diff --git a/pyproject.toml b/pyproject.toml index 20169609..58796189 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "aiomysql>=0.3.2", # For async MySQL database (short term memory) "opensearch-py==2.8.0", "filetype>=1.2.0", + "vikingdb-python-sdk>=0.1.3", "agentkit-sdk-python" ] diff --git a/veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py b/veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py index 0207cd32..04d1f481 100644 --- a/veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py +++ b/veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py @@ -95,6 +95,9 @@ def setHeader(self, header): api_info[key].header[item] = header[item] self.api_info = api_info + def get_host(self): + return self.service_info.host + @staticmethod def get_service_info(host, region, scheme, connection_timeout, socket_timeout): service_info = ServiceInfo( @@ -281,28 +284,3 @@ def update_collection( } res = self.json("UpdateCollection", {}, json.dumps(params)) return json.loads(res) - - def search_memory(self, collection_name, query, filter, limit=10): - params = { - "collection_name": collection_name, - "limit": limit, - "filter": filter, - } - if query: - params["query"] = query - res = self.json("SearchMemory", {}, json.dumps(params)) - return json.loads(res) - - def add_messages( - self, collection_name, session_id, messages, metadata, entities=None - ): - params = { - "collection_name": collection_name, - "session_id": session_id, - "messages": messages, - "metadata": metadata, - } - if entities is not None: - params["entities"] = entities - res = self.json("AddMessages", {}, json.dumps(params)) - return json.loads(res) diff --git a/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py b/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py index 427922e4..89df8088 100644 --- a/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +++ b/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py @@ -30,6 +30,9 @@ from veadk.memory.long_term_memory_backends.base_backend import ( BaseLongTermMemoryBackend, ) +from vikingdb import IAM +from vikingdb.memory import VikingMem + from veadk.utils.logger import get_logger logger = get_logger(__name__) @@ -118,6 +121,18 @@ def _get_client(self) -> VikingDBMemoryClient: region=self.region, ) + def _get_sdk_client(self) -> VikingMem: + client = self._get_client() + return VikingMem( + host=client.get_host(), + region=self.region, + auth=IAM( + ak=self.volcengine_access_key, + sk=self.volcengine_secret_key, + ), + sts_token=self.session_token, + ) + @override def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool: assistant_id = kwargs.get("assistant_id", "assistant") @@ -140,12 +155,12 @@ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool: f"Request for add {len(messages)} memory to VikingDB: collection_name={self.index}, metadata={metadata}, session_id={session_id}" ) - client = self._get_client() - response = client.add_messages( - collection_name=self.index, + client = self._get_sdk_client() + collection = client.get_collection(collection_name=self.index) + response = collection.add_session( + session_id=session_id, messages=messages, metadata=metadata, - session_id=session_id, ) logger.debug(f"Response from add memory to VikingDB: {response}") @@ -165,9 +180,12 @@ def search_memory( f"Request for search memory in VikingDB: filter={filter}, collection_name={self.index}, query={query}, limit={top_k}" ) - client = self._get_client() - response = client.search_memory( - collection_name=self.index, query=query, filter=filter, limit=top_k + client = self._get_sdk_client() + collection = client.get_collection(collection_name=self.index) + response = collection.search_memory( + query=query, + filter=filter, + limit=top_k, ) logger.debug(f"Response from search memory in VikingDB: {response}")