diff --git a/.gitleaks.toml b/.gitleaks.toml index e2442bae..0f097484 100644 --- a/.gitleaks.toml +++ b/.gitleaks.toml @@ -73,4 +73,4 @@ description = "Empty environment variables with KEY pattern" regex = '''os\.environ\[".*?KEY"\]\s*=\s*".+"''' [allowlist] -paths = ["requirements.txt", "tests"] +paths = ["requirements.txt", "tests", "veadk/realtime/client.py", "veadk/realtime/live.py"] \ No newline at end of file diff --git a/tests/config/test_model_config.py b/tests/config/test_model_config.py new file mode 100644 index 00000000..ca48426e --- /dev/null +++ b/tests/config/test_model_config.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import TestCase, mock +from veadk.configs.model_configs import RealtimeModelConfig + + +class TestRealtimeModelConfig(TestCase): + def test_default_values(self): + """Test that default values are set correctly""" + config = RealtimeModelConfig() + self.assertEqual(config.name, "doubao_realtime_voice_model") + self.assertEqual( + config.api_base, "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + ) + + @mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": "test_api_key"}) + def test_api_key_from_env(self): + """Test api_key is retrieved from environment variable""" + config = RealtimeModelConfig() + self.assertEqual(config.api_key, "test_api_key") + + @mock.patch.dict(os.environ, {}, clear=True) + @mock.patch( + "veadk.configs.model_configs.get_speech_token", return_value="mocked_token" + ) + def test_api_key_from_get_speech_token(self, mock_get_token): + """Test api_key falls back to get_speech_token when env var is not set""" + config = RealtimeModelConfig() + self.assertEqual(config.api_key, "mocked_token") + mock_get_token.assert_called_once() + + @mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": ""}) + @mock.patch( + "veadk.configs.model_configs.get_speech_token", return_value="mocked_token" + ) + def test_api_key_empty_env_var(self, mock_get_token): + """Test api_key falls back when env var is empty string""" + config = RealtimeModelConfig() + self.assertEqual(config.api_key, "mocked_token") + mock_get_token.assert_called_once() + + def test_api_key_caching(self): + """Test that api_key is properly cached""" + with mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": "test_key"}): + config = RealtimeModelConfig() + first_call = config.api_key + second_call = config.api_key + self.assertEqual(first_call, second_call) + self.assertEqual(first_call, "test_key") diff --git a/tests/realtime/test_doubao_realtime_client.py b/tests/realtime/test_doubao_realtime_client.py new file mode 100644 index 00000000..022fb4a2 --- /dev/null +++ b/tests/realtime/test_doubao_realtime_client.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest.mock import patch, MagicMock +from google.genai._api_client import BaseApiClient +from veadk.realtime.client import DoubaoClient, DoubaoAsyncClient +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + + +class TestDoubaoAsyncClient(unittest.TestCase): + def setUp(self): + self.mock_api_client = MagicMock(spec=BaseApiClient) + self.async_client = DoubaoAsyncClient(self.mock_api_client) + + def test_initialization(self): + self.assertIsInstance(self.async_client, DoubaoAsyncClient) + self.assertEqual(self.async_client._api_client, self.mock_api_client) + + def test_live_property(self): + from veadk.realtime.live import DoubaoAsyncLive + + live_instance = self.async_client.live + self.assertIsInstance(live_instance, DoubaoAsyncLive) + self.assertEqual(live_instance._api_client, self.mock_api_client) + + +class TestDoubaoClient(unittest.TestCase): + def setUp(self): + self.patcher = patch.dict("os.environ", {}, clear=True) + self.patcher.start() + + def tearDown(self): + self.patcher.stop() + + def test_initialization_without_google_key(self): + # Test when GOOGLE_API_KEY is not set + os.environ["REALTIME_API_KEY"] = "hack_google_api_key" + client = DoubaoClient() + self.assertEqual(os.environ["GOOGLE_API_KEY"], "hack_google_api_key") + self.assertIsNotNone(client._aio) + + def test_initialization_with_google_key(self): + # Test when GOOGLE_API_KEY is already set + os.environ["GOOGLE_API_KEY"] = "existing_key" + os.environ["REALTIME_API_KEY"] = "existing_key" + client = DoubaoClient() + self.assertEqual(os.environ["GOOGLE_API_KEY"], "existing_key") + self.assertIsNotNone(client._aio) + + @patch( + "veadk.realtime.client.DoubaoAsyncClient", side_effect=Exception("Test error") + ) + def test_initialization_failure(self, mock_async_client): + # Test when DoubaoAsyncClient initialization fails + os.environ["REALTIME_API_KEY"] = "hack_google_api_key" + client = DoubaoClient() + self.assertIsNone(client._aio) + + def test_aio_property(self): + os.environ["REALTIME_API_KEY"] = "hack_google_api_key" + client = DoubaoClient() + aio_client = client.aio + self.assertIsInstance(aio_client, DoubaoAsyncClient) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/realtime/test_doubao_realtime_voice_llm.py b/tests/realtime/test_doubao_realtime_voice_llm.py new file mode 100644 index 00000000..108266cd --- /dev/null +++ b/tests/realtime/test_doubao_realtime_voice_llm.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from google.genai import types +from veadk.realtime.doubao_realtime_voice_llm import DoubaoRealtimeVoice +from google.adk.models.llm_request import LlmRequest +from google.adk.models.base_llm_connection import BaseLlmConnection +from google.genai.types import GenerateContentConfig +import os +from veadk.realtime.client import DoubaoClient +from veadk.realtime.doubao_realtime_voice_llm import ( + _AGENT_ENGINE_TELEMETRY_TAG, + _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, +) + + +class TestDoubaoRealtimeVoice: + @pytest.fixture + def mock_llm_request(self): + request = MagicMock(spec=LlmRequest) + request.model = "doubao_realtime_voice" + request.config = GenerateContentConfig() + request.config.system_instruction = "Test instruction" + request.config.tools = [] + request.live_connect_config = types.LiveConnectConfig( + http_options=types.HttpOptions() + ) + return request + + def test_supported_models(self): + """Test supported_models returns correct model patterns""" + models = DoubaoRealtimeVoice.supported_models() + assert isinstance(models, list) + assert len(models) == 2 + assert r"doubao_realtime_voice.*" in models + assert r"Doubao_scene_SLM_Doubao_realtime_voice_model.*" in models + + def test_api_client_property(self): + """Test api_client property returns DoubaoClient with correct options""" + model = DoubaoRealtimeVoice() + client = model.api_client + assert isinstance(client, DoubaoClient) + assert client._api_client._http_options.retry_options == model.retry_options + + def test_live_api_client_property(self): + """Test _live_api_client property returns DoubaoClient with correct version""" + model = DoubaoRealtimeVoice() + client = model._live_api_client + assert isinstance(client, DoubaoClient) + assert client._api_client._http_options.api_version == model._live_api_version + + def test_tracking_headers_without_env(self): + """Test _tracking_headers without environment variable""" + model = DoubaoRealtimeVoice() + headers = model._tracking_headers + assert "x-volcengine-api-client" in headers + assert "user-agent" in headers + assert _AGENT_ENGINE_TELEMETRY_TAG not in headers["x-volcengine-api-client"] + + @patch.dict(os.environ, {_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME: "test_id"}) + def test_tracking_headers_with_env(self): + """Test _tracking_headers with environment variable set""" + model = DoubaoRealtimeVoice() + headers = model._tracking_headers + assert _AGENT_ENGINE_TELEMETRY_TAG in headers["x-volcengine-api-client"] + + @pytest.mark.asyncio + async def test_connect_with_speech_config(self, mock_llm_request): + """Test connect method with speech config""" + speech_config = types.SpeechConfig() + model = DoubaoRealtimeVoice(speech_config=speech_config) + + # 修正异步上下文管理器的 mock 设置 + with patch.object(model._live_api_client.aio.live, "connect") as mock_connect: + # 创建模拟的异步上下文管理器 + mock_session = AsyncMock() + mock_connect.return_value.__aenter__.return_value = mock_session + + async with model.connect(mock_llm_request) as connection: + assert isinstance(connection, BaseLlmConnection) + assert ( + mock_llm_request.live_connect_config.speech_config == speech_config + ) + mock_connect.assert_called_once_with( + model=mock_llm_request.model, + config=mock_llm_request.live_connect_config, + ) + + @pytest.mark.asyncio + async def test_connect_without_speech_config(self, mock_llm_request): + """Test connect method without speech config""" + model = DoubaoRealtimeVoice() + + with patch.object(model._live_api_client.aio.live, "connect") as mock_connect: + # 使用AsyncMock模拟会话对象,更贴近真实场景 + mock_session = AsyncMock() + mock_connect.return_value.__aenter__.return_value = mock_session + + async with model.connect(mock_llm_request) as connection: + assert isinstance(connection, BaseLlmConnection) + # 验证speech_config为None而非检查属性是否存在 + assert mock_llm_request.live_connect_config.speech_config is None + mock_connect.assert_called_once_with( + model=mock_llm_request.model, + config=mock_llm_request.live_connect_config, + ) diff --git a/tests/realtime/test_doubao_realtime_voice_llm_connection.py b/tests/realtime/test_doubao_realtime_voice_llm_connection.py new file mode 100644 index 00000000..0caba9e2 --- /dev/null +++ b/tests/realtime/test_doubao_realtime_voice_llm_connection.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import AsyncMock +from veadk.realtime.doubao_realtime_voice_llm_connection import ( + DoubaoRealtimeVoiceLlmConnection, +) +from google.genai import types + + +@pytest.mark.asyncio +async def test_send_realtime_with_blob(): + """Test sending Blob input.""" + # Setup + mock_session = AsyncMock() + connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session) + connection._gemini_session = mock_session + + blob_input = types.Blob() + + # Execute + await connection.send_realtime(blob_input) + + # Verify + mock_session.send_realtime_input.assert_called_once_with(media=blob_input) + + +@pytest.mark.asyncio +async def test_send_realtime_with_activity_start(): + """Test sending ActivityStart input.""" + # Setup + mock_session = AsyncMock() + connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session) + connection._gemini_session = mock_session + + activity_start = types.ActivityStart() + + # Execute + await connection.send_realtime(activity_start) + + # Verify + mock_session.send_realtime_input.assert_called_once_with( + activity_start=activity_start + ) + + +@pytest.mark.asyncio +async def test_send_realtime_with_activity_end(): + """Test sending ActivityEnd input.""" + # Setup + mock_session = AsyncMock() + connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session) + connection._gemini_session = mock_session + + activity_end = types.ActivityEnd() + + # Execute + await connection.send_realtime(activity_end) + + # Verify + mock_session.send_realtime_input.assert_called_once_with(activity_end=activity_end) + + +@pytest.mark.asyncio +async def test_send_realtime_with_unsupported_type(): + """Test sending unsupported input type.""" + # Setup + mock_session = AsyncMock() + connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session) + connection._gemini_session = mock_session + + unsupported_input = "unsupported_type" + + # Execute & Verify + with pytest.raises(ValueError) as excinfo: + await connection.send_realtime(unsupported_input) + + assert "Unsupported input type" in str(excinfo.value) diff --git a/tests/realtime/test_live.py b/tests/realtime/test_live.py new file mode 100644 index 00000000..7f61d58b --- /dev/null +++ b/tests/realtime/test_live.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import AsyncMock, MagicMock +from veadk.realtime.live import DoubaoAsyncSession, ProtocolEvents +from veadk.realtime import protocol +from google.genai import types + + +@pytest.fixture +def mock_ws(): + ws = AsyncMock() + ws.recv = AsyncMock() + ws.send = AsyncMock() + ws.response = MagicMock() + ws.response.headers = {"X-Tt-Logid": "test-logid"} + return ws + + +@pytest.fixture +def mock_api_client(): + client = MagicMock() + client._websocket_ssl_ctx = {} + return client + + +@pytest.fixture +def mock_session(mock_ws, mock_api_client): + return DoubaoAsyncSession( + api_client=mock_api_client, websocket=mock_ws, session_id="test-session-id" + ) + + +@pytest.mark.asyncio +async def test_send_realtime_input(mock_session): + # Test with media input + media = types.Blob(data=b"test-data", mime_type="audio/pcm") + await mock_session.send_realtime_input(media=media) + + # Verify the message was constructed and sent correctly + assert mock_session._ws.send.called + + # Test with multiple arguments (should raise error) + with pytest.raises(ValueError): + await mock_session.send_realtime_input(media=media, text="test") + + +@pytest.mark.asyncio +async def test_receive(mock_session): + # Mock different response types + test_cases = [ + ( + {"event": ProtocolEvents.ASR_INFO}, + { + "message_type": "SERVER_FULL_RESPONSE", + "event": ProtocolEvents.ASR_INFO, + "payload_msg": {"asr_task_id": "test_id"}, + }, + True, + ), # ASR_INFO + ( + {"event": ProtocolEvents.ASR_RESPONSE}, + { + "message_type": "SERVER_FULL_RESPONSE", + "event": ProtocolEvents.ASR_RESPONSE, + "payload_msg": {"results": [{"text": "test"}]}, + }, + "test", + ), + # ASR_RESPONSE + ( + {"event": ProtocolEvents.TTS_RESPONSE}, + { + "message_type": "SERVER_FULL_RESPONSE", + "event": ProtocolEvents.TTS_RESPONSE, + "payload_msg": b"audio-data", + }, + b"audio-data", + ), # TTS_RESPONSE + ( + {"event": ProtocolEvents.CHAT_RESPONSE}, + { + "message_type": "SERVER_FULL_RESPONSE", + "event": ProtocolEvents.CHAT_RESPONSE, + "payload_msg": {"content": "chat"}, + }, + "chat", + ), # CHAT_RESPONSE + ( + {"event": ProtocolEvents.USAGE_RESPONSE}, + { + "message_type": "SERVER_FULL_RESPONSE", + "event": ProtocolEvents.USAGE_RESPONSE, + "payload_msg": {"usage": {"cached_1": 10, "cached_2": 20, "other": 5}}, + }, + 35, + ), # USAGE_RESPONSE + ] + + for response_data, parse_data, expected in test_cases: + mock_session._ws.recv = AsyncMock(return_value=response_data) + protocol.parse_response = MagicMock(return_value=parse_data) + async for msg in mock_session.receive(): + if response_data["event"] == ProtocolEvents.ASR_INFO: + assert msg.server_content.interrupted == expected + elif response_data["event"] == ProtocolEvents.ASR_RESPONSE: + assert msg.server_content.input_transcription.text == expected + elif response_data["event"] == ProtocolEvents.TTS_RESPONSE: + assert ( + msg.server_content.model_turn.parts[0].inline_data.data == expected + ) + elif response_data["event"] == ProtocolEvents.CHAT_RESPONSE: + assert msg.server_content.output_transcription.text == expected + elif response_data["event"] == ProtocolEvents.USAGE_RESPONSE: + assert msg.usage_metadata.tool_use_prompt_token_count == expected + break + + +@pytest.mark.asyncio +async def test_convert_to_live_server_message(mock_session): + # Test ASR_INFO event + response = { + "event": ProtocolEvents.ASR_INFO, + "payload_msg": {"asr_task_id": "test_id"}, + } + result = mock_session.convert_to_live_server_message(response) + assert result.server_content.interrupted + + # Test ASR_RESPONSE event + response = { + "event": ProtocolEvents.ASR_RESPONSE, + "payload_msg": {"results": [{"text": "test"}]}, + } + result = mock_session.convert_to_live_server_message(response) + assert result.server_content.input_transcription.text == "test" + + # Test TTS_RESPONSE event + response = {"event": ProtocolEvents.TTS_RESPONSE, "payload_msg": b"audio-data"} + result = mock_session.convert_to_live_server_message(response) + assert result.server_content.model_turn.parts[0].inline_data.data == b"audio-data" + + # Test CHAT_ENDED event + response = { + "event": ProtocolEvents.CHAT_ENDED, + "payload_msg": {"results": [{"text": "test"}]}, + } + result = mock_session.convert_to_live_server_message(response) + assert result.server_content.output_transcription.finished + + # Test TTS_ENDED event + response = { + "event": ProtocolEvents.TTS_ENDED, + "payload_msg": {"results": [{"text": "test"}]}, + } + result = mock_session.convert_to_live_server_message(response) + assert result.server_content.turn_complete diff --git a/tests/realtime/test_protocol.py b/tests/realtime/test_protocol.py new file mode 100644 index 00000000..06119574 --- /dev/null +++ b/tests/realtime/test_protocol.py @@ -0,0 +1,228 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import gzip +import json +from veadk.realtime import protocol + + +class TestProtocolFunctions(unittest.TestCase): + def test_generate_header_default(self): + """Test generate_header with default parameters""" + header = protocol.generate_header() + self.assertEqual(len(header), 4) # Default header size is 1 (4 bytes) + self.assertEqual(header[0] >> 4, protocol.PROTOCOL_VERSION) + self.assertEqual(header[0] & 0x0F, 1) # header_size + self.assertEqual(header[1] >> 4, protocol.CLIENT_FULL_REQUEST) + self.assertEqual(header[1] & 0x0F, protocol.MSG_WITH_EVENT) + self.assertEqual(header[2] >> 4, protocol.JSON) + self.assertEqual(header[2] & 0x0F, protocol.GZIP) + self.assertEqual(header[3], 0x00) + + def test_generate_header_with_extension(self): + """Test generate_header with extension header""" + extension = b"\x01\x02\x03\x04" + header = protocol.generate_header(extension_header=extension) + self.assertEqual(len(header), 8) # header_size=2 (8 bytes) + self.assertEqual(header[0] & 0x0F, 2) # header_size + self.assertEqual(header[4:], extension) + + def test_generate_header_various_combinations(self): + """Test generate_header with various parameter combinations""" + # Test different message types + for msg_type in [ + protocol.CLIENT_FULL_REQUEST, + protocol.CLIENT_AUDIO_ONLY_REQUEST, + ]: + header = protocol.generate_header(message_type=msg_type) + self.assertEqual(header[1] >> 4, msg_type) + + # Test different flags + for flag in [ + protocol.NO_SEQUENCE, + protocol.POS_SEQUENCE, + protocol.NEG_SEQUENCE, + protocol.MSG_WITH_EVENT, + ]: + header = protocol.generate_header(message_type_specific_flags=flag) + self.assertEqual(header[1] & 0x0F, flag) + + # Test different serialization methods + for serial in [ + protocol.NO_SERIALIZATION, + protocol.JSON, + protocol.THRIFT, + protocol.CUSTOM_TYPE, + ]: + header = protocol.generate_header(serial_method=serial) + self.assertEqual(header[2] >> 4, serial) + + # Test different compression types + for comp in [ + protocol.NO_COMPRESSION, + protocol.GZIP, + protocol.CUSTOM_COMPRESSION, + ]: + header = protocol.generate_header(compression_type=comp) + self.assertEqual(header[2] & 0x0F, comp) + + def test_parse_response_invalid_input(self): + """Test parse_response with invalid inputs""" + # Test with string input + self.assertEqual(protocol.parse_response("invalid"), {}) + + # Test with too short response + self.assertEqual( + protocol.parse_response(b"\x01"), {"error": "Response too short"} + ) + + # Test with invalid header size + invalid_header = ( + bytes([(protocol.PROTOCOL_VERSION << 4) | 0x00]) + b"\x00\x00\x00" + ) # header_size=0 + self.assertEqual( + protocol.parse_response(invalid_header), {"error": "Invalid header size: 0"} + ) + + # Test with response shorter than header indicates + short_response = ( + bytes([(protocol.PROTOCOL_VERSION << 4) | 0x02]) + b"\x00\x00\x00" + ) # header_size=2 but only 1 byte + self.assertEqual( + protocol.parse_response(short_response), + {"error": "Response shorter than header indicates"}, + ) + + def test_parse_response_server_full_response(self): + """Test parse_response with SERVER_FULL_RESPONSE""" + # Create test data + test_data = json.dumps({"key": "value"}).encode("utf-8") + compressed_data = gzip.compress(test_data) + + # Build response + header = bytes( + [ + (protocol.PROTOCOL_VERSION << 4) | 0x01, # version + header_size=1 + (protocol.SERVER_FULL_RESPONSE << 4) + | (protocol.NEG_SEQUENCE | protocol.MSG_WITH_EVENT), # type + flags + (protocol.JSON << 4) | protocol.GZIP, # serial + compression + 0x00, # reserved + ] + ) + + # Add payload + seq_num = 1234 + event = 5678 + session_id = b"session123" + payload = ( + seq_num.to_bytes(4, "big") + + event.to_bytes(4, "big") + + len(session_id).to_bytes(4, "big", signed=True) + + session_id + + len(compressed_data).to_bytes(4, "big") + + compressed_data + ) + response = header + payload + + # Parse and verify + result = protocol.parse_response(response) + self.assertEqual(result["message_type"], "SERVER_FULL_RESPONSE") + self.assertEqual(result["seq"], seq_num) + # self.assertEqual(result['event'], event) + self.assertEqual(result["session_id"], "b'session123'") + self.assertEqual(result["payload_size"], len(compressed_data)) + self.assertEqual(result["payload_msg"], {"key": "value"}) + + def test_parse_response_server_ack(self): + """Test parse_response with SERVER_ACK""" + # Build response with no sequence, no event + header = bytes( + [ + (protocol.PROTOCOL_VERSION << 4) | 0x01, + (protocol.SERVER_ACK << 4) | protocol.NO_SEQUENCE, # type + flags + (protocol.JSON << 4) | protocol.NO_COMPRESSION, + 0x00, + ] + ) + + session_id = b"session456" + test_data = json.dumps({"status": "ok"}).encode("utf-8") + payload = ( + len(session_id).to_bytes(4, "big", signed=True) + + session_id + + len(test_data).to_bytes(4, "big") + + test_data + ) + response = header + payload + + result = protocol.parse_response(response) + self.assertEqual(result["message_type"], "SERVER_ACK") + self.assertNotIn("seq", result) + self.assertNotIn("event", result) + self.assertEqual(result["session_id"], "b'session456'") + self.assertEqual(result["payload_msg"], {"status": "ok"}) + + def test_parse_response_server_error(self): + """Test parse_response with SERVER_ERROR_RESPONSE""" + header = bytes( + [ + (protocol.PROTOCOL_VERSION << 4) | 0x01, + (protocol.SERVER_ERROR_RESPONSE << 4) | 0x00, + (protocol.JSON << 4) | protocol.NO_COMPRESSION, + 0x00, + ] + ) + + error_code = 404 + error_msg = json.dumps({"error": "Not found"}).encode("utf-8") + payload = ( + error_code.to_bytes(4, "big") + + len(error_msg).to_bytes(4, "big") + + error_msg + ) + response = header + payload + + result = protocol.parse_response(response) + self.assertEqual(result["code"], error_code) + self.assertEqual(result["payload_msg"], {"error": "Not found"}) + self.assertEqual(result["payload_size"], len(error_msg)) + + def test_parse_response_no_serialization(self): + """Test parse_response with NO_SERIALIZATION""" + header = bytes( + [ + (protocol.PROTOCOL_VERSION << 4) | 0x01, + (protocol.SERVER_FULL_RESPONSE << 4) | 0x00, + (protocol.NO_SERIALIZATION << 4) | protocol.NO_COMPRESSION, + 0x00, + ] + ) + + session_id = b"session456" + test_data = b"raw binary data" + payload = ( + len(session_id).to_bytes(4, "big", signed=True) + + session_id + + len(test_data).to_bytes(4, "big") + + test_data + ) + response = header + payload + + result = protocol.parse_response(response) + self.assertEqual(result["payload_msg"], test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/veadk/config.py b/veadk/config.py index e4c308b9..d053a833 100644 --- a/veadk/config.py +++ b/veadk/config.py @@ -19,6 +19,7 @@ from pydantic import BaseModel, Field from veadk.configs.auth_configs import VeIdentityConfig +from veadk.configs.model_configs import RealtimeModelConfig from veadk.configs.database_configs import ( MysqlConfig, OpensearchConfig, @@ -70,6 +71,7 @@ class VeADKConfig(BaseModel): ) veidentity: VeIdentityConfig = Field(default_factory=VeIdentityConfig) + realtime_model: RealtimeModelConfig = Field(default_factory=RealtimeModelConfig) def getenv( diff --git a/veadk/configs/model_configs.py b/veadk/configs/model_configs.py index ebc4b9a6..de66505f 100644 --- a/veadk/configs/model_configs.py +++ b/veadk/configs/model_configs.py @@ -18,6 +18,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from veadk.auth.veauth.ark_veauth import get_ark_token +from veadk.auth.veauth.speech_veauth import get_speech_token from veadk.consts import ( DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_MODEL_AGENT_NAME, @@ -76,3 +77,17 @@ class NormalEmbeddingModelConfig(BaseSettings): """The api base of the model for embedding.""" api_key: str + + +class RealtimeModelConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="MODEL_REALTIME_") + + name: str = "doubao_realtime_voice_model" + """Model name for realtime.""" + + api_base: str = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + """The api base of the model for realtime.""" + + @cached_property + def api_key(self) -> str: + return os.getenv("MODEL_REALTIME_API_KEY") or get_speech_token() diff --git a/veadk/realtime/__init__.py b/veadk/realtime/__init__.py new file mode 100644 index 00000000..2fbdda15 --- /dev/null +++ b/veadk/realtime/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .doubao_realtime_voice_llm import DoubaoRealtimeVoice +from google.adk.models.registry import LLMRegistry + +__all__ = [ + "DoubaoRealtimeVoice", +] + +LLMRegistry.register(DoubaoRealtimeVoice) diff --git a/veadk/realtime/client.py b/veadk/realtime/client.py new file mode 100644 index 00000000..0843c9d6 --- /dev/null +++ b/veadk/realtime/client.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from google.genai.client import Client, AsyncClient +from google.genai._api_client import BaseApiClient +from .live import DoubaoAsyncLive +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + + +class DoubaoAsyncClient(AsyncClient): + """Client for making asynchronous (non-blocking) requests.""" + + def __init__(self, api_client: BaseApiClient): + super().__init__(api_client) + self._live = DoubaoAsyncLive(self._api_client) + + @property + def live(self) -> DoubaoAsyncLive: + return self._live + + +class DoubaoClient(Client): + """The synchronous client for doubao realtime voice model, with async support via the aio property.""" + + def __init__(self, *args, **kwargs): + # Temporary workaround to set Google API key for Gemini client + if not os.environ.get("GOOGLE_API_KEY"): + os.environ["GOOGLE_API_KEY"] = "hack_google_api_key" + try: + super().__init__(*args, **kwargs) + self._aio = DoubaoAsyncClient(self._api_client) + except Exception as e: + logger.info(f"Failed to initialize DoubaoAsyncClient: {e}") + self._aio = None + + @property + def aio(self) -> DoubaoAsyncClient: + return self._aio diff --git a/veadk/realtime/doubao_realtime_voice_llm.py b/veadk/realtime/doubao_realtime_voice_llm.py new file mode 100644 index 00000000..a7b0c16b --- /dev/null +++ b/veadk/realtime/doubao_realtime_voice_llm.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +from functools import cached_property +import os +import sys + +from typing import Optional +from typing import TYPE_CHECKING + + +from .client import DoubaoClient +from google.genai import types +from typing_extensions import override +from google.adk import version +from google.adk.models.google_llm import Gemini +from google.adk.models.base_llm_connection import BaseLlmConnection +from .doubao_realtime_voice_llm_connection import DoubaoRealtimeVoiceLlmConnection + + +if TYPE_CHECKING: + from google.adk.models.llm_request import LlmRequest + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +_NEW_LINE = "\n" +_EXCLUDED_PART_FIELD = {"inline_data": {"data"}} +_AGENT_ENGINE_TELEMETRY_TAG = "remote_reasoning_engine" +_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = "VOLCENGINE_CLOUD_AGENT_ENGINE_ID" + + +class DoubaoRealtimeVoice(Gemini): + """Integration for doubao realtime voice model. + + Attributes: + model: The name of the doubao realtime voice model (default: 'doubao_realtime_voice'). + speech_config: Optional speech configuration for voice input/output (type: google.genai.types.SpeechConfig). + retry_options: Optional HTTP retry configuration for failed requests (type: google.genai.types.HttpRetryOptions). + """ + + model: str = "doubao_realtime_voice" + + speech_config: Optional[types.SpeechConfig] = None + + retry_options: Optional[types.HttpRetryOptions] = None + """Allow doubao realtime voice model to retry failed responses. + + Sample: + ```python + from google.genai import types + + # ... + + agent = Agent( + model=DoubaoRealtimeVoice( + retry_options=types.HttpRetryOptions(initial_delay=1, attempts=2), + ) + ) + ``` + """ + + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + + return [ + r"doubao_realtime_voice.*", + r"Doubao_scene_SLM_Doubao_realtime_voice_model.*", + ] + + @cached_property + def api_client(self) -> DoubaoClient: + """Provides the api client. + + Returns: + The api client. + """ + return DoubaoClient( + http_options=types.HttpOptions( + headers=self._tracking_headers, + retry_options=self.retry_options, + ) + ) + + @cached_property + def _live_api_client(self) -> DoubaoClient: + return DoubaoClient( + http_options=types.HttpOptions( + headers=self._tracking_headers, api_version=self._live_api_version + ) + ) + + @cached_property + def _tracking_headers(self) -> dict[str, str]: + framework_label = f"veadk/{version.__version__}" + if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME): + framework_label = f"{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}" + language_label = "ve-python/" + sys.version.split()[0] + version_header_value = f"{framework_label} {language_label}" + tracking_headers = { + "x-volcengine-api-client": version_header_value, + "user-agent": version_header_value, + } + return tracking_headers + + @contextlib.asynccontextmanager + async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: + """Connects to the doubao realtime voice LLM model and returns an llm connection. + + Args: + llm_request: LlmRequest, the request to send to the Seed LLM model. + + Yields: + BaseLlmConnection, the connection to the Seed LLM model. + """ + # add tracking headers to custom headers and set api_version given + # the customized http options will override the one set in the api client + # constructor + if ( + llm_request.live_connect_config + and llm_request.live_connect_config.http_options + ): + if not llm_request.live_connect_config.http_options.headers: + llm_request.live_connect_config.http_options.headers = {} + llm_request.live_connect_config.http_options.headers.update( + self._tracking_headers + ) + llm_request.live_connect_config.http_options.api_version = ( + self._live_api_version + ) + + if self.speech_config is not None: + llm_request.live_connect_config.speech_config = self.speech_config + + llm_request.live_connect_config.system_instruction = types.Content( + role="system", + parts=[types.Part.from_text(text=llm_request.config.system_instruction)], + ) + llm_request.live_connect_config.tools = llm_request.config.tools + logger.info("Connecting to live with llm_request:%s", llm_request) + async with self._live_api_client.aio.live.connect( + model=llm_request.model, config=llm_request.live_connect_config + ) as live_session: + # use DoubaoRealtimeVoiceLlmConnection in place of GeminiLlmConnection + yield DoubaoRealtimeVoiceLlmConnection(live_session) diff --git a/veadk/realtime/doubao_realtime_voice_llm_connection.py b/veadk/realtime/doubao_realtime_voice_llm_connection.py new file mode 100644 index 00000000..83ef4e99 --- /dev/null +++ b/veadk/realtime/doubao_realtime_voice_llm_connection.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Union + +from google.adk.models.gemini_llm_connection import GeminiLlmConnection +from google.genai import types +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +RealtimeInput = Union[types.Blob, types.ActivityStart, types.ActivityEnd] + + +class DoubaoRealtimeVoiceLlmConnection(GeminiLlmConnection): + """The doubao realtime voice model connection.""" + + async def send_realtime(self, input: RealtimeInput): + """Sends a chunk of audio or a frame of video to the model in realtime. + + Args: + input: The input to send to the model. + """ + if isinstance(input, types.Blob): + # The blob is binary and is very large. So let's not log it. + # logger.debug('Sending LLM Blob.') + # bugfix: 'error': 'decode ws request failed: unsupported protocol version 7' + await self._gemini_session.send_realtime_input(media=input) + + elif isinstance(input, types.ActivityStart): + logger.debug("Sending LLM activity start signal.") + await self._gemini_session.send_realtime_input(activity_start=input) + elif isinstance(input, types.ActivityEnd): + logger.debug("Sending LLM activity end signal.") + await self._gemini_session.send_realtime_input(activity_end=input) + else: + raise ValueError("Unsupported input type: %s" % type(input)) diff --git a/veadk/realtime/live.py b/veadk/realtime/live.py new file mode 100644 index 00000000..449a43c0 --- /dev/null +++ b/veadk/realtime/live.py @@ -0,0 +1,471 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import uuid +import gzip +import json +from . import protocol +from typing import Any, AsyncIterator, Optional +from google.genai.live import AsyncLive, AsyncSession +from google.genai import _common +import google.genai.types as types +from veadk.types import RealtimeVoiceConnectConfig +from veadk.config import getenv, settings + +try: + from websockets.asyncio.client import connect as ws_connect +except ModuleNotFoundError: + # This try/except is for TAP, mypy complains about it which is why we have the type: ignore + from websockets.client import connect as ws_connect # type: ignore + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + + +class ProtocolEvents: + ASR_INFO = 450 + ASR_RESPONSE = 451 + ASR_ENDED = 459 + TTS_SENTENCE_START = 350 + TTS_RESPONSE = 352 + TTS_SENTENCE_END = 351 + TTS_ENDED = 359 + USAGE_RESPONSE = 154 + CHAT_RESPONSE = 550 + CHAT_ENDED = 559 + + +class ProtocolConstants: + RESOURCE_ID = "volc.speech.dialog" + APP_KEY = "PlgvMymc7f3tQnJ6" + DEFAULT_SPEAKER = "zh_male_yunzhou_jupiter_bigtts" + DEFAULT_SYSTEM_ROLE = ( + "You use a lively female voice, have an outgoing personality, and love life." + ) + + +class RequestConstants: + REQ_ASR_END_SMOOTH_WINDOW_MS = 1500 + REQ_TTS_CHANNEL = 1 + REQ_TTS_SAMPLE_RATE = 24000 + REQ_DIALOG_BOT_NAME = "doubao" + REQ_DIALOG_SPEAKING_STYLE = "Your speaking style is concise and clear, with a moderate pace and natural intonation." + REQ_DIALOG_AUDIT_RESPONSE = "Support customize security audit response scripts。" + REQ_DIALOG_RECV_TIMEOUT = 10 + + +class DoubaoAsyncSession(AsyncSession): + """[Preview] AsyncSession.""" + + async def send_realtime_input( + self, + *, + media: Optional[types.BlobImageUnionDict] = None, + audio: Optional[types.BlobOrDict] = None, + audio_stream_end: Optional[bool] = None, + video: Optional[types.BlobImageUnionDict] = None, + text: Optional[str] = None, + activity_start: Optional[types.ActivityStartOrDict] = None, + activity_end: Optional[types.ActivityEndOrDict] = None, + ) -> None: + """Send realtime input to the model, only send one argument per call. + + Use `send_realtime_input` for realtime audio chunks and video + frames(images). + + With `send_realtime_input` the api will respond to audio automatically + based on voice activity detection (VAD). + + `send_realtime_input` is optimized for responsivness at the expense of + deterministic ordering. Audio and video tokens are added to the + context when they become available. + + Args: + media: A `Blob`-like object, the realtime media to send. + + Example: + + .. code-block:: python + + from pathlib import Path + + from google import genai + from google.genai import types + + import PIL.Image + + import os + + if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'): + MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09' + else: + MODEL_NAME = 'gemini-live-2.5-flash-preview'; + + + client = genai.Client() + + async with client.aio.live.connect( + model=MODEL_NAME, + config={"response_modalities": ["TEXT"]}, + ) as session: + await session.send_realtime_input( + media=PIL.Image.open('image.jpg')) + + audio_bytes = Path('audio.pcm').read_bytes() + await session.send_realtime_input( + media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000')) + + async for msg in session.receive(): + if msg.text is not None: + print(f'{msg.text}') + """ + kwargs: _common.StringDict = {} + if media is not None: + kwargs["media"] = media + if audio is not None: + kwargs["audio"] = audio + if audio_stream_end is not None: + kwargs["audio_stream_end"] = audio_stream_end + if video is not None: + kwargs["video"] = video + if text is not None: + kwargs["text"] = text + if activity_start is not None: + kwargs["activity_start"] = activity_start + if activity_end is not None: + kwargs["activity_end"] = activity_end + + if len(kwargs) != 1: + raise ValueError( + f"Only one argument can be set, got {len(kwargs)}:" + f" {list(kwargs.keys())}" + ) + + task_request = bytearray( + protocol.generate_header( + message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST, + serial_method=protocol.NO_SERIALIZATION, + ) + ) + task_request.extend(int(200).to_bytes(4, "big")) + task_request.extend((len(self.session_id)).to_bytes(4, "big")) + task_request.extend(str.encode(self.session_id)) + payload_bytes = gzip.compress(media.data) + task_request.extend( + (len(payload_bytes)).to_bytes(4, "big") + ) # payload size(4 bytes) + task_request.extend(payload_bytes) + await self._ws.send(task_request) + + async def receive(self) -> AsyncIterator[types.LiveServerMessage]: + """Receive model responses from the server. + + The method will yield the model responses from the server. The returned + responses will represent a complete model turn. When the returned message + is function call, user must call `send` with the function response to + continue the turn. + + Yields: + The model responses from the server. + + Example usage: + + .. code-block:: python + + client = genai.Client(api_key=API_KEY) + + async with client.aio.live.connect(model='...') as session: + await session.send(input='Hello world!', end_of_turn=True) + async for message in session.receive(): + print(message) + """ + # TODO(b/365983264) Handle intermittent issues for the user. + while result := await self._receive(): + # todo + # if result.server_content and result.server_content.turn_complete: + # yield result + # break + yield result + + async def _receive(self) -> types.LiveServerMessage: + try: + raw_response = await self._ws.recv(decode=False) + except TypeError: + raw_response = await self._ws.recv() # type: ignore[assignment] + if raw_response: + try: + response = protocol.parse_response(raw_response) + logger.debug(f"receive llm response: {response}") + except Exception: + raise ValueError(f"Failed to parse raw response: {raw_response!r}") + else: + response = {} + + return self.convert_to_live_server_message(response) + + def convert_to_live_server_message( + self, response: dict[str, Any] + ) -> types.LiveServerMessage: + """Converts a raw response to a LiveServerMessage. + + Args: + response: The raw response from the server. + + Returns: + The converted LiveServerMessage. + """ + + """ + msg = { + "server_content": { + "model_turn": { + "parts": [{ + "inline_data": "", + "text": "" + }], + "role": "model" + }, + "turn_complete": False, + "interrupted": False, + "input_transcription": { + "text": "", + "finished": False + }, + "output_transcription": { + "text": "", + "finished": False + } + }, + "usage_metadata": {} + } + """ + parameter_model = types.LiveServerMessage() + model_turn = {} + server_content = {} + usage_metadata = {} + output_ranscription = {} + + if "event" in response and "payload_msg" in response: + message = response.get("payload_msg") + if response.get("event") == ProtocolEvents.ASR_INFO: + # ASRInfo + # The model recognizes the event returned by the first character in the audio stream, + # which is used to interrupt the client's broadcast + server_content["interrupted"] = True + elif ( + response.get("event") == ProtocolEvents.ASR_RESPONSE + and "results" in message + ): + # ASRResponse + # The ASR Response model identifies the textual content of a user's speech + server_content["inputTranscription"] = { + "text": message.get("results")[0].get("text"), + "finished": True, + } + elif response.get("event") == ProtocolEvents.ASR_ENDED: + # ASREnded + # The model considers the event where the user's speech ends + logger.debug("ASREnded msg: %s", message) + elif response.get("event") == ProtocolEvents.TTS_SENTENCE_START: + # TTSSentenceStart + logger.debug("TTSSentenceStart msg: %s", message) + elif response.get("event") == ProtocolEvents.TTS_RESPONSE: + # TTSResponse + # Return the audio data generated by the model, and load the binary audio data into the payload + model_turn["parts"] = [{"inlineData": {"data": message}}] + server_content["modelTurn"] = model_turn + elif response.get("event") == ProtocolEvents.TTS_SENTENCE_END: + # TTSSentenceEnd + logger.debug("TTSSentenceEnd msg: %s", message) + elif response.get("event") == ProtocolEvents.USAGE_RESPONSE: + # UsageResponse + # Usage information corresponding to each round of interaction + def sum_cached_tokens(d): + return sum(v for k, v in d.items() if k.startswith(("cached_"))) + + usage_metadata["tool_use_prompt_token_count"] = ( + lambda d: sum(d.values()) + )(message.get("usage")) + usage_metadata["cached_content_token_count"] = sum_cached_tokens( + message.get("usage") + ) + elif response.get("event") == ProtocolEvents.CHAT_RESPONSE: + # ChatResponse + # The text content replied by the model needs to be concatenated + output_ranscription["text"] = message.get("content") + server_content["output_transcription"] = output_ranscription + elif response.get("event") == ProtocolEvents.CHAT_ENDED: + # ChatEnded + # End event of model reply text + output_ranscription["finished"] = True + server_content["output_transcription"] = output_ranscription + elif response.get("event") == ProtocolEvents.TTS_ENDED: + # TTSEnded + # End event of synthesized audio + server_content["turnComplete"] = True + + return types.LiveServerMessage._from_response( + response={"serverContent": server_content, "usageMetadata": usage_metadata}, + kwargs=parameter_model.model_dump(), + ) + + +class DoubaoAsyncLive(AsyncLive): + """[Preview] AsyncLive for doubao realtime voice model.""" + + @contextlib.asynccontextmanager + async def connect( + self, + *, + model: str, + config: Optional[types.LiveConnectConfigOrDict] = None, + ) -> AsyncIterator[DoubaoAsyncSession]: + """[Preview] Connect to the live server. + + Note: the live API is currently in preview. + + Usage: + + .. code-block:: python + + client = genai.Client(api_key=API_KEY) + config = {} + async with client.aio.live.connect(model='...', config=config) as session: + await session.send_client_content( + turns=types.Content( + role='user', + parts=[types.Part(text='hello!')] + ), + turn_complete=True + ) + async for message in session.receive(): + print(message) + + Args: + model: The model to use for the live session. + config: The configuration for the live session. + **kwargs: additional keyword arguments. + + Yields: + An AsyncSession object. + """ + # TODO(b/404946570): Support per request http options. + if isinstance(config, dict): + config = RealtimeVoiceConnectConfig(**config) + + api_key = settings.realtime_model.api_key + api_base = settings.realtime_model.api_base + app_id = getenv("MODEL_REALTIME_APP_ID") + speaker = getenv("MODEL_REALTIME_TTS_SPEAKER", "zh_male_yunzhou_jupiter_bigtts") + + system_role = "You use a lively female voice, have an outgoing personality, and love life." + if ( + config + and hasattr(config, "system_instruction") + and config.system_instruction + and config.system_instruction.parts + ): + system_role = config.system_instruction.parts[0].text + + headers = { + "X-Api-App-ID": app_id, + "X-Api-Access-Key": api_key, + "X-Api-Resource-Id": ProtocolConstants.RESOURCE_ID, # fixed value + "X-Api-App-Key": ProtocolConstants.APP_KEY, # fixed value + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + start_session_req = { + "asr": { + "extra": { + "end_smooth_window_ms": RequestConstants.REQ_ASR_END_SMOOTH_WINDOW_MS, + }, + }, + "tts": { + "speaker": speaker, + "audio_config": { + "channel": RequestConstants.REQ_TTS_CHANNEL, + "format": "pcm_s16le", # default: pcm_f32le + "sample_rate": RequestConstants.REQ_TTS_SAMPLE_RATE, + }, + }, + "dialog": { + "bot_name": RequestConstants.REQ_DIALOG_BOT_NAME, + "system_role": system_role, + "speaking_style": RequestConstants.REQ_DIALOG_SPEAKING_STYLE, + "extra": { + "strict_audit": False, + "audit_response": RequestConstants.REQ_DIALOG_AUDIT_RESPONSE, + "recv_timeout": RequestConstants.REQ_DIALOG_RECV_TIMEOUT, + "input_mod": "audio", + }, + }, + } + + async with ws_connect( + api_base, additional_headers=headers, **self._api_client._websocket_ssl_ctx + ) as ws: + logid = ws.response.headers.get("X-Tt-Logid") + logger.info(f"dialog server response logid: {logid}") + + # StartConnection request + start_connection_request = bytearray(protocol.generate_header()) + start_connection_request.extend(int(1).to_bytes(4, "big")) + payload_bytes = str.encode("{}") + payload_bytes = gzip.compress(payload_bytes) + start_connection_request.extend((len(payload_bytes)).to_bytes(4, "big")) + start_connection_request.extend(payload_bytes) + + await ws.send(start_connection_request) + + try: + # websockets 14.0+ + raw_response = await ws.recv() + logger.info( + f"StartConnection response: {protocol.parse_response(raw_response)}" + ) + response_result = protocol.parse_response(raw_response) + session_id = response_result.get("session_id") + recv_timeout = 120 + mod = "audio" + # Expanding this parameter can maintain silence for a period of time, + # mainly used for text mode, with a parameter range of [10, 120] + start_session_req["dialog"]["extra"]["recv_timeout"] = recv_timeout + # This parameter can remain silent for a period of time in either text or audio_file mode + start_session_req["dialog"]["extra"]["input_mod"] = mod + # StartSession request + request_params = start_session_req + payload_bytes = str.encode(json.dumps(request_params)) + payload_bytes = gzip.compress(payload_bytes) + start_session_request = bytearray(protocol.generate_header()) + start_session_request.extend(int(100).to_bytes(4, "big")) + start_session_request.extend((len(session_id)).to_bytes(4, "big")) + start_session_request.extend(str.encode(session_id)) + start_session_request.extend((len(payload_bytes)).to_bytes(4, "big")) + start_session_request.extend(payload_bytes) + await ws.send(start_session_request) + response = await ws.recv() + + logger.info( + f"StartSession response: {protocol.parse_response(response)}" + ) + except TypeError: + raw_response = await ws.recv() # type: ignore[assignment] + yield DoubaoAsyncSession( + api_client=self._api_client, + websocket=ws, + session_id=session_id, + ) diff --git a/veadk/realtime/protocol.py b/veadk/realtime/protocol.py new file mode 100644 index 00000000..31584721 --- /dev/null +++ b/veadk/realtime/protocol.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import json + +PROTOCOL_VERSION = 0b0001 +DEFAULT_HEADER_SIZE = 0b0001 + +PROTOCOL_VERSION_BITS = 4 +HEADER_BITS = 4 +MESSAGE_TYPE_BITS = 4 +MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 +MESSAGE_SERIALIZATION_BITS = 4 +MESSAGE_COMPRESSION_BITS = 4 +RESERVED_BITS = 8 + +# Message Type: +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +# Message Type Specific Flags +NO_SEQUENCE = 0b0000 # no check sequence +POS_SEQUENCE = 0b0001 +NEG_SEQUENCE = 0b0010 +NEG_SEQUENCE_1 = 0b0011 + +MSG_WITH_EVENT = 0b0100 + +# Message Serialization +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +THRIFT = 0b0011 +CUSTOM_TYPE = 0b1111 + +# Message Compression +NO_COMPRESSION = 0b0000 +GZIP = 0b0001 +CUSTOM_COMPRESSION = 0b1111 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes(), +): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + """ + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +def parse_response(res): + """ + - header + - (4bytes)header + - (4bits)version(v1) + (4bits)header_size + - (4bits)messageType + (4bits)messageTypeFlags + -- 0001 CompleteClient | -- 0001 hasSequence + -- 0010 audioonly | -- 0010 isTailPacket + | -- 0100 hasEvent + - (4bits)payloadFormat + (4bits)compression + - (8bits) reserve + - payload + - [optional 4 bytes] event + - [optional] session ID + -- (4 bytes)session ID len + -- session ID data + - (4 bytes)data len + - data + """ + if isinstance(res, str): + return {} + if len(res) < 4: + return {"error": "Response too short"} + + # protocol_version = res[0] >> 4 + header_size = res[0] & 0x0F + if header_size < 1 or header_size > 16: + return {"error": f"Invalid header size: {header_size}"} + + if len(res) < header_size * 4: + return {"error": "Response shorter than header indicates"} + + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0F + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0F + # reserved = res[3] + # header_extensions = res[4 : header_size * 4] + payload = res[header_size * 4 :] + result = {} + payload_msg = None + payload_size = 0 + start = 0 + if message_type == SERVER_FULL_RESPONSE or message_type == SERVER_ACK: + result["message_type"] = "SERVER_FULL_RESPONSE" + if message_type == SERVER_ACK: + result["message_type"] = "SERVER_ACK" + if message_type_specific_flags & NEG_SEQUENCE > 0: + result["seq"] = int.from_bytes(payload[:4], "big", signed=False) + start += 4 + if message_type_specific_flags & MSG_WITH_EVENT > 0: + result["event"] = int.from_bytes(payload[:4], "big", signed=False) + start += 4 + payload = payload[start:] + session_id_size = int.from_bytes(payload[:4], "big", signed=True) + session_id = payload[4 : session_id_size + 4] + result["session_id"] = str(session_id) + payload = payload[4 + session_id_size :] + payload_size = int.from_bytes(payload[:4], "big", signed=False) + payload_msg = payload[4:] + elif message_type == SERVER_ERROR_RESPONSE: + code = int.from_bytes(payload[:4], "big", signed=False) + result["code"] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + if payload_msg is None: + return result + if message_compression == GZIP: + payload_msg = gzip.decompress(payload_msg) + if serialization_method == JSON: + payload_msg = json.loads(str(payload_msg, "utf-8")) + elif serialization_method != NO_SERIALIZATION: + payload_msg = str(payload_msg, "utf-8") + result["payload_msg"] = payload_msg + result["payload_size"] = payload_size + return result diff --git a/veadk/types.py b/veadk/types.py index 2d1ba642..8a4cbb1f 100644 --- a/veadk/types.py +++ b/veadk/types.py @@ -19,6 +19,7 @@ from veadk.agents.parallel_agent import ParallelAgent from veadk.agents.sequential_agent import SequentialAgent from veadk.memory.short_term_memory import ShortTermMemory +from google.genai.types import LiveConnectConfig class MediaMessage(BaseModel): @@ -45,3 +46,7 @@ class AgentRunConfig(BaseModel): short_term_memory: ShortTermMemory = Field( default_factory=ShortTermMemory, description="The short-term memory instance" ) + + +class RealtimeVoiceConnectConfig(LiveConnectConfig): + """Configuration for connecting to the realtime voice model."""