diff --git a/docs/docs/agent/responses-api.md b/docs/docs/agent/responses-api.md index 33e66460..61996aa2 100644 --- a/docs/docs/agent/responses-api.md +++ b/docs/docs/agent/responses-api.md @@ -2,27 +2,27 @@ title: Responses API 支持 --- -Responses API 是火山方舟最新推出的 API 接口,原生支持高效的上下文管理,支持更简洁的输入输出格式,并且工具调用方式也更加便捷,不仅延续了 Chat API 的易用性,还结合了更强的智能代理能力。 -随着大模型技术不断升级,Responses API 为开发各类面向实际行动的应用提供了更灵活的基础,并且支持工具调用多种扩展能力,非常适合搭建智能助手、自动化工具等场景。 +Responses API 是火山方舟最新推出的 API 接口,原生支持高效的上下文管理,具备更简洁的输入输出格式。其工具调用方式更加便捷,不仅延续了 Chat API 的易用性,还结合了更强的智能代理能力。 +随着大模型技术不断升级,Responses API 为开发各类面向实际行动的应用提供了更灵活的基础,并支持多种工具调用扩展能力,非常适合搭建智能助手、自动化工具等场景。 --- ## 使用教程 -目前 VeADK Responses API 支持对 LiteLLM 版本依赖有限制,需要 `litellm>=1.79.3`,请确保您的 LiteLLM 版本符合要求。 +目前仅 `veadk-python` 版本支持 Responses API,且对 `google-adk` 版本有特定要求(`google-adk>=1.21.0`)。请确保您的环境符合要求。 === "pip" ```bash - pip install "litellm>=1.79.3" + pip install "google-adk>=1.21.0" ``` === "uv" ```bash - uv pip install "litellm>=1.79.3" + uv pip install "google-adk>=1.21.0" ``` ### 快速开始 -只需要要配置enable_responses=True即可 +只需配置 `enable_responses=True` 即可。 ```python hl_lines="4" from veadk import Agent @@ -39,6 +39,311 @@ root_agent = Agent( ## 注意事项 -- 必须提升litellm版本: `litellm>=1.79.3`,用于支撑responses_request转换 -- 必须保证adk版本: `google-adk>=1.15` -- 请保证使用的模型支持ResponsesAPI +1. **版本要求**:必须保证 `google-adk>=1.21.0`。 +2. **模型支持**:请确保使用的模型支持 Responses API(注:Doubao 系列模型 0615 版本之后,除特殊说明外均支持)。 +3. **缓存机制**:VeADK 开启 Responses API 并使用火山方舟模型时默认开启上下文缓存(Session 缓存)。但若在 Agent 中设置了 `output_schema`,因该字段与缓存机制冲突,系统将自动关闭缓存。 + +## 上下文缓存 + +在 Responses API 模式下,VeADK 默认开启会话缓存(Session Caching)。该机制会自动存储初始上下文信息,并在每一轮对话中动态更新。在后续请求中,系统会将缓存内容与新输入合并后发送给模型推理。此功能特别适用于多轮对话、复杂工具调用等长上下文场景。 + +### 缓存信息查看 + +您可以通过返回的 `Event.usage_metadata` 字段查看 Token 使用及缓存命中情况。 + +以下是一次包含两轮对话的 `usage_metadata` 示例: + +```json +{"cached_content_token_count":0,"candidates_token_count":87,"prompt_token_count":210,"total_token_count":297} +{"cached_content_token_count":297,"candidates_token_count":181,"prompt_token_count":314,"total_token_count":495} +``` + +**字段说明:** + +- `cached_content_token_count`:命中缓存的 Token 数量(即从缓存中读取的 Token 数)。 +- `candidates_token_count`:模型生成的 Token 数量(输出 Token)。 +- `prompt_token_count`:输入给模型的总 Token 数量(包含已缓存和未缓存部分)。 +- `total_token_count`:总消耗 Token 数量(输入 + 输出)。 + +**缓存机制说明:** + +- 缓存仅影响输入(Prompt)Token,不影响输出(Completion)Token。 +- **缓存命中率**反映了缓存策略的有效性,命中率越高,Token 成本节省越多。 + - 计算公式:`缓存命中率 = (cached_content_token_count / prompt_token_count) × 100%` +- 输入 Token 成本节约率:用于量化整个会话的缓存收益,是面向业务侧的核心指标,支持会话级汇总计算。 + +### 成本节省示例 + +基于上述样例数据,缓存命中率计算如下: + +- **第一轮对话**:0%(初始状态,无缓存) +- **第二轮对话**:`297 / 314 * 100% ≈ 94.58%` + +输入 Token 成本节约率:`(0 + 297) / (210 + 314) ≈ 56.68%` + +这意味着在开启缓存后,该次会话的 **输入 Token 缓存命中率达到了 56.68%**,大幅减少了重复内容的计算开销。 +[火山方舟:缓存Token计费说明](https://www.volcengine.com/docs/82379/1544106?lang=zh) + +注:第N轮的`cached_content_token_count`不一定等于第N-1轮的`total_token_count`,如果开启了thinking,二者不等。 + + +## 多模态能力支持 + +Responses API 除文本交互外,还具备图片、视频和文件等多模态理解能力。 +您可以使用 `google.genai.types.FileData` 字段传递多模态数据(如图片路径、视频 URL、Files API 生成的 file_id 等)。 + +`FileData` 支持以下数据传递方式: + +1. **Files API 资源 (`file_id`)** + - 通过 `file_data` 传递,需在 `file_uri` 中添加特定 Scheme 前缀以区分。 + - 格式:`file_uri="file_id://xxxxxxx"`,该值将被自动映射到 `file_id` 字段。 + +2. **通用资源标识符 (URI)** + 火山方舟支持的所有文件上传方式(包括 `image_url`、`video_url`、`file_url`)均可通过 `file_data` 统一处理。 + 1. **网络 URL** + - 系统根据 `mime_type` 自动识别资源类型(视频、图片、文件等)。 + - 格式:`file_uri="https://..."` + 2. **本地文件路径** + - 直接使用本地文件路径作为 `file_uri`,底层会自动调用 Files API 完成上传。 + - 格式:`file_uri=f"file://{local_path}"` + 3. **Base64 Data URI** + - 支持传入 Base64 编码的数据。 + - 参考:[火山方舟:图片理解-Base64编码](https://www.volcengine.com/docs/82379/1362931?lang=zh#477e51ce) + +### 样例代码 + +**注**:以下所有示例代码均基于下述 `Agent` 配置与 `main` 函数: + +```python +import os +import asyncio +import uuid + +from google.adk.events import Event +from google.genai import types +from google.genai.types import FileData + +from veadk import Agent, Runner +from veadk.memory.short_term_memory import ShortTermMemory + +agent = Agent( + enable_responses=True +) +short_term_memory = ShortTermMemory() +runner = Runner( + agent=agent, + short_term_memory=short_term_memory, +) + +async def main(message: types.Content): + session_id = uuid.uuid4().hex + await short_term_memory.session_service.create_session( + app_name=runner.app_name, user_id=runner.user_id, session_id=session_id + ) + + async for event in runner.run_async( + user_id=runner.user_id, + session_id=session_id, + new_message=message, + ): + if isinstance(event, Event) and event.is_final_response(): + if event.content and event.content.parts: + if not event.content.parts[0].thought: + print(event.content.parts[0].text) + elif len(event.content.parts) > 1: + print(event.content.parts[1].text) +``` + + +### 图片理解 + +=== "本地路径" + + 支持处理**最大 512MB** 的图片文件。 + + 该方式直接传入本地文件路径,系统会自动调用 Files API 完成上传,随后调用 Responses API 进行分析。 + + ```python hl_lines="7-8" + local_path = os.path.abspath("example-data.png") + message = types.UserContent( + parts=[ + types.Part(text="描述一下这张图片"), + types.Part( + file_data=FileData( + file_uri=f"file://{local_path}", + mime_type="image/png" + ) + ) + ], + ) + asyncio.run( + main(message) + ) + ``` + +=== "Files API" + + 通过火山方舟 Files API 上传文件,支持处理**最大 512MB** 的图片文件。 + + ```python hl_lines="9-10" + from veadk.utils.misc import upload_to_files_api + local_path = os.path.abspath("example-data.png") + file_id = asyncio.run(upload_to_files_api(local_path)) + message = types.UserContent( + parts=[ + types.Part(text="描述一下这张图片"), + types.Part( + file_data=FileData( + file_uri=f"file_id://{file_id}", + mime_type="image/png" + ) + ) + ], + ) + asyncio.run( + main(message) + ) + ``` + +=== "图片 URL" + + 支持处理**最大 10MB** 的图片文件。 + + ```python hl_lines="7-8" + image_url = "" + message = types.UserContent( + parts=[ + types.Part(text="描述一下这张图片"), + types.Part( + file_data=FileData( + file_uri=f"{image_url}", + mime_type="image/png" + ) + ) + ], + ) + asyncio.run( + main(message) + ) + ``` + + + +### 视频理解 + +=== "Files API" + + 通过火山方舟 Files API 上传文件,支持处理**最大 512MB** 的视频文件。 + + 在 `file_data` 中传递视频文件时,可选择添加 `video_metadata` 字段,指定视频的帧率(fps)。 + + 您可以通过`fps`字段,控制从视频中抽取图像的频率,默认为1,即每秒从视频中抽取一帧图像,输入给模型进行视觉理解。可通过fps字段调整抽取频率,以平衡视频长度与模型处理效率。 + + - 当视频画面变化剧烈或需要关注画面变化,如计算视频中角色动作次数,可以跳高fps的设置(最高为5),防止抽帧频率过快导致模型无法准确理解视频内容。 + + - 当视频画面变化缓慢或不需要关注画面变化,如分析视频中人物行为,可适当降低fps的设置(最低为0.2),以平衡视频长度与模型处理效率。 + + ```python hl_lines="9-10" + from veadk.utils.misc import upload_to_files_api + local_path = os.path.abspath("example-data.png") + file_id = asyncio.run(upload_to_files_api(local_path, fps=0.3)) # optional `fps` + message = types.UserContent( + parts=[ + types.Part(text="描述一下这个视频"), + types.Part( + file_data=FileData( + file_uri=f"file_id://{file_id}", + mime_type="video/mp4" + ), + video_metadata={ # optional + "fps": 0.3 + } + ) + ], + ) + asyncio.run( + main(message) + ) + ``` + +=== "视频 URL" + + 支持处理**最大 50MB** 的视频文件。 + + ```python hl_lines="7-8" + video_url = "" + message = types.UserContent( + parts=[ + types.Part(text="描述一下这个视频"), + types.Part( + file_data=FileData( + file_uri=f"{video_url}", + mime_type="video/mp4" + ) + ) + ], + ) + asyncio.run( + main(message) + ) + ``` + + +### 文档理解 + +部分模型支持处理 PDF 格式的文档,系统会通过视觉功能理解整个文档的上下文。 +当传入 PDF 文档时,大模型会将文件分页处理成多张图片,分析解读其中的文本与图片信息,并结合这些信息完成文档理解任务。 + +=== "Files API" + + 通过火山方舟 Files API 上传文件,支持处理**最大 512MB** 的文档。 + + ```python hl_lines="9-10" + from veadk.utils.misc import upload_to_files_api + local_path = os.path.abspath("example-pdf.pdf") + file_id = asyncio.run(upload_to_files_api(local_path)) + message = types.UserContent( + parts=[ + types.Part(text="请概括总结文档的内容"), + types.Part( + file_data=FileData( + file_uri=f"file_id://{file_id}", + mime_type="application/pdf" + ) + ) + ], + ) + asyncio.run( + main(message) + ) + ``` + +=== "文档 URL" + + 支持处理**最大 50MB** 的文档。 + + ```python hl_lines="7-8" + message2 = types.UserContent( + parts=[ + types.Part(text="请总结概括本文档"), + types.Part( + file_data=FileData( + file_uri="", + mime_type="application/pdf" + ) + ) + ], + ) + asyncio.run( + main(message2) + ) + ``` + + + +## 参考文档 + +1. [火山方舟:ResponsesAPI迁移文档](https://www.volcengine.com/docs/82379/1585128?lang=zh) +2. [火山方舟:上下文缓存](https://www.volcengine.com/docs/82379/1602228?lang=zh#3e69e743) +3. [火山方舟:缓存Token计费说明](https://www.volcengine.com/docs/82379/1544106?lang=zh) +4. [火山方舟:多模态理解](https://www.volcengine.com/docs/82379/1958521?lang=zh) \ No newline at end of file diff --git a/docs/docs/assets/images/agents/responses_api.png b/docs/docs/assets/images/agents/responses_api.png index 4ec5cdd4..043d433c 100644 Binary files a/docs/docs/assets/images/agents/responses_api.png and b/docs/docs/assets/images/agents/responses_api.png differ diff --git a/veadk/agent.py b/veadk/agent.py index b3f14365..7f4745d6 100644 --- a/veadk/agent.py +++ b/veadk/agent.py @@ -15,7 +15,7 @@ from __future__ import annotations import os -from typing import Optional, Union, AsyncGenerator +from typing import Optional, Union # If user didn't set LITELLM_LOCAL_MODEL_COST_MAP, set it to True # to enable local model cost map. @@ -24,12 +24,11 @@ if not os.getenv("LITELLM_LOCAL_MODEL_COST_MAP"): os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -from google.adk.agents import LlmAgent, RunConfig, InvocationContext +from google.adk.agents import LlmAgent, RunConfig from google.adk.agents.base_agent import BaseAgent from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.agents.llm_agent import InstructionProvider, ToolUnion from google.adk.agents.run_config import StreamingMode -from google.adk.events import Event, EventActions from google.adk.models.lite_llm import LiteLlm from google.adk.runners import Runner from google.genai import types @@ -54,7 +53,6 @@ from veadk.tracing.base_tracer import BaseTracer from veadk.utils.logger import get_logger from veadk.utils.patches import patch_asyncio, patch_tracer -from veadk.utils.misc import check_litellm_version from veadk.version import VERSION patch_tracer() @@ -171,9 +169,6 @@ def model_post_init(self, __context: Any) -> None: if not self.model: if self.enable_responses: - min_version = "1.79.3" - check_litellm_version(min_version) - from veadk.models.ark_llm import ArkLlm self.model = ArkLlm( @@ -182,12 +177,6 @@ def model_post_init(self, __context: Any) -> None: api_base=self.model_api_base, **self.model_extra_config, ) - if not self.context_cache_config: - self.context_cache_config = ContextCacheConfig( - cache_intervals=100, # maximum number - ttl_seconds=315360000, - min_tokens=0, - ) else: fallbacks = None if isinstance(self.model_name, list): @@ -292,28 +281,6 @@ def model_post_init(self, __context: Any) -> None: f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}" ) - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - if self.enable_responses: - if not ctx.context_cache_config: - ctx.context_cache_config = self.context_cache_config - - async for event in super()._run_async_impl(ctx): - yield event - if self.enable_responses and event.cache_metadata: - # for persistent short-term memory with response api - session_state_event = Event( - invocation_id=event.invocation_id, - author=event.author, - actions=EventActions( - state_delta={ - "response_id": event.cache_metadata.cache_name, - } - ), - ) - yield session_state_event - async def _run( self, runner, diff --git a/veadk/memory/short_term_memory.py b/veadk/memory/short_term_memory.py index 259ea928..6b537e82 100644 --- a/veadk/memory/short_term_memory.py +++ b/veadk/memory/short_term_memory.py @@ -32,7 +32,6 @@ from veadk.memory.short_term_memory_backends.sqlite_backend import ( SQLiteSTMBackend, ) -from veadk.models.ark_transform import build_cache_metadata from veadk.utils.logger import get_logger logger = get_logger(__name__) @@ -50,21 +49,6 @@ async def wrapper(*args, **kwargs): setattr(obj, "get_session", wrapper) -def enable_responses_api_for_session_service(result, *args, **kwargs): - if result and isinstance(result, Session): - if result.events: - for event in result.events: - if ( - event.actions - and event.actions.state_delta - and not event.cache_metadata - and "response_id" in event.actions.state_delta - ): - event.cache_metadata = build_cache_metadata( - response_id=event.actions.state_delta.get("response_id"), - ) - - class ShortTermMemory(BaseModel): """Short term memory for agent execution. @@ -186,11 +170,6 @@ def model_post_init(self, __context: Any) -> None: db_kwargs=self.db_kwargs, **self.backend_configs ).session_service - if self.backend != "local": - wrap_get_session_with_callbacks( - self._session_service, enable_responses_api_for_session_service - ) - if self.after_load_memory_callback: wrap_get_session_with_callbacks( self._session_service, self.after_load_memory_callback diff --git a/veadk/models/__init__.py b/veadk/models/__init__.py new file mode 100644 index 00000000..7f463206 --- /dev/null +++ b/veadk/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/veadk/models/ark_llm.py b/veadk/models/ark_llm.py index 54295b7e..0da72735 100644 --- a/veadk/models/ark_llm.py +++ b/veadk/models/ark_llm.py @@ -14,50 +14,626 @@ # adapted from Google ADK models adk-python/blob/main/src/google/adk/models/lite_llm.py at f1f44675e4a86b75e72cfd838efd8a0399f23e24 · google/adk-python +import base64 import json -from typing import Any, Dict, Union, AsyncGenerator - -import litellm -import openai -from openai.types.responses import Response as OpenAITypeResponse, ResponseStreamEvent -from google.adk.models import LlmRequest, LlmResponse -from google.adk.models.lite_llm import ( - LiteLlm, - _get_completion_inputs, - FunctionChunk, - TextChunk, - _message_to_generate_content_response, - UsageMetadataChunk, -) +from typing import Any, Dict, Union, AsyncGenerator, Tuple, List, Optional, Literal +from typing_extensions import override + +from google.adk.models import LlmRequest, LlmResponse, Gemini from google.genai import types -from litellm import ChatCompletionAssistantMessage -from litellm.types.utils import ( - ChatCompletionMessageToolCall, - Function, +from pydantic import Field, BaseModel +from volcenginesdkarkruntime import AsyncArk +from volcenginesdkarkruntime._streaming import AsyncStream +from volcenginesdkarkruntime.types.responses import ( + Response as ArkTypeResponse, + ResponseStreamEvent, + FunctionToolParam, + ResponseTextConfigParam, + ResponseReasoningItem, + ResponseOutputMessage, + ResponseOutputText, + ResponseFunctionToolCall, + ResponseReasoningSummaryTextDeltaEvent, + ResponseTextDeltaEvent, + ResponseCompletedEvent, ) -from pydantic import Field - -from veadk.models.ark_transform import ( - CompletionToResponsesAPIHandler, +from volcenginesdkarkruntime.types.responses.response_input_message_content_list_param import ( + ResponseInputTextParam, + ResponseInputImageParam, + ResponseInputVideoParam, + ResponseInputFileParam, + ResponseInputContentParam, +) +from volcenginesdkarkruntime.types.responses.response_input_param import ( + ResponseInputItemParam, + ResponseFunctionToolCallParam, + EasyInputMessageParam, + FunctionCallOutput, ) -from veadk.utils.logger import get_logger -# This will add functions to prompts if functions are provided. -litellm.add_function_to_prompt = True +from veadk.config import settings +from veadk.consts import DEFAULT_VIDEO_MODEL_API_BASE +from veadk.utils.logger import get_logger logger = get_logger(__name__) +_ARK_TEXT_FIELD_TYPES = {"json_object", "json_schema"} + +_FINISH_REASON_MAPPING = { + "incomplete": { + "length": types.FinishReason.MAX_TOKENS, + "content_filter": types.FinishReason.SAFETY, + }, + "completed": { + "other": types.FinishReason.STOP, + }, +} + +ark_supported_fields = [ + "input", + "model", + "stream", + "background", + "include", + "instructions", + "max_output_tokens", + "parallel_tool_calls", + "previous_response_id", + "thinking", + "store", + "caching", + "stream", + "temperature", + "text", + "tool_choice", + "tools", + "top_p", + "max_tool_calls", + "expire_at", + "extra_headers", + "extra_query", + "extra_body", + "timeout", + "reasoning" + # auth params + "api_key", + "api_base", +] + + +def _to_ark_role(role: Optional[str]) -> Literal["user", "assistant"]: + if role in ["model", "assistant"]: + return "assistant" + return "user" + + +def _safe_json_serialize(obj) -> str: + try: + return json.dumps(obj, ensure_ascii=False) + except (TypeError, OverflowError): + return str(obj) + + +def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: + schema_dict = ( + schema.model_dump(exclude_none=True) + if isinstance(schema, types.Schema) + else dict(schema) + ) + enum_values = schema_dict.get("enum") + if isinstance(enum_values, (list, tuple)): + schema_dict["enum"] = [value for value in enum_values if value is not None] + + if "type" in schema_dict and schema_dict["type"] is not None: + t = schema_dict["type"] + schema_dict["type"] = (t.value if isinstance(t, types.Type) else str(t)).lower() + + if "items" in schema_dict: + items = schema_dict["items"] + schema_dict["items"] = ( + _schema_to_dict(items) if isinstance(items, (types.Schema, dict)) else items + ) + + if "properties" in schema_dict: + new_props = {} + for key, value in schema_dict["properties"].items(): + if isinstance(value, (types.Schema, dict)): + new_props[key] = _schema_to_dict(value) + else: + new_props[key] = value + schema_dict["properties"] = new_props + + return schema_dict + + +# ----------------------------------------------------------------- +# inputs param transform ------------------------------------------ +def _file_data_to_content_param( + part: types.Part, +) -> ResponseInputContentParam: + file_uri = part.file_data.file_uri + mime_type = part.file_data.mime_type + fps = 1.0 + if getattr(part, "video_metadata", None): + video_metadata = part.video_metadata + if isinstance(video_metadata, dict): + fps = video_metadata.get("fps") + else: + fps = getattr(video_metadata, "fps", 1) + + is_file_id = file_uri.startswith("file_id://") + value = file_uri[10:] if is_file_id else file_uri + # video + if mime_type.startswith("video/"): + param = {"file_id": value} if is_file_id else {"video_url": value} + if fps is not None: + param["fps"] = fps + return ResponseInputVideoParam( + type="input_video", + **param, + ) + # image + if mime_type.startswith("image/"): + return ResponseInputImageParam( + type="input_image", + detail="auto", + **({"file_id": value} if is_file_id else {"image_url": value}), + ) + # file + param = {"file_id": value} if is_file_id else {"file_url": value} + return ResponseInputFileParam( + type="input_file", + **param, + ) + + +def _inline_data_to_content_param(part: types.Part) -> ResponseInputContentParam: + mime_type = ( + part.inline_data.mime_type if part.inline_data else None + ) or "application/octet-stream" + base64_string = base64.b64encode(part.inline_data.data).decode("utf-8") + data_uri = f"data:{mime_type};base64,{base64_string}" + + if mime_type.startswith("image"): + return ResponseInputImageParam( + type="input_image", + image_url=data_uri, + detail="auto", + ) + if mime_type.startswith("video"): + param: Dict[str, Any] = {"video_url": data_uri} + if getattr(part, "video_metadata", None): + video_metadata = part.video_metadata + if isinstance(video_metadata, dict): + fps = video_metadata.get("fps") + else: + fps = getattr(video_metadata, "fps", None) + if fps is not None: + param["fps"] = fps + return ResponseInputVideoParam( + type="input_video", + **param, + ) + + file_param: Dict[str, Any] = {"file_data": data_uri} + return ResponseInputFileParam( + type="input_file", + **file_param, + ) + + +def _get_content( + parts: List[types.Part], + role: Literal["user", "system", "developer", "assistant"], +) -> Optional[EasyInputMessageParam]: + content = [] + for part in parts: + if part.text: + content.append( + ResponseInputTextParam( + type="input_text", + text=part.text, + ) + ) + elif part.inline_data and part.inline_data.data: + content.append(_inline_data_to_content_param(part)) + elif part.file_data: # file_id和file_url + content.append(_file_data_to_content_param(part)) + if len(content) > 0: + return EasyInputMessageParam(type="message", role=role, content=content) + else: + return None + + +def _content_to_input_item( + content: types.Content, +) -> Union[ResponseInputItemParam, List[ResponseInputItemParam]]: + role = _to_ark_role(content.role) + + # 1. FunctionResponse:`Tool` messages cannot be mixed with other content + input_list = [] + for part in content.parts: + if part.function_response: # FunctionCallOutput + input_list.append( + FunctionCallOutput( + call_id=part.function_response.id, + output=_safe_json_serialize(part.function_response.response), + type="function_call_output", + ) + ) + if input_list: + return input_list if len(input_list) > 1 else input_list[0] + + input_content = _get_content(content.parts, role=role) or None + + if role == "user": + # 2. Process the user's message + if input_content: + return input_content + else: # model + # 3. Processing model messages + for part in content.parts: + if part.function_call: + input_list.append( + ResponseFunctionToolCallParam( + arguments=_safe_json_serialize(part.function_call.args), + call_id=part.function_call.id, + name=part.function_call.name, + type="function_call", + ) + ) + elif part.text or part.inline_data: + if input_content: + input_list.append(input_content) + return input_list + + +def _function_declarations_to_tool_param( + function_declaration: types.FunctionDeclaration, +) -> FunctionToolParam: + assert function_declaration.name + + parameters = {"type": "object", "properties": {}} + if function_declaration.parameters and function_declaration.parameters.properties: + properties = {} + for key, value in function_declaration.parameters.properties.items(): + properties[key] = _schema_to_dict(value) + + parameters = { + "type": "object", + "properties": properties, + } + elif function_declaration.parameters_json_schema: + parameters = function_declaration.parameters_json_schema + + tool_params = FunctionToolParam( + name=function_declaration.name, + parameters=parameters, + type="function", + description=function_declaration.description, + ) + + return tool_params + + +def _responses_schema_to_text( + response_schema: types.SchemaUnion, +) -> Optional[ResponseTextConfigParam | dict]: + schema_name = "" + if isinstance(response_schema, dict): + schema_type = response_schema.get("type") + if ( + isinstance(schema_type, str) + and schema_type.lower() in _ARK_TEXT_FIELD_TYPES + ): + return response_schema + schema_dict = dict(response_schema) + elif isinstance(response_schema, type) and issubclass(response_schema, BaseModel): + schema_name = response_schema.__name__ + schema_dict = response_schema.model_json_schema() + elif isinstance(response_schema, BaseModel): + if isinstance(response_schema, types.Schema): + # GenAI Schema instances already represent JSON schema definitions. + schema_name = response_schema.__name__ + schema_dict = response_schema.model_dump(exclude_none=True, mode="json") + else: + schema_name = response_schema.__name__ + schema_dict = response_schema.__class__.model_json_schema() + elif hasattr(response_schema, "model_dump"): + schema_name = response_schema.__name__ + schema_dict = response_schema.model_dump(exclude_none=True, mode="json") + else: + logger.warning( + "Unsupported response_schema type %s for LiteLLM structured outputs.", + type(response_schema), + ) + return None + + return ResponseTextConfigParam( + format={ # noqa + "type": "json_schema", + "name": schema_name, + "schema": schema_dict, + "strict": True, + } + ) + + +def _get_responses_inputs( + llm_request: LlmRequest, +) -> Tuple[ + Optional[str], + Optional[List[ResponseInputItemParam]], + Optional[List[FunctionToolParam]], + Optional[ResponseTextConfigParam], + Optional[Dict], +]: + # 0. instructions(system prompt) + instructions: Optional[str] = None + if llm_request.config and llm_request.config.system_instruction: + instructions = llm_request.config.system_instruction + # 1. input + input_params: Optional[List[ResponseInputItemParam]] = [] + for content in llm_request.contents or []: + # Each content represents `one conversation`. + # This `one conversation` may contain `multiple pieces of content`, + # but it cannot contain `multiple conversations`. + input_item_or_list = _content_to_input_item(content) + if isinstance(input_item_or_list, list): + input_params.extend(input_item_or_list) + elif input_item_or_list: + input_params.append(input_item_or_list) + + # 2. Convert tool declarations + tools: Optional[List[FunctionToolParam]] = None + if ( + llm_request.config + and llm_request.config.tools + and llm_request.config.tools[0].function_declarations + ): + tools = [ + _function_declarations_to_tool_param(tool) + for tool in llm_request.config.tools[0].function_declarations + ] + + # 3. Handle `output-schema` -> `text` + text: Optional[ResponseTextConfigParam] = None + if llm_request.config and llm_request.config.response_schema: + text = _responses_schema_to_text(llm_request.config.response_schema) + + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + generation_params = {} + for key in ("temperature", "max_output_tokens", "top_p"): + if key in config_dict: + generation_params[key] = config_dict[key] + + if not generation_params: + generation_params = None + return instructions, input_params, tools, text, generation_params + + +def get_model_without_provider(request_data: dict) -> dict: + model = request_data.get("model") + + if not isinstance(model, str): + raise ValueError( + "Unsupported Responses API request: 'model' must be a string in the OpenAI-style format, e.g. 'openai/gpt-4o'." + ) + + if "/" not in model: + raise ValueError( + "Unsupported Responses API request: only OpenAI-style model names are supported (use 'openai/')." + ) + + provider, actual_model = model.split("/", 1) + if provider != "openai": + raise ValueError( + f"Unsupported model prefix '{provider}'. Responses API request format only supports 'openai/'." + ) + + request_data["model"] = actual_model + + return request_data + + +def filtered_inputs( + inputs: List[ResponseInputItemParam], +) -> List[ResponseInputItemParam]: + # Keep the first message and all consecutive user messages from the end + # Collect all consecutive user messages from the end + new_inputs = [] + for m in reversed(inputs): # Skip the first message + if m.get("type") == "function_call_output" or m.get("role") == "user": + new_inputs.append(m) + else: + break # Stop when we encounter a non-user message + + return new_inputs[::-1] + + +def _is_caching_enabled(request_data: dict) -> bool: + extra_body = request_data.get("extra_body") + if not isinstance(extra_body, dict): + return False + caching = extra_body.get("caching") + if not isinstance(caching, dict): + return False + return caching.get("type") == "enabled" + + +def _remove_caching(request_data: dict) -> None: + extra_body = request_data.get("extra_body") + if isinstance(extra_body, dict): + extra_body.pop("caching", None) + request_data.pop("caching", None) + + +def request_reorganization_by_ark(request_data: Dict) -> Dict: + # 1. model provider + request_data = get_model_without_provider(request_data) + + # 2. filtered input + request_data["input"] = filtered_inputs(request_data["input"]) + + # 3. filter not support data + request_data = { + key: value for key, value in request_data.items() if key in ark_supported_fields + } + + # [Note: Ark Limitations] caching and text + # After enabling caching, output_schema(text) cannot be used. Caching must be disabled. + if _is_caching_enabled(request_data) and request_data.get("text") is not None: + logger.warning( + "Caching is enabled, but text is provided. Ark does not support caching with text. Caching will be disabled." + ) + _remove_caching(request_data) + + # [Note: Ark Limitations] tools and previous_response_id + # Remove tools in subsequent rounds (when previous_response_id is present) + if ( + "tools" in request_data + and "previous_response_id" in request_data + and request_data["previous_response_id"] is not None + ): + # Remove tools in subsequent rounds regardless of caching status + del request_data["tools"] + + # [Note: Ark Limitations] caching and store + # Ensure store field is true or default when caching is enabled + if _is_caching_enabled(request_data): + # Set store to true when caching is enabled for writing + if "store" not in request_data: + request_data["store"] = True + elif request_data["store"] is False: + # Override false to true for cache writing + request_data["store"] = True + + # [NOTE Ark Limitations] instructions -> input (because of caching) + # Due to the Volcano Ark settings, there is a conflict between the cache and the instructions field. + # If a system prompt is needed, it should be placed in the system role message within the input, instead of using the instructions parameter. + # https://www.volcengine.com/docs/82379/1585128 + instructions: Optional[str] = request_data.pop("instructions", None) + if instructions and not request_data.get("previous_response_id"): + request_data["input"].insert( + 0, + EasyInputMessageParam( + role="system", + type="message", + content=[ + ResponseInputTextParam( + type="input_text", + text=instructions, + ) + ], + ), + ) + + return request_data + + +# --------------------------------------- +# output transfer ----------------------- +def event_to_generate_content_response( + event: Union[ArkTypeResponse, ResponseStreamEvent], + *, + is_partial: bool = False, + model_version: str = None, +) -> Optional[LlmResponse]: + parts = [] + if not is_partial: + for output in event.output: + if isinstance(output, ResponseReasoningItem): + parts.append( + types.Part( + text="\n".join([summary.text for summary in output.summary]), + thought=True, + ) + ) + elif isinstance(output, ResponseOutputMessage): + text = "" + if isinstance(output.content, list): + for item in output.content: + if isinstance(item, ResponseOutputText): + text += item.text + parts.append(types.Part(text=text)) + + elif isinstance(output, ResponseFunctionToolCall): + part = types.Part.from_function_call( + name=output.name, args=json.loads(output.arguments or "{}") + ) + part.function_call.id = output.call_id + parts.append(part) + + else: + if isinstance(event, ResponseReasoningSummaryTextDeltaEvent): + parts.append(types.Part(text=event.delta, thought=True)) + elif isinstance(event, ResponseTextDeltaEvent): + parts.append(types.Part.from_text(text=event.delta)) + elif isinstance(event, ResponseCompletedEvent): + raw_response = event.response + llm_response = ark_response_to_generate_content_response(raw_response) + return llm_response + else: + return None + return LlmResponse( + content=types.Content(role="model", parts=parts), + partial=is_partial, + model_version=model_version, + ) + + +def ark_response_to_generate_content_response( + raw_response: ArkTypeResponse, +) -> LlmResponse: + """ + ArkTypeResponse -> LlmResponse + instead of `_model_response_to_generate_content_response`, + """ + outputs = raw_response.output + status = raw_response.status + incomplete_details = getattr( + raw_response.incomplete_details or None, "reason", "other" + ) + + finish_reason = _FINISH_REASON_MAPPING.get(status, {}).get( + incomplete_details, types.FinishReason.OTHER + ) + + if not outputs: + raise ValueError("No message in response") + + llm_response = event_to_generate_content_response( + raw_response, model_version=raw_response.model, is_partial=False + ) + llm_response.finish_reason = finish_reason + if raw_response.usage: + llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=raw_response.usage.input_tokens, + candidates_token_count=raw_response.usage.output_tokens, + total_token_count=raw_response.usage.total_tokens, + cached_content_token_count=raw_response.usage.input_tokens_details.cached_tokens, + ) + + # previous_response_id + llm_response.interaction_id = raw_response.id + + return llm_response + + class ArkLlmClient: async def aresponse( self, **kwargs - ) -> Union[OpenAITypeResponse, openai.AsyncStream[ResponseStreamEvent]]: + ) -> Union[ArkTypeResponse, AsyncStream[ResponseStreamEvent]]: # 1. Get request params - api_base = kwargs.pop("api_base", None) - api_key = kwargs.pop("api_key", None) + api_base = kwargs.pop("api_base", DEFAULT_VIDEO_MODEL_API_BASE) + api_key = kwargs.pop("api_key", settings.model.api_key) # 2. Call openai responses - client = openai.AsyncOpenAI( + client = AsyncArk( base_url=api_base, api_key=api_key, ) @@ -66,15 +642,29 @@ async def aresponse( return raw_response -class ArkLlm(LiteLlm): +class ArkLlm(Gemini): + model: str llm_client: ArkLlmClient = Field(default_factory=ArkLlmClient) _additional_args: Dict[str, Any] = None - transform_handler: CompletionToResponsesAPIHandler = Field( - default_factory=CompletionToResponsesAPIHandler - ) + use_interactions_api: bool = True def __init__(self, **kwargs): + # adk version check + if "previous_interaction_id" not in LlmRequest.model_fields: + raise ImportError( + "If using the ResponsesAPI, " + "please upgrade the version of google-adk to `1.21.0` or higher with the command: " + "`pip install -U 'google-adk>=1.21.0'`" + ) super().__init__(**kwargs) + drop_params = kwargs.pop("drop_params", None) + self._additional_args = dict(kwargs) + self._additional_args.pop("llm_client", None) + self._additional_args.pop("messages", None) + self._additional_args.pop("tools", None) + self._additional_args.pop("stream", None) + if drop_params is not None: + self._additional_args["drop_params"] = drop_params async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False @@ -91,8 +681,8 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) # logger.debug(_build_request_log(llm_request)) - messages, tools, response_format, generation_params = _get_completion_inputs( - llm_request + instructions, input_param, tools, text_format, generation_params = ( + _get_responses_inputs(llm_request) ) if "functions" in self._additional_args: @@ -101,152 +691,40 @@ async def generate_content_async( # ------------------------------------------------------ # # get previous_response_id previous_response_id = None - if llm_request.cache_metadata and llm_request.cache_metadata.cache_name: - previous_response_id = llm_request.cache_metadata.cache_name - completion_args = { + if llm_request.previous_interaction_id: + previous_response_id = llm_request.previous_interaction_id + responses_args = { "model": self.model, - "messages": messages, + "instructions": instructions, + "input": input_param, "tools": tools, - "response_format": response_format, + "text": text_format, "previous_response_id": previous_response_id, # supply previous_response_id } # ------------------------------------------------------ # - completion_args.update(self._additional_args) + responses_args.update(self._additional_args) if generation_params: - completion_args.update(generation_params) - response_args = self.transform_handler.transform_request(**completion_args) + responses_args.update(generation_params) + responses_args = request_reorganization_by_ark(responses_args) if stream: - text = "" - # Track function calls by index - function_calls = {} # index -> {name, args, id} - response_args["stream"] = True - aggregated_llm_response = None - aggregated_llm_response_with_tool_call = None - usage_metadata = None - fallback_index = 0 - raw_response = await self.llm_client.aresponse(**response_args) - async for part in raw_response: - for ( - model_response, - chunk, - finish_reason, - ) in self.transform_handler.stream_event_to_chunk( - part, model=self.model - ): - if isinstance(chunk, FunctionChunk): - index = chunk.index or fallback_index - if index not in function_calls: - function_calls[index] = {"name": "", "args": "", "id": None} - - if chunk.name: - function_calls[index]["name"] += chunk.name - if chunk.args: - function_calls[index]["args"] += chunk.args - - # check if args is completed (workaround for improper chunk - # indexing) - try: - json.loads(function_calls[index]["args"]) - fallback_index += 1 - except json.JSONDecodeError: - pass - - function_calls[index]["id"] = ( - chunk.id or function_calls[index]["id"] or str(index) - ) - elif isinstance(chunk, TextChunk): - text += chunk.text - yield _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", - content=chunk.text, - ), - is_partial=True, - ) - elif isinstance(chunk, UsageMetadataChunk): - usage_metadata = types.GenerateContentResponseUsageMetadata( - prompt_token_count=chunk.prompt_tokens, - candidates_token_count=chunk.completion_tokens, - total_token_count=chunk.total_tokens, - ) - # ------------------------------------------------------ # - if model_response.get("usage", {}).get("prompt_tokens_details"): - usage_metadata.cached_content_token_count = ( - model_response.get("usage", {}) - .get("prompt_tokens_details") - .cached_tokens - ) - # ------------------------------------------------------ # - - if ( - finish_reason == "tool_calls" or finish_reason == "stop" - ) and function_calls: - tool_calls = [] - for index, func_data in function_calls.items(): - if func_data["id"]: - tool_calls.append( - ChatCompletionMessageToolCall( - type="function", - id=func_data["id"], - function=Function( - name=func_data["name"], - arguments=func_data["args"], - index=index, - ), - ) - ) - aggregated_llm_response_with_tool_call = ( - _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", - content=text, - tool_calls=tool_calls, - ) - ) - ) - self.transform_handler.adapt_responses_api( - model_response, - aggregated_llm_response_with_tool_call, - stream=True, - ) - text = "" - function_calls.clear() - elif finish_reason == "stop" and text: - aggregated_llm_response = _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", content=text - ) - ) - self.transform_handler.adapt_responses_api( - model_response, - aggregated_llm_response, - stream=True, - ) - text = "" - - # waiting until streaming ends to yield the llm_response as litellm tends - # to send chunk that contains usage_metadata after the chunk with - # finish_reason set to tool_calls or stop. - if aggregated_llm_response: - if usage_metadata: - aggregated_llm_response.usage_metadata = usage_metadata - usage_metadata = None - yield aggregated_llm_response - - if aggregated_llm_response_with_tool_call: - if usage_metadata: - aggregated_llm_response_with_tool_call.usage_metadata = ( - usage_metadata - ) - yield aggregated_llm_response_with_tool_call - + responses_args["stream"] = True + async for part in await self.llm_client.aresponse(**responses_args): + llm_response = event_to_generate_content_response( + event=part, is_partial=True, model_version=self.model + ) + if llm_response: + yield llm_response else: - raw_response = await self.llm_client.aresponse(**response_args) - for ( - llm_response - ) in self.transform_handler.openai_response_to_generate_content_response( - llm_request, raw_response - ): - yield llm_response + raw_response = await self.llm_client.aresponse(**responses_args) + llm_response = ark_response_to_generate_content_response(raw_response) + yield llm_response + + @classmethod + @override + def supported_models(cls) -> list[str]: + return [ + # For OpenAI models (e.g., "openai/gpt-4o") + r"openai/.*", + ] diff --git a/veadk/models/ark_transform.py b/veadk/models/ark_transform.py deleted file mode 100644 index dc2301e5..00000000 --- a/veadk/models/ark_transform.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# adapted from Google ADK models adk-python/blob/main/src/google/adk/models/lite_llm.py at f1f44675e4a86b75e72cfd838efd8a0399f23e24 · google/adk-python - -import uuid -from typing import Any, Dict, Optional, cast, List, Generator, Tuple, Union - -import litellm -from google.adk.models import LlmResponse, LlmRequest -from google.adk.models.cache_metadata import CacheMetadata -from google.adk.models.lite_llm import ( - TextChunk, - FunctionChunk, - UsageMetadataChunk, - _model_response_to_chunk, - _model_response_to_generate_content_response, -) -from openai.types.responses import ( - Response as OpenAITypeResponse, - ResponseStreamEvent, - ResponseTextDeltaEvent, - ResponseOutputMessage, - ResponseFunctionToolCall, -) -from openai.types.responses import ( - ResponseCompletedEvent, -) -from litellm.completion_extras.litellm_responses_transformation.transformation import ( - LiteLLMResponsesTransformationHandler, -) -from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider -from litellm.types.llms.openai import ResponsesAPIResponse -from litellm.types.utils import ( - ModelResponse, - LlmProviders, - Choices, - Message, -) -from litellm.utils import ProviderConfigManager - -from veadk.utils.logger import get_logger - -# This will add functions to prompts if functions are provided. -litellm.add_function_to_prompt = True - -logger = get_logger(__name__) - - -openai_supported_fields = [ - "stream", - "background", - "include", - "input", - "instructions", - "max_output_tokens", - "max_tool_calls", - "metadata", - "model", - "parallel_tool_calls", - "previous_response_id", - "prompt", - "prompt_cache_key", - "reasoning", - "safety_identifier", - "service_tier", - "store", - "stream", - "stream_options", - "temperature", - "text", - "tool_choice", - "tools", - "top_logprobs", - "top_p", - "truncation", - "user", - "extra_headers", - "extra_query", - "extra_body", - "timeout", - # auth params - "api_key", - "api_base", -] - - -def ark_field_reorganization(request_data: dict) -> dict: - # [Note: Ark Limitations] tools and previous_response_id - # Remove tools in subsequent rounds (when previous_response_id is present) - if ( - "tools" in request_data - and "previous_response_id" in request_data - and request_data["previous_response_id"] is not None - ): - # Remove tools in subsequent rounds regardless of caching status - del request_data["tools"] - - # [Note: Ark Limitations] caching and store - # Ensure store field is true or default when caching is enabled - if ( - "extra_body" in request_data - and isinstance(request_data["extra_body"], dict) - and "caching" in request_data["extra_body"] - and isinstance(request_data["extra_body"]["caching"], dict) - and request_data["extra_body"]["caching"].get("type") == "enabled" - ): - # Set store to true when caching is enabled for writing - if "store" not in request_data: - request_data["store"] = True - elif request_data["store"] is False: - # Override false to true for cache writing - request_data["store"] = True - - # [NOTE Ark Limitations] instructions -> input (because of caching) - # Due to the Volcano Ark settings, there is a conflict between the cache and the instructions field. - # If a system prompt is needed, it should be placed in the system role message within the input, instead of using the instructions parameter. - # https://www.volcengine.com/docs/82379/1585128 - instructions = request_data.pop("instructions", None) - if instructions: - request_data["input"] = [ - { - "content": [{"text": instructions, "type": "input_text"}], - "role": "system", - "type": "message", - } - ] + request_data["input"] - - return request_data - - -def build_cache_metadata(response_id: str) -> CacheMetadata: - """Create a new CacheMetadata instance for agent response tracking. - - Args: - response_id: Response ID to track - - Returns: - A new CacheMetadata instance with the agent-response mapping - """ - if "contents_count" in CacheMetadata.model_fields: # adk >= 1.17 - cache_metadata = CacheMetadata( - cache_name=response_id, - expire_time=0, - fingerprint="", - invocations_used=0, - contents_count=0, - ) - else: # 1.15 <= adk < 1.17 - cache_metadata = CacheMetadata( - cache_name=response_id, - expire_time=0, - fingerprint="", - invocations_used=0, - cached_contents_count=0, - ) - return cache_metadata - - -class CompletionToResponsesAPIHandler: - def __init__(self): - self.litellm_handler = LiteLLMResponsesTransformationHandler() - - def transform_request( - self, model: str, messages: list, tools: Optional[list], **kwargs - ): - # Keep the first message and all consecutive user messages from the end - filtered_messages = messages[:1] - - # Collect all consecutive user messages from the end - user_messages_from_end = [] - for message in reversed(messages[1:]): # Skip the first message - if message.get("role") and message.get("role") in {"user", "tool"}: - user_messages_from_end.append(message) - else: - break # Stop when we encounter a non-user message - - # Reverse to maintain original order and add to filtered messages - filtered_messages.extend(reversed(user_messages_from_end)) - - messages = filtered_messages - # completion_request to responses api request - # 1. model and llm_custom_provider - model, custom_llm_provider, _, _ = get_llm_provider(model=model) - - # 2. input and instruction - if custom_llm_provider is not None and custom_llm_provider in [ - provider.value for provider in LlmProviders - ]: - provider_config = ProviderConfigManager.get_provider_chat_config( - model=model, provider=LlmProviders(custom_llm_provider) - ) - if provider_config is not None: - messages = provider_config.translate_developer_role_to_system_role( - messages=messages - ) - - input_items, instructions = ( - self.litellm_handler.convert_chat_completion_messages_to_responses_api( - messages - ) - ) - if tools is not None: - tools = self.litellm_handler._convert_tools_to_responses_format( - cast(List[Dict[str, Any]], tools) - ) - - response_args = { - "input": input_items, - "instructions": instructions, - "tools": tools, - "stream": kwargs.get("stream", False), - "model": model, - **kwargs, - } - result = { - key: value - for key, value in response_args.items() - if key in openai_supported_fields - } - - # Filter and reorganize scenarios that are not supported by some arks - return ark_field_reorganization(result) - - def transform_response( - self, openai_response: OpenAITypeResponse, stream: bool = False - ) -> list[ModelResponse]: - # openai_type_response -> responses_api_response -> completion_response - result_list = [] - raw_response_list = construct_responses_api_response(openai_response) - for raw_response in raw_response_list: - model_response = ModelResponse(stream=stream) - setattr(model_response, "usage", litellm.Usage()) - response = self.litellm_handler.transform_response( - model=raw_response.model, - raw_response=raw_response, - model_response=model_response, - logging_obj=None, - request_data={}, - messages=[], - optional_params={}, - litellm_params={}, - encoding=None, - ) - if raw_response and hasattr(raw_response, "id"): - response.id = raw_response.id - result_list.append(response) - - return result_list - - def openai_response_to_generate_content_response( - self, llm_request: LlmRequest, raw_response: OpenAITypeResponse - ) -> list[LlmResponse]: - """ - OpenAITypeResponse -> litellm.ModelResponse -> LlmResponse - instead of `_model_response_to_generate_content_response`, - """ - # no stream response - model_response_list = self.transform_response( - openai_response=raw_response, stream=False - ) - llm_response_list = [] - for model_response in model_response_list: - llm_response = _model_response_to_generate_content_response(model_response) - - llm_response = self.adapt_responses_api( - model_response, - llm_response, - ) - llm_response_list.append(llm_response) - return llm_response_list - - def adapt_responses_api( - self, - model_response: ModelResponse, - llm_response: LlmResponse, - stream: bool = False, - ): - """ - Adapt responses api. - """ - if not model_response.id.startswith("chatcmpl"): - previous_response_id = model_response["id"] - llm_response.cache_metadata = build_cache_metadata( - previous_response_id, - ) - # add responses cache data - if not stream: - if model_response.get("usage", {}).get("prompt_tokens_details"): - if llm_response.usage_metadata: - llm_response.usage_metadata.cached_content_token_count = ( - model_response.get("usage", {}) - .get("prompt_tokens_details") - .cached_tokens - ) - return llm_response - - def stream_event_to_chunk( - self, event: ResponseStreamEvent, model: str - ) -> Generator[ - Tuple[ - ModelResponse, - Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]], - Optional[str], - ], - None, - None, - ]: - """ - instead of using `_model_response_to_chunk`, - we use our own implementation to support the responses api. - """ - choices = [] - - if isinstance(event, ResponseTextDeltaEvent): - delta = Message(content=event.delta) - choices.append( - Choices(delta=delta, index=event.output_index, finish_reason=None) - ) - model_response = ModelResponse( - stream=True, choices=choices, model=model, id=str(uuid.uuid4()) - ) - for chunk, _ in _model_response_to_chunk(model_response): - # delta text, not finish - yield model_response, chunk, None - elif isinstance(event, ResponseCompletedEvent): - response = event.response - model_response_list = self.transform_response(response, stream=True) - for model_response in model_response_list: - model_response = fix_model_response(model_response) - - for chunk, finish_reason in _model_response_to_chunk(model_response): - if isinstance(chunk, TextChunk): - yield model_response, None, finish_reason - else: - yield model_response, chunk, finish_reason - else: - # Ignore other event types like ResponseOutputItemAddedEvent, etc. - pass - - -def fix_model_response(model_response: ModelResponse) -> ModelResponse: - """ - fix: tool_call has no attribute `index` in `_model_response_to_chunk` - """ - for i, choice in enumerate(model_response.choices): - if choice.message.tool_calls: - for idx, tool_call in enumerate(choice.message.tool_calls): - if not tool_call.get("index"): - model_response.choices[i].message.tool_calls[idx].index = 0 - - return model_response - - -def construct_responses_api_response( - openai_response: OpenAITypeResponse, -) -> list[ResponsesAPIResponse]: - output = openai_response.output - - # Check if we need to split the response - if len(output) >= 2: - # Check if output contains both ResponseOutputMessage and ResponseFunctionToolCall types - has_message = any(isinstance(item, ResponseOutputMessage) for item in output) - has_tool_call = any( - isinstance(item, ResponseFunctionToolCall) for item in output - ) - - if has_message and has_tool_call: - # Split into separate responses for each item - raw_response_list = [] - for item in output: - if isinstance(item, (ResponseOutputMessage, ResponseFunctionToolCall)): - raw_response_list.append( - ResponsesAPIResponse( - **{ - k: v - for k, v in openai_response.model_dump().items() - if k != "output" - }, - output=[item], - ) - ) - return raw_response_list - - # Otherwise, return the original response structure - return [ResponsesAPIResponse(**openai_response.model_dump())] diff --git a/veadk/utils/misc.py b/veadk/utils/misc.py index 6755479e..ffc825e7 100644 --- a/veadk/utils/misc.py +++ b/veadk/utils/misc.py @@ -18,7 +18,7 @@ import sys import time import types -from typing import Any, Dict, List, MutableMapping, Tuple +from typing import Any, Dict, List, MutableMapping, Tuple, Optional import requests from yaml import safe_load @@ -184,32 +184,32 @@ def get_agent_dir(): return full_path -def check_litellm_version(min_version: str): - """ - Check if the installed litellm version meets the minimum requirement. - - Args: - min_version (str): The minimum required version of litellm. - """ - try: - from packaging.version import InvalidVersion - from packaging.version import parse as parse_version - import pkg_resources +async def upload_to_files_api( + local_path: str, + fps: Optional[float] = None, + poll_interval: float = 3.0, + max_wait_seconds: float = 10 * 60, +) -> str: + from veadk.config import getenv, settings + from veadk.consts import DEFAULT_MODEL_AGENT_API_BASE + from volcenginesdkarkruntime import AsyncArk - try: - installed = parse_version(pkg_resources.get_distribution("litellm").version) - except pkg_resources.DistributionNotFound: - raise ImportError( - "litellm installation not detected, please install it first: pip install litellm>=1.79.3" - ) from None - except InvalidVersion as e: - raise ValueError(f"Invalid format of litellm version number:{e}") from None - required = parse_version(min_version) - if installed < required: - raise ValueError( - "You have used `enable_responses=True`. If you want to use the `responses_api`, please ensure that `litellm>=1.79.3`" - ) - except ImportError: - raise ImportError( - "packaging or pkg_resources not found. Please install them: pip install packaging setuptools" - ) + client = AsyncArk( + api_key=getenv("MODEL_AGENT_API_KEY", settings.model.api_key), + base_url=getenv("DEFAULT_MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE), + ) + file = await client.files.create( + file=open(local_path, "rb"), + purpose="user_data", + preprocess_configs={ + "video": { + "fps": fps, + } + } + if fps + else None, + ) + await client.files.wait_for_processing( + id=file.id, poll_interval=poll_interval, max_wait_seconds=max_wait_seconds + ) + return file.id