Skip to content

Commit 56652f2

Browse files
committed
feat: use interaction_id to pass response_id
1 parent 4bdecf5 commit 56652f2

3 files changed

Lines changed: 33 additions & 78 deletions

File tree

veadk/agent.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import os
18-
from typing import Optional, Union, AsyncGenerator
18+
from typing import Optional, Union
1919

2020
# If user didn't set LITELLM_LOCAL_MODEL_COST_MAP, set it to True
2121
# to enable local model cost map.
@@ -24,12 +24,11 @@
2424
if not os.getenv("LITELLM_LOCAL_MODEL_COST_MAP"):
2525
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
2626

27-
from google.adk.agents import LlmAgent, RunConfig, InvocationContext
27+
from google.adk.agents import LlmAgent, RunConfig
2828
from google.adk.agents.base_agent import BaseAgent
2929
from google.adk.agents.context_cache_config import ContextCacheConfig
3030
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
3131
from google.adk.agents.run_config import StreamingMode
32-
from google.adk.events import Event, EventActions
3332
from google.adk.models.lite_llm import LiteLlm
3433
from google.adk.runners import Runner
3534
from google.genai import types
@@ -178,12 +177,6 @@ def model_post_init(self, __context: Any) -> None:
178177
api_base=self.model_api_base,
179178
**self.model_extra_config,
180179
)
181-
if not self.context_cache_config:
182-
self.context_cache_config = ContextCacheConfig(
183-
cache_intervals=100, # maximum number
184-
ttl_seconds=315360000,
185-
min_tokens=0,
186-
)
187180
else:
188181
fallbacks = None
189182
if isinstance(self.model_name, list):
@@ -288,28 +281,6 @@ def model_post_init(self, __context: Any) -> None:
288281
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}"
289282
)
290283

291-
async def _run_async_impl(
292-
self, ctx: InvocationContext
293-
) -> AsyncGenerator[Event, None]:
294-
if self.enable_responses:
295-
if not ctx.context_cache_config:
296-
ctx.context_cache_config = self.context_cache_config
297-
298-
async for event in super()._run_async_impl(ctx):
299-
yield event
300-
if self.enable_responses and event.cache_metadata:
301-
# for persistent short-term memory with response api
302-
session_state_event = Event(
303-
invocation_id=event.invocation_id,
304-
author=event.author,
305-
actions=EventActions(
306-
state_delta={
307-
"response_id": event.cache_metadata.cache_name,
308-
}
309-
),
310-
)
311-
yield session_state_event
312-
313284
async def _run(
314285
self,
315286
runner,

veadk/memory/short_term_memory.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from veadk.memory.short_term_memory_backends.sqlite_backend import (
3333
SQLiteSTMBackend,
3434
)
35-
from veadk.models.ark_llm import build_cache_metadata
3635
from veadk.utils.logger import get_logger
3736

3837
logger = get_logger(__name__)
@@ -50,21 +49,6 @@ async def wrapper(*args, **kwargs):
5049
setattr(obj, "get_session", wrapper)
5150

5251

53-
def enable_responses_api_for_session_service(result, *args, **kwargs):
54-
if result and isinstance(result, Session):
55-
if result.events:
56-
for event in result.events:
57-
if (
58-
event.actions
59-
and event.actions.state_delta
60-
and not event.cache_metadata
61-
and "response_id" in event.actions.state_delta
62-
):
63-
event.cache_metadata = build_cache_metadata(
64-
response_id=event.actions.state_delta.get("response_id"),
65-
)
66-
67-
6852
class ShortTermMemory(BaseModel):
6953
"""Short term memory for agent execution.
7054
@@ -186,11 +170,6 @@ def model_post_init(self, __context: Any) -> None:
186170
db_kwargs=self.db_kwargs, **self.backend_configs
187171
).session_service
188172

189-
if self.backend != "local":
190-
wrap_get_session_with_callbacks(
191-
self._session_service, enable_responses_api_for_session_service
192-
)
193-
194173
if self.after_load_memory_callback:
195174
wrap_get_session_with_callbacks(
196175
self._session_service, self.after_load_memory_callback

veadk/models/ark_llm.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717
import base64
1818
import json
1919
from typing import Any, Dict, Union, AsyncGenerator, Tuple, List, Optional, Literal
20+
from typing_extensions import override
2021

21-
from google.adk.models import LlmRequest, LlmResponse
22-
from google.adk.models.lite_llm import LiteLlm
23-
from google.adk.models.cache_metadata import CacheMetadata
22+
from google.adk.models import LlmRequest, LlmResponse, Gemini
2423
from google.genai import types
2524
from pydantic import Field, BaseModel
2625
from volcenginesdkarkruntime import AsyncArk
@@ -148,24 +147,6 @@ def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict:
148147
return schema_dict
149148

150149

151-
def build_cache_metadata(response_id: str) -> CacheMetadata:
152-
"""Create a new CacheMetadata instance for agent response tracking.
153-
Args:
154-
response_id: Response ID to track
155-
Returns:
156-
A new CacheMetadata instance with the agent-response mapping
157-
"""
158-
# `adk >= 1.17`
159-
cache_metadata = CacheMetadata(
160-
cache_name=response_id,
161-
expire_time=0,
162-
fingerprint="",
163-
invocations_used=0,
164-
contents_count=0,
165-
)
166-
return cache_metadata
167-
168-
169150
# -----------------------------------------------------------------
170151
# inputs param transform ------------------------------------------
171152
def _file_data_to_content_param(
@@ -638,8 +619,7 @@ def ark_response_to_generate_content_response(
638619
)
639620

640621
# previous_response_id
641-
previous_response_id = raw_response.id
642-
llm_response.cache_metadata = build_cache_metadata(previous_response_id)
622+
llm_response.interaction_id = raw_response.id
643623

644624
return llm_response
645625

@@ -662,12 +642,29 @@ async def aresponse(
662642
return raw_response
663643

664644

665-
class ArkLlm(LiteLlm):
645+
class ArkLlm(Gemini):
646+
model: str
666647
llm_client: ArkLlmClient = Field(default_factory=ArkLlmClient)
667648
_additional_args: Dict[str, Any] = None
649+
use_interactions_api: bool = True
668650

669651
def __init__(self, **kwargs):
652+
# adk version check
653+
if "previous_interaction_id" not in LlmRequest.model_fields:
654+
raise ImportError(
655+
"If using the ResponsesAPI, "
656+
"please upgrade the version of google-adk to `1.21.0` or higher with the command: "
657+
"`pip install -U 'google-adk>=1.21.0'`"
658+
)
670659
super().__init__(**kwargs)
660+
drop_params = kwargs.pop("drop_params", None)
661+
self._additional_args = dict(kwargs)
662+
self._additional_args.pop("llm_client", None)
663+
self._additional_args.pop("messages", None)
664+
self._additional_args.pop("tools", None)
665+
self._additional_args.pop("stream", None)
666+
if drop_params is not None:
667+
self._additional_args["drop_params"] = drop_params
671668

672669
async def generate_content_async(
673670
self, llm_request: LlmRequest, stream: bool = False
@@ -694,8 +691,8 @@ async def generate_content_async(
694691
# ------------------------------------------------------ #
695692
# get previous_response_id
696693
previous_response_id = None
697-
if llm_request.cache_metadata and llm_request.cache_metadata.cache_name:
698-
previous_response_id = llm_request.cache_metadata.cache_name
694+
if llm_request.previous_interaction_id:
695+
previous_response_id = llm_request.previous_interaction_id
699696
responses_args = {
700697
"model": self.model,
701698
"instructions": instructions,
@@ -723,3 +720,11 @@ async def generate_content_async(
723720
raw_response = await self.llm_client.aresponse(**responses_args)
724721
llm_response = ark_response_to_generate_content_response(raw_response)
725722
yield llm_response
723+
724+
@classmethod
725+
@override
726+
def supported_models(cls) -> list[str]:
727+
return [
728+
# For OpenAI models (e.g., "openai/gpt-4o")
729+
r"openai/.*",
730+
]

0 commit comments

Comments
 (0)