1717import base64
1818import json
1919from typing import Any , Dict , Union , AsyncGenerator , Tuple , List , Optional , Literal
20+ from typing_extensions import override
2021
21- from google .adk .models import LlmRequest , LlmResponse
22- from google .adk .models .lite_llm import LiteLlm
23- from google .adk .models .cache_metadata import CacheMetadata
22+ from google .adk .models import LlmRequest , LlmResponse , Gemini
2423from google .genai import types
2524from pydantic import Field , BaseModel
2625from volcenginesdkarkruntime import AsyncArk
@@ -148,24 +147,6 @@ def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict:
148147 return schema_dict
149148
150149
151- def build_cache_metadata (response_id : str ) -> CacheMetadata :
152- """Create a new CacheMetadata instance for agent response tracking.
153- Args:
154- response_id: Response ID to track
155- Returns:
156- A new CacheMetadata instance with the agent-response mapping
157- """
158- # `adk >= 1.17`
159- cache_metadata = CacheMetadata (
160- cache_name = response_id ,
161- expire_time = 0 ,
162- fingerprint = "" ,
163- invocations_used = 0 ,
164- contents_count = 0 ,
165- )
166- return cache_metadata
167-
168-
169150# -----------------------------------------------------------------
170151# inputs param transform ------------------------------------------
171152def _file_data_to_content_param (
@@ -638,8 +619,7 @@ def ark_response_to_generate_content_response(
638619 )
639620
640621 # previous_response_id
641- previous_response_id = raw_response .id
642- llm_response .cache_metadata = build_cache_metadata (previous_response_id )
622+ llm_response .interaction_id = raw_response .id
643623
644624 return llm_response
645625
@@ -662,12 +642,29 @@ async def aresponse(
662642 return raw_response
663643
664644
665- class ArkLlm (LiteLlm ):
645+ class ArkLlm (Gemini ):
646+ model : str
666647 llm_client : ArkLlmClient = Field (default_factory = ArkLlmClient )
667648 _additional_args : Dict [str , Any ] = None
649+ use_interactions_api : bool = True
668650
669651 def __init__ (self , ** kwargs ):
652+ # adk version check
653+ if "previous_interaction_id" not in LlmRequest .model_fields :
654+ raise ImportError (
655+ "If using the ResponsesAPI, "
656+ "please upgrade the version of google-adk to `1.21.0` or higher with the command: "
657+ "`pip install -U 'google-adk>=1.21.0'`"
658+ )
670659 super ().__init__ (** kwargs )
660+ drop_params = kwargs .pop ("drop_params" , None )
661+ self ._additional_args = dict (kwargs )
662+ self ._additional_args .pop ("llm_client" , None )
663+ self ._additional_args .pop ("messages" , None )
664+ self ._additional_args .pop ("tools" , None )
665+ self ._additional_args .pop ("stream" , None )
666+ if drop_params is not None :
667+ self ._additional_args ["drop_params" ] = drop_params
671668
672669 async def generate_content_async (
673670 self , llm_request : LlmRequest , stream : bool = False
@@ -694,8 +691,8 @@ async def generate_content_async(
694691 # ------------------------------------------------------ #
695692 # get previous_response_id
696693 previous_response_id = None
697- if llm_request .cache_metadata and llm_request . cache_metadata . cache_name :
698- previous_response_id = llm_request .cache_metadata . cache_name
694+ if llm_request .previous_interaction_id :
695+ previous_response_id = llm_request .previous_interaction_id
699696 responses_args = {
700697 "model" : self .model ,
701698 "instructions" : instructions ,
@@ -723,3 +720,11 @@ async def generate_content_async(
723720 raw_response = await self .llm_client .aresponse (** responses_args )
724721 llm_response = ark_response_to_generate_content_response (raw_response )
725722 yield llm_response
723+
724+ @classmethod
725+ @override
726+ def supported_models (cls ) -> list [str ]:
727+ return [
728+ # For OpenAI models (e.g., "openai/gpt-4o")
729+ r"openai/.*" ,
730+ ]
0 commit comments