Skip to content
4 changes: 2 additions & 2 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _test_convert_messages(runner):
role="user",
)
]
actual_message = runner._convert_messages(message)
actual_message = runner._convert_messages(message, session_id="test_session_id")
assert actual_message == expected_message

message = ["test message 1", "test message 2"]
Expand All @@ -42,7 +42,7 @@ def _test_convert_messages(runner):
role="user",
),
]
actual_message = runner._convert_messages(message)
actual_message = runner._convert_messages(message, session_id="test_session_id")
assert actual_message == expected_message


Expand Down
110 changes: 110 additions & 0 deletions tests/test_tos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from unittest import mock
import veadk.integrations.ve_tos.ve_tos as tos_mod

# 使用 pytest-asyncio
pytest_plugins = ("pytest_asyncio",)


@pytest.fixture
def mock_client(monkeypatch):
fake_client = mock.Mock()

monkeypatch.setenv("DATABASE_TOS_REGION", "test-region")
monkeypatch.setenv("VOLCENGINE_ACCESS_KEY", "test-access-key")
monkeypatch.setenv("VOLCENGINE_SECRET_KEY", "test-secret-key")
monkeypatch.setenv("DATABASE_TOS_BUCKET", "test-bucket")

monkeypatch.setattr(tos_mod.tos, "TosClientV2", lambda *a, **k: fake_client)

class FakeExceptions:
class TosServerError(Exception):
def __init__(self, msg):
super().__init__(msg)
self.status_code = None

monkeypatch.setattr(tos_mod.tos, "exceptions", FakeExceptions)
monkeypatch.setattr(
tos_mod.tos,
"StorageClassType",
type("S", (), {"Storage_Class_Standard": "STANDARD"}),
)
monkeypatch.setattr(
tos_mod.tos, "ACLType", type("A", (), {"ACL_Private": "private"})
)

return fake_client


@pytest.fixture
def tos_client(mock_client):
return tos_mod.VeTOS()


def test_create_bucket_exists(tos_client, mock_client):
mock_client.head_bucket.return_value = None # head_bucket 正常返回表示存在
result = tos_client.create_bucket()
assert result is True
mock_client.create_bucket.assert_not_called()


def test_create_bucket_not_exists(tos_client, mock_client):
exc = tos_mod.tos.exceptions.TosServerError("not found")
exc.status_code = 404
mock_client.head_bucket.side_effect = exc

result = tos_client.create_bucket()
assert result is True
mock_client.create_bucket.assert_called_once()


@pytest.mark.asyncio
async def test_upload_bytes_success(tos_client, mock_client):
mock_client.head_bucket.return_value = True
data = b"hello world"

result = await tos_client.upload("obj-key", data)
assert result is True
mock_client.put_object.assert_called_once()
mock_client.close.assert_called_once()


@pytest.mark.asyncio
async def test_upload_file_success(tmp_path, tos_client, mock_client):
mock_client.head_bucket.return_value = True
file_path = tmp_path / "file.txt"
file_path.write_text("hello file")

result = await tos_client.upload("obj-key", str(file_path))
assert result is True
mock_client.put_object_from_file.assert_called_once()
mock_client.close.assert_called_once()


def test_download_success(tmp_path, tos_client, mock_client):
save_path = tmp_path / "out.txt"
mock_client.get_object.return_value = [b"abc", b"def"]

result = tos_client.download("obj-key", str(save_path))
assert result is True
assert save_path.read_bytes() == b"abcdef"


def test_download_fail(tos_client, mock_client):
mock_client.get_object.side_effect = Exception("boom")
result = tos_client.download("obj-key", "somewhere.txt")
assert result is False
176 changes: 176 additions & 0 deletions veadk/integrations/ve_tos/ve_tos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from veadk.config import getenv
from veadk.utils.logger import get_logger
import tos
import asyncio
from typing import Union
from pydantic import BaseModel, Field
from typing import Any
from urllib.parse import urlparse
from datetime import datetime

logger = get_logger(__name__)


class TOSConfig(BaseModel):
region: str = Field(
default_factory=lambda: getenv("DATABASE_TOS_REGION"),
description="TOS region",
)
ak: str = Field(
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"),
description="Volcengine access key",
)
sk: str = Field(
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"),
description="Volcengine secret key",
)
bucket_name: str = Field(
default_factory=lambda: getenv("DATABASE_TOS_BUCKET"),
description="TOS bucket name",
)


class VeTOS(BaseModel):
config: TOSConfig = Field(default_factory=TOSConfig)

def model_post_init(self, __context: Any) -> None:
try:
self._client = tos.TosClientV2(
self.config.ak,
self.config.sk,
endpoint=f"tos-{self.config.region}.volces.com",
region=self.config.region,
)
logger.info("Connected to TOS successfully.")
except Exception as e:
logger.error(f"Client initialization failed:{e}")
return None

def create_bucket(self) -> bool:
"""If the bucket does not exist, create it"""
try:
self._client.head_bucket(self.config.bucket_name)
logger.info(f"Bucket {self.config.bucket_name} already exists")
return True
except tos.exceptions.TosServerError as e:
if e.status_code == 404:
self._client.create_bucket(
bucket=self.config.bucket_name,
storage_class=tos.StorageClassType.Storage_Class_Standard,
acl=tos.ACLType.ACL_Private,
)
logger.info(f"Bucket {self.config.bucket_name} created successfully")
return True
except Exception as e:
logger.error(f"Bucket creation failed: {str(e)}")
return False

def build_tos_url(
self, user_id: str, app_name: str, session_id: str, data_path: str
) -> tuple[str, str]:
"""generate TOS object key"""
parsed_url = urlparse(data_path)

if parsed_url.scheme and parsed_url.scheme in ("http", "https", "ftp", "ftps"):
file_name = os.path.basename(parsed_url.path)
else:
file_name = os.path.basename(data_path)

timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
object_key: str = f"{app_name}-{user_id}-{session_id}/{timestamp}-{file_name}"
tos_url: str = f"https://{self.config.bucket_name}.tos-{self.config.region}.volces.com/{object_key}"

return object_key, tos_url

def upload(
self,
object_key: str,
data: Union[str, bytes],
):
if isinstance(data, str):
data_type = "file"
elif isinstance(data, bytes):
data_type = "bytes"
else:
error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}"
logger.error(error_msg)
raise ValueError(error_msg)
if data_type == "file":
return asyncio.to_thread(self._do_upload_file, object_key, data)
elif data_type == "bytes":
return asyncio.to_thread(self._do_upload_bytes, object_key, data)

def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool:
try:
if not self._client:
return False
if not self.create_bucket():
return False
self._client.put_object(
bucket=self.config.bucket_name, key=object_key, content=bytes
)
logger.debug(f"Upload success, object_key: {object_key}")
self._close()
return True
except Exception as e:
logger.error(f"Upload failed: {e}")
self._close()
return False

def _do_upload_file(self, object_key: str, file_path: str) -> bool:
try:
if not self._client:
return False
if not self.create_bucket():
return False

self._client.put_object_from_file(
bucket=self.config.bucket_name, key=object_key, file_path=file_path
)
self._close()
logger.debug(f"Upload success, object_key: {object_key}")
return True
except Exception as e:
logger.error(f"Upload failed: {e}")
self._close()
return False

def download(self, object_key: str, save_path: str) -> bool:
"""download image from TOS"""
try:
object_stream = self._client.get_object(self.config.bucket_name, object_key)

save_dir = os.path.dirname(save_path)
if save_dir and not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)

with open(save_path, "wb") as f:
for chunk in object_stream:
f.write(chunk)

logger.debug(f"Image download success, saved to: {save_path}")
return True

except Exception as e:
logger.error(f"Image download failed: {str(e)}")

return False

def _close(self):
if self._client:
self._client.close()
24 changes: 19 additions & 5 deletions veadk/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Union

from google.adk.agents import RunConfig
Expand All @@ -31,6 +32,7 @@
from veadk.types import MediaMessage
from veadk.utils.logger import get_logger
from veadk.utils.misc import read_png_to_bytes
from veadk.integrations.ve_tos.ve_tos import VeTOS

logger = get_logger(__name__)

Expand Down Expand Up @@ -84,22 +86,34 @@ def __init__(
plugins=plugins,
)

def _convert_messages(self, messages) -> list:
def _convert_messages(self, messages, session_id) -> list:
if isinstance(messages, str):
messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
elif isinstance(messages, MediaMessage):
assert messages.media.endswith(".png"), (
"The MediaMessage only supports PNG format file for now."
)
data = read_png_to_bytes(messages.media)

ve_tos = VeTOS()
object_key, tos_url = ve_tos.build_tos_url(
self.user_id, self.app_name, session_id, messages.media
)
try:
asyncio.create_task(ve_tos.upload(object_key, data))
except Exception as e:
logger.error(f"Upload to TOS failed: {e}")
tos_url = None

messages = [
types.Content(
role="user",
parts=[
types.Part(text=messages.text),
types.Part(
inline_data=Blob(
display_name=messages.media,
data=read_png_to_bytes(messages.media),
display_name=tos_url,
data=data,
mime_type="image/png",
)
),
Expand All @@ -109,7 +123,7 @@ def _convert_messages(self, messages) -> list:
elif isinstance(messages, list):
converted_messages = []
for message in messages:
converted_messages.extend(self._convert_messages(message))
converted_messages.extend(self._convert_messages(message, session_id))
messages = converted_messages
else:
raise ValueError(f"Unknown message type: {type(messages)}")
Expand Down Expand Up @@ -169,7 +183,7 @@ async def run(
run_config: RunConfig | None = None,
save_tracing_data: bool = False,
):
converted_messages: list = self._convert_messages(messages)
converted_messages: list = self._convert_messages(messages, session_id)

await self.short_term_memory.create_session(
app_name=self.app_name, user_id=self.user_id, session_id=session_id
Expand Down
Loading