diff --git a/veadk/runner.py b/veadk/runner.py index 06f49038..5b7b6183 100644 --- a/veadk/runner.py +++ b/veadk/runner.py @@ -14,6 +14,7 @@ from typing import Union from google.adk.agents import RunConfig +from google.adk.agents.invocation_context import LlmCallsLimitExceededError from google.adk.agents.run_config import StreamingMode from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner as ADKRunner @@ -49,20 +50,25 @@ class Runner: def __init__( self, agent: VeAgent, - short_term_memory: ShortTermMemory, + short_term_memory: ShortTermMemory | None = None, plugins: list[BasePlugin] | None = None, app_name: str = "veadk_default_app", user_id: str = "veadk_default_user", ): - # basic settings self.app_name = app_name self.user_id = user_id - # agent settings self.agent = agent - self.short_term_memory = short_term_memory - self.session_service = short_term_memory.session_service + if not short_term_memory: + logger.info( + "No short term memory provided, using a in-memory memory by default." + ) + self.short_term_memory = ShortTermMemory() + else: + self.short_term_memory = short_term_memory + + self.session_service = self.short_term_memory.session_service # prevent VeRemoteAgent has no long-term memory attr if isinstance(self.agent, Agent): @@ -114,35 +120,44 @@ async def _run( self, session_id: str, message: types.Content, + run_config: RunConfig | None = None, stream: bool = False, ): stream_mode = StreamingMode.SSE if stream else StreamingMode.NONE - async def event_generator(): - async for event in self.runner.run_async( - user_id=self.user_id, - session_id=session_id, - new_message=message, - run_config=RunConfig(streaming_mode=stream_mode), - ): - if event.get_function_calls(): - for function_call in event.get_function_calls(): - logger.debug(f"Function call: {function_call}") - elif ( - event.content is not None - and event.content.parts - and event.content.parts[0].text is not None - and len(event.content.parts[0].text.strip()) > 0 - ): - yield event.content.parts[0].text + if run_config is not None: + stream_mode = run_config.streaming_mode + else: + run_config = RunConfig(streaming_mode=stream_mode) + try: - final_output = "" - async for chunk in event_generator(): + async def event_generator(): + async for event in self.runner.run_async( + user_id=self.user_id, + session_id=session_id, + new_message=message, + run_config=run_config, + ): + if event.get_function_calls(): + for function_call in event.get_function_calls(): + logger.debug(f"Function call: {function_call}") + elif ( + event.content is not None + and event.content.parts + and event.content.parts[0].text is not None + and len(event.content.parts[0].text.strip()) > 0 + ): + yield event.content.parts[0].text + + final_output = "" + async for chunk in event_generator(): + if stream: + print(chunk, end="", flush=True) + final_output += chunk if stream: - print(chunk, end="", flush=True) - final_output += chunk - if stream: - print() # end with a new line + print() # end with a new line + except LlmCallsLimitExceededError as e: + logger.warning(f"Max number of llm calls limit exceeded: {e}") return final_output @@ -151,6 +166,7 @@ async def run( messages: RunnerMessage, session_id: str, stream: bool = False, + run_config: RunConfig | None = None, save_tracing_data: bool = False, ): converted_messages: list = self._convert_messages(messages) @@ -163,7 +179,9 @@ async def run( final_output = "" for converted_message in converted_messages: - final_output = await self._run(session_id, converted_message, stream) + final_output = await self._run( + session_id, converted_message, run_config, stream + ) # try to save tracing file if save_tracing_data: @@ -193,6 +211,47 @@ def get_trace_id(self) -> str: logger.warning(f"Get tracer id failed as {e}") return "" + async def run_with_raw_message( + self, + message: types.Content, + session_id: str, + run_config: RunConfig | None = None, + ): + run_config = RunConfig() if not run_config else run_config + + await self.short_term_memory.create_session( + app_name=self.app_name, user_id=self.user_id, session_id=session_id + ) + + try: + + async def event_generator(): + async for event in self.runner.run_async( + user_id=self.user_id, + session_id=session_id, + new_message=message, + run_config=run_config, + ): + if event.get_function_calls(): + for function_call in event.get_function_calls(): + logger.debug(f"Function call: {function_call}") + elif ( + event.content is not None + and event.content.parts + and event.content.parts[0].text is not None + and len(event.content.parts[0].text.strip()) > 0 + ): + yield event.content.parts[0].text + + final_output = "" + + async for chunk in event_generator(): + final_output += chunk + except LlmCallsLimitExceededError as e: + logger.warning(f"Max number of llm calls limit exceeded: {e}") + + return final_output + def _print_trace_id(self) -> None: if not isinstance(self.agent, Agent): logger.warning(