Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"aiomysql>=0.3.2", # For async MySQL database (short term memory)
"opensearch-py==2.8.0",
"filetype>=1.2.0",
"vikingdb-python-sdk>=0.1.3",
"agentkit-sdk-python"
]

Expand Down
28 changes: 3 additions & 25 deletions veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def setHeader(self, header):
api_info[key].header[item] = header[item]
self.api_info = api_info

def get_host(self):
return self.service_info.host

@staticmethod
def get_service_info(host, region, scheme, connection_timeout, socket_timeout):
service_info = ServiceInfo(
Expand Down Expand Up @@ -281,28 +284,3 @@ def update_collection(
}
res = self.json("UpdateCollection", {}, json.dumps(params))
return json.loads(res)

def search_memory(self, collection_name, query, filter, limit=10):
params = {
"collection_name": collection_name,
"limit": limit,
"filter": filter,
}
if query:
params["query"] = query
res = self.json("SearchMemory", {}, json.dumps(params))
return json.loads(res)

def add_messages(
self, collection_name, session_id, messages, metadata, entities=None
):
params = {
"collection_name": collection_name,
"session_id": session_id,
"messages": messages,
"metadata": metadata,
}
if entities is not None:
params["entities"] = entities
res = self.json("AddMessages", {}, json.dumps(params))
return json.loads(res)
32 changes: 25 additions & 7 deletions veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from veadk.memory.long_term_memory_backends.base_backend import (
BaseLongTermMemoryBackend,
)
from vikingdb import IAM
from vikingdb.memory import VikingMem

from veadk.utils.logger import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -118,6 +121,18 @@ def _get_client(self) -> VikingDBMemoryClient:
region=self.region,
)

def _get_sdk_client(self) -> VikingMem:
client = self._get_client()
return VikingMem(
host=client.get_host(),
region=self.region,
auth=IAM(
ak=self.volcengine_access_key,
sk=self.volcengine_secret_key,
),
sts_token=self.session_token,
)

@override
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
assistant_id = kwargs.get("assistant_id", "assistant")
Expand All @@ -140,12 +155,12 @@ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
f"Request for add {len(messages)} memory to VikingDB: collection_name={self.index}, metadata={metadata}, session_id={session_id}"
)

client = self._get_client()
response = client.add_messages(
collection_name=self.index,
client = self._get_sdk_client()
collection = client.get_collection(collection_name=self.index)
response = collection.add_session(
session_id=session_id,
messages=messages,
metadata=metadata,
session_id=session_id,
)

logger.debug(f"Response from add memory to VikingDB: {response}")
Expand All @@ -165,9 +180,12 @@ def search_memory(
f"Request for search memory in VikingDB: filter={filter}, collection_name={self.index}, query={query}, limit={top_k}"
)

client = self._get_client()
response = client.search_memory(
collection_name=self.index, query=query, filter=filter, limit=top_k
client = self._get_sdk_client()
collection = client.get_collection(collection_name=self.index)
response = collection.search_memory(
query=query,
filter=filter,
limit=top_k,
)

logger.debug(f"Response from search memory in VikingDB: {response}")
Expand Down