|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import os |
18 | | -from typing import Optional, Union, AsyncGenerator |
| 18 | +from typing import AsyncGenerator, Optional, Union |
19 | 19 |
|
20 | 20 | # If user didn't set LITELLM_LOCAL_MODEL_COST_MAP, set it to True |
21 | 21 | # to enable local model cost map. |
|
24 | 24 | if not os.getenv("LITELLM_LOCAL_MODEL_COST_MAP"): |
25 | 25 | os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" |
26 | 26 |
|
27 | | -from google.adk.agents import LlmAgent, RunConfig, InvocationContext |
| 27 | +from google.adk.agents import InvocationContext, LlmAgent, RunConfig |
28 | 28 | from google.adk.agents.base_agent import BaseAgent |
29 | 29 | from google.adk.agents.context_cache_config import ContextCacheConfig |
30 | 30 | from google.adk.agents.llm_agent import InstructionProvider, ToolUnion |
31 | 31 | from google.adk.agents.run_config import StreamingMode |
32 | 32 | from google.adk.events import Event, EventActions |
| 33 | +from google.adk.flows.llm_flows.auto_flow import AutoFlow |
| 34 | +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow |
| 35 | +from google.adk.flows.llm_flows.single_flow import SingleFlow |
33 | 36 | from google.adk.models.lite_llm import LiteLlm |
34 | 37 | from google.adk.runners import Runner |
35 | 38 | from google.genai import types |
|
53 | 56 | from veadk.prompts.prompt_manager import BasePromptManager |
54 | 57 | from veadk.tracing.base_tracer import BaseTracer |
55 | 58 | from veadk.utils.logger import get_logger |
56 | | -from veadk.utils.patches import patch_asyncio, patch_tracer |
57 | 59 | from veadk.utils.misc import check_litellm_version |
| 60 | +from veadk.utils.patches import patch_asyncio, patch_tracer |
58 | 61 | from veadk.version import VERSION |
59 | 62 |
|
60 | 63 | patch_tracer() |
@@ -118,6 +121,8 @@ class Agent(LlmAgent): |
118 | 121 |
|
119 | 122 | enable_responses: bool = False |
120 | 123 |
|
| 124 | + enable_shadow_agent: bool = False |
| 125 | + |
121 | 126 | context_cache_config: Optional[ContextCacheConfig] = None |
122 | 127 |
|
123 | 128 | run_processor: Optional[BaseRunProcessor] = Field(default=None, exclude=True) |
@@ -292,6 +297,28 @@ def model_post_init(self, __context: Any) -> None: |
292 | 297 | f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}" |
293 | 298 | ) |
294 | 299 |
|
| 300 | + @property |
| 301 | + def _llm_flow(self) -> BaseLlmFlow: |
| 302 | + if ( |
| 303 | + self.disallow_transfer_to_parent |
| 304 | + and self.disallow_transfer_to_peers |
| 305 | + and not self.sub_agents |
| 306 | + ): |
| 307 | + from veadk.flows.supervisor_single_flow import SupervisorSingleFlow |
| 308 | + |
| 309 | + if self.enable_shadow_agent: |
| 310 | + logger.debug(f"Enable supervisor flow for agent: {self.name}") |
| 311 | + return SupervisorSingleFlow(supervised_agent=self) |
| 312 | + else: |
| 313 | + return SingleFlow() |
| 314 | + else: |
| 315 | + from veadk.flows.supervisor_auto_flow import SupervisorAutoFlow |
| 316 | + |
| 317 | + if self.enable_shadow_agent: |
| 318 | + logger.debug(f"Enable supervisor flow for agent: {self.name}") |
| 319 | + return SupervisorAutoFlow(supervised_agent=self) |
| 320 | + return AutoFlow() |
| 321 | + |
295 | 322 | async def _run_async_impl( |
296 | 323 | self, ctx: InvocationContext |
297 | 324 | ) -> AsyncGenerator[Event, None]: |
|
0 commit comments