diff --git a/tests/test_runner.py b/tests/test_runner.py index 510a3323..b6a6ff8a 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -28,7 +28,7 @@ def _test_convert_messages(runner): role="user", ) ] - actual_message = runner._convert_messages(message) + actual_message = runner._convert_messages(message, session_id="test_session_id") assert actual_message == expected_message message = ["test message 1", "test message 2"] @@ -42,7 +42,7 @@ def _test_convert_messages(runner): role="user", ), ] - actual_message = runner._convert_messages(message) + actual_message = runner._convert_messages(message, session_id="test_session_id") assert actual_message == expected_message diff --git a/tests/test_tos.py b/tests/test_tos.py new file mode 100644 index 00000000..51595184 --- /dev/null +++ b/tests/test_tos.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest import mock +import veadk.integrations.ve_tos.ve_tos as tos_mod + +# 使用 pytest-asyncio +pytest_plugins = ("pytest_asyncio",) + + +@pytest.fixture +def mock_client(monkeypatch): + fake_client = mock.Mock() + + monkeypatch.setenv("DATABASE_TOS_REGION", "test-region") + monkeypatch.setenv("VOLCENGINE_ACCESS_KEY", "test-access-key") + monkeypatch.setenv("VOLCENGINE_SECRET_KEY", "test-secret-key") + monkeypatch.setenv("DATABASE_TOS_BUCKET", "test-bucket") + + monkeypatch.setattr(tos_mod.tos, "TosClientV2", lambda *a, **k: fake_client) + + class FakeExceptions: + class TosServerError(Exception): + def __init__(self, msg): + super().__init__(msg) + self.status_code = None + + monkeypatch.setattr(tos_mod.tos, "exceptions", FakeExceptions) + monkeypatch.setattr( + tos_mod.tos, + "StorageClassType", + type("S", (), {"Storage_Class_Standard": "STANDARD"}), + ) + monkeypatch.setattr( + tos_mod.tos, "ACLType", type("A", (), {"ACL_Private": "private"}) + ) + + return fake_client + + +@pytest.fixture +def tos_client(mock_client): + return tos_mod.VeTOS() + + +def test_create_bucket_exists(tos_client, mock_client): + mock_client.head_bucket.return_value = None # head_bucket 正常返回表示存在 + result = tos_client.create_bucket() + assert result is True + mock_client.create_bucket.assert_not_called() + + +def test_create_bucket_not_exists(tos_client, mock_client): + exc = tos_mod.tos.exceptions.TosServerError("not found") + exc.status_code = 404 + mock_client.head_bucket.side_effect = exc + + result = tos_client.create_bucket() + assert result is True + mock_client.create_bucket.assert_called_once() + + +@pytest.mark.asyncio +async def test_upload_bytes_success(tos_client, mock_client): + mock_client.head_bucket.return_value = True + data = b"hello world" + + result = await tos_client.upload("obj-key", data) + assert result is True + mock_client.put_object.assert_called_once() + mock_client.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_upload_file_success(tmp_path, tos_client, mock_client): + mock_client.head_bucket.return_value = True + file_path = tmp_path / "file.txt" + file_path.write_text("hello file") + + result = await tos_client.upload("obj-key", str(file_path)) + assert result is True + mock_client.put_object_from_file.assert_called_once() + mock_client.close.assert_called_once() + + +def test_download_success(tmp_path, tos_client, mock_client): + save_path = tmp_path / "out.txt" + mock_client.get_object.return_value = [b"abc", b"def"] + + result = tos_client.download("obj-key", str(save_path)) + assert result is True + assert save_path.read_bytes() == b"abcdef" + + +def test_download_fail(tos_client, mock_client): + mock_client.get_object.side_effect = Exception("boom") + result = tos_client.download("obj-key", "somewhere.txt") + assert result is False diff --git a/veadk/integrations/ve_tos/ve_tos.py b/veadk/integrations/ve_tos/ve_tos.py new file mode 100644 index 00000000..c5d93fb6 --- /dev/null +++ b/veadk/integrations/ve_tos/ve_tos.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from veadk.config import getenv +from veadk.utils.logger import get_logger +import tos +import asyncio +from typing import Union +from pydantic import BaseModel, Field +from typing import Any +from urllib.parse import urlparse +from datetime import datetime + +logger = get_logger(__name__) + + +class TOSConfig(BaseModel): + region: str = Field( + default_factory=lambda: getenv("DATABASE_TOS_REGION"), + description="TOS region", + ) + ak: str = Field( + default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"), + description="Volcengine access key", + ) + sk: str = Field( + default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"), + description="Volcengine secret key", + ) + bucket_name: str = Field( + default_factory=lambda: getenv("DATABASE_TOS_BUCKET"), + description="TOS bucket name", + ) + + +class VeTOS(BaseModel): + config: TOSConfig = Field(default_factory=TOSConfig) + + def model_post_init(self, __context: Any) -> None: + try: + self._client = tos.TosClientV2( + self.config.ak, + self.config.sk, + endpoint=f"tos-{self.config.region}.volces.com", + region=self.config.region, + ) + logger.info("Connected to TOS successfully.") + except Exception as e: + logger.error(f"Client initialization failed:{e}") + return None + + def create_bucket(self) -> bool: + """If the bucket does not exist, create it""" + try: + self._client.head_bucket(self.config.bucket_name) + logger.info(f"Bucket {self.config.bucket_name} already exists") + return True + except tos.exceptions.TosServerError as e: + if e.status_code == 404: + self._client.create_bucket( + bucket=self.config.bucket_name, + storage_class=tos.StorageClassType.Storage_Class_Standard, + acl=tos.ACLType.ACL_Private, + ) + logger.info(f"Bucket {self.config.bucket_name} created successfully") + return True + except Exception as e: + logger.error(f"Bucket creation failed: {str(e)}") + return False + + def build_tos_url( + self, user_id: str, app_name: str, session_id: str, data_path: str + ) -> tuple[str, str]: + """generate TOS object key""" + parsed_url = urlparse(data_path) + + if parsed_url.scheme and parsed_url.scheme in ("http", "https", "ftp", "ftps"): + file_name = os.path.basename(parsed_url.path) + else: + file_name = os.path.basename(data_path) + + timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] + object_key: str = f"{app_name}-{user_id}-{session_id}/{timestamp}-{file_name}" + tos_url: str = f"https://{self.config.bucket_name}.tos-{self.config.region}.volces.com/{object_key}" + + return object_key, tos_url + + def upload( + self, + object_key: str, + data: Union[str, bytes], + ): + if isinstance(data, str): + data_type = "file" + elif isinstance(data, bytes): + data_type = "bytes" + else: + error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}" + logger.error(error_msg) + raise ValueError(error_msg) + if data_type == "file": + return asyncio.to_thread(self._do_upload_file, object_key, data) + elif data_type == "bytes": + return asyncio.to_thread(self._do_upload_bytes, object_key, data) + + def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool: + try: + if not self._client: + return False + if not self.create_bucket(): + return False + self._client.put_object( + bucket=self.config.bucket_name, key=object_key, content=bytes + ) + logger.debug(f"Upload success, object_key: {object_key}") + self._close() + return True + except Exception as e: + logger.error(f"Upload failed: {e}") + self._close() + return False + + def _do_upload_file(self, object_key: str, file_path: str) -> bool: + try: + if not self._client: + return False + if not self.create_bucket(): + return False + + self._client.put_object_from_file( + bucket=self.config.bucket_name, key=object_key, file_path=file_path + ) + self._close() + logger.debug(f"Upload success, object_key: {object_key}") + return True + except Exception as e: + logger.error(f"Upload failed: {e}") + self._close() + return False + + def download(self, object_key: str, save_path: str) -> bool: + """download image from TOS""" + try: + object_stream = self._client.get_object(self.config.bucket_name, object_key) + + save_dir = os.path.dirname(save_path) + if save_dir and not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + + with open(save_path, "wb") as f: + for chunk in object_stream: + f.write(chunk) + + logger.debug(f"Image download success, saved to: {save_path}") + return True + + except Exception as e: + logger.error(f"Image download failed: {str(e)}") + + return False + + def _close(self): + if self._client: + self._client.close() diff --git a/veadk/runner.py b/veadk/runner.py index 5b7b6183..54d8442a 100644 --- a/veadk/runner.py +++ b/veadk/runner.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from typing import Union from google.adk.agents import RunConfig @@ -31,6 +32,7 @@ from veadk.types import MediaMessage from veadk.utils.logger import get_logger from veadk.utils.misc import read_png_to_bytes +from veadk.integrations.ve_tos.ve_tos import VeTOS logger = get_logger(__name__) @@ -84,13 +86,25 @@ def __init__( plugins=plugins, ) - def _convert_messages(self, messages) -> list: + def _convert_messages(self, messages, session_id) -> list: if isinstance(messages, str): messages = [types.Content(role="user", parts=[types.Part(text=messages)])] elif isinstance(messages, MediaMessage): assert messages.media.endswith(".png"), ( "The MediaMessage only supports PNG format file for now." ) + data = read_png_to_bytes(messages.media) + + ve_tos = VeTOS() + object_key, tos_url = ve_tos.build_tos_url( + self.user_id, self.app_name, session_id, messages.media + ) + try: + asyncio.create_task(ve_tos.upload(object_key, data)) + except Exception as e: + logger.error(f"Upload to TOS failed: {e}") + tos_url = None + messages = [ types.Content( role="user", @@ -98,8 +112,8 @@ def _convert_messages(self, messages) -> list: types.Part(text=messages.text), types.Part( inline_data=Blob( - display_name=messages.media, - data=read_png_to_bytes(messages.media), + display_name=tos_url, + data=data, mime_type="image/png", ) ), @@ -109,7 +123,7 @@ def _convert_messages(self, messages) -> list: elif isinstance(messages, list): converted_messages = [] for message in messages: - converted_messages.extend(self._convert_messages(message)) + converted_messages.extend(self._convert_messages(message, session_id)) messages = converted_messages else: raise ValueError(f"Unknown message type: {type(messages)}") @@ -169,7 +183,7 @@ async def run( run_config: RunConfig | None = None, save_tracing_data: bool = False, ): - converted_messages: list = self._convert_messages(messages) + converted_messages: list = self._convert_messages(messages, session_id) await self.short_term_memory.create_session( app_name=self.app_name, user_id=self.user_id, session_id=session_id diff --git a/veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py b/veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py index 379c033a..d8a1141a 100644 --- a/veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +++ b/veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py @@ -137,6 +137,15 @@ def llm_gen_ai_prompt(params: LLMAttributesParams) -> ExtractorResponse: if part.function_call.args else json.dumps({}) ) + # image + if part.inline_data: + message[f"gen_ai.prompt.{idx}.type"] = "image_url" + message[f"gen_ai.prompt.{idx}.image_url.name"] = ( + part.inline_data.display_name.split("/")[-1] + ) + message[f"gen_ai.prompt.{idx}.image_url.url"] = ( + part.inline_data.display_name + ) if message: messages.append(message) @@ -234,6 +243,14 @@ def llm_gen_ai_user_message(params: LLMAttributesParams) -> ExtractorResponse: message_part[f"parts.{idx}.content"] = str( part.function_response ) + if part.inline_data: + message_part[f"parts.{idx}.type"] = "image_url" + message_part[f"parts.{idx}.image_url.name"] = ( + part.inline_data.display_name.split("/")[-1] + ) + message_part[f"parts.{idx}.image_url.url"] = ( + part.inline_data.display_name + ) message_parts.append(message_part) diff --git a/veadk/tracing/telemetry/telemetry.py b/veadk/tracing/telemetry/telemetry.py index 041b31d2..2f732973 100644 --- a/veadk/tracing/telemetry/telemetry.py +++ b/veadk/tracing/telemetry/telemetry.py @@ -87,6 +87,17 @@ def _set_agent_input_attribute( "gen_ai.user.message", {f"parts.{idx}.type": "text", f"parts.{idx}.content": part.text}, ) + if part.inline_data: + span.add_event( + "gen_ai.user.message", + { + f"parts.{idx}.type": "image_url", + f"parts.{idx}.image_url.name": part.inline_data.display_name.split( + "/" + )[-1], + f"parts.{idx}.image_url.url": part.inline_data.display_name, + }, + ) def _set_agent_output_attribute(span: _Span, llm_response: LlmResponse) -> None: