Skip to content

Commit bfb38f3

Browse files
committed
fix bugs
1 parent 2f26882 commit bfb38f3

5 files changed

Lines changed: 58 additions & 16 deletions

File tree

veadk/configs/model_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class EmbeddingModelConfig(BaseSettings):
5151
dim: int = 2560
5252
"""Embedding dim is different from different models."""
5353

54-
api_base: str = "https://ark.cn-beijing.volces.com/api/v3/embeddings"
54+
api_base: str = "https://ark.cn-beijing.volces.com/api/v3/"
5555
"""The api base of the model for embedding."""
5656

5757
@cached_property

veadk/knowledgebase/backends/opensearch_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@ def add_from_text(self, text: str | list[str]) -> bool:
9999

100100
@override
101101
def search(self, query: str, top_k: int = 5) -> list[str]:
102-
retrieved_nodes = self._retriever.retrieve(query, top_k=top_k)
102+
_original_top_k = self._retriever.similarity_top_k # type: ignore
103+
self._retriever.similarity_top_k = top_k # type: ignore
104+
105+
retrieved_nodes = self._retriever.retrieve(query)
106+
107+
self._retriever.similarity_top_k = _original_top_k # type: ignore
103108
return [node.text for node in retrieved_nodes]
104109

105110
def _split_documents(self, documents: list[Document]) -> list[BaseNode]:

veadk/knowledgebase/backends/redis_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class RedisKnowledgeBackend(BaseKnowledgebaseBackend):
3939
redis_config: RedisConfig = Field(default_factory=RedisConfig)
4040
"""Redis client configs"""
4141

42-
embedding_config: EmbeddingModelConfig
42+
embedding_config: EmbeddingModelConfig = Field(default_factory=EmbeddingModelConfig)
4343
"""Embedding model configs"""
4444

4545
def model_post_init(self, __context: Any) -> None:
@@ -105,7 +105,12 @@ def add_from_text(self, text: str | list[str]) -> bool:
105105

106106
@override
107107
def search(self, query: str, top_k: int = 5) -> list[str]:
108+
_original_top_k = self._retriever.similarity_top_k # type: ignore
109+
self._retriever.similarity_top_k = top_k # type: ignore
110+
108111
retrieved_nodes = self._retriever.retrieve(query, top_k=top_k)
112+
113+
self._retriever.similarity_top_k = _original_top_k # type: ignore
109114
return [node.text for node in retrieved_nodes]
110115

111116
def _split_documents(self, documents: list[Document]) -> list[BaseNode]:

veadk/memory/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
from veadk.memory.long_term_memory import LongTermMemory
19+
from veadk.memory.short_term_memory import ShortTermMemory
20+
21+
22+
# Lazy loading for classes
23+
def __getattr__(name):
24+
if name == "ShortTermMemory":
25+
from veadk.memory.short_term_memory import ShortTermMemory
26+
27+
return ShortTermMemory
28+
if name == "LongTeremMemory":
29+
from veadk.memory.long_term_memory import LongTermMemory
30+
31+
return LongTermMemory
32+
raise AttributeError(f"module 'veadk.memory' has no attribute '{name}'")
33+
34+
35+
__all__ = ["ShortTermMemory", "LongTermMemory"]

veadk/memory/short_term_memory.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
from functools import wraps
1616
from typing import Any, Callable, Literal
1717

18-
from google.adk.sessions import DatabaseSessionService, InMemorySessionService
19-
from pydantic import BaseModel, Field
18+
from google.adk.sessions import (
19+
BaseSessionService,
20+
DatabaseSessionService,
21+
InMemorySessionService,
22+
)
23+
from pydantic import BaseModel, Field, PrivateAttr
2024

2125
from veadk.memory.short_term_memory_backends.mysql_backend import (
2226
MysqlSTMBackend,
@@ -61,10 +65,12 @@ class ShortTermMemory(BaseModel):
6165
after_load_memory_callback: Callable | None = None
6266
"""A callback to be called after loading memory from the backend. The callback function should accept `Session` as an input."""
6367

68+
_session_service: BaseSessionService = PrivateAttr()
69+
6470
def model_post_init(self, __context: Any) -> None:
6571
if self.db_url:
6672
logger.info("The `db_url` is set, ignore `backend` option.")
67-
self.session_service = DatabaseSessionService(db_url=self.db_url)
73+
self._session_service = DatabaseSessionService(db_url=self.db_url)
6874
else:
6975
if self.backend == "database":
7076
logger.warning(
@@ -73,37 +79,41 @@ def model_post_init(self, __context: Any) -> None:
7379
self.backend = "sqlite"
7480
match self.backend:
7581
case "local":
76-
self.session_service = InMemorySessionService()
82+
self._session_service = InMemorySessionService()
7783
case "mysql":
78-
self.session_service = MysqlSTMBackend(
84+
self._session_service = MysqlSTMBackend(
7985
**self.backend_configs
8086
).session_service
8187
case "sqlite":
82-
self.session_service = SQLiteSTMBackend(
88+
self._session_service = SQLiteSTMBackend(
8389
local_path=self.local_database_path
8490
).session_service
8591
case "redis":
86-
self.session_service = RedisSTMBackend(
92+
self._session_service = RedisSTMBackend(
8793
**self.backend_configs
8894
).session_service
8995
case "postgresql":
90-
self.session_service = PostgreSqlSTMBackend(
96+
self._session_service = PostgreSqlSTMBackend(
9197
**self.backend_configs
9298
).session_service
9399

94100
if self.after_load_memory_callback:
95101
wrap_get_session_with_callbacks(
96-
self.session_service, self.after_load_memory_callback
102+
self._session_service, self.after_load_memory_callback
97103
)
98104

105+
@property
106+
def session_service(self) -> BaseSessionService:
107+
return self._session_service
108+
99109
async def create_session(
100110
self,
101111
app_name: str,
102112
user_id: str,
103113
session_id: str,
104114
) -> None:
105-
if isinstance(self.session_service, DatabaseSessionService):
106-
list_sessions_response = await self.session_service.list_sessions(
115+
if isinstance(self._session_service, DatabaseSessionService):
116+
list_sessions_response = await self._session_service.list_sessions(
107117
app_name=app_name, user_id=user_id
108118
)
109119

@@ -112,12 +122,12 @@ async def create_session(
112122
)
113123

114124
if (
115-
await self.session_service.get_session(
125+
await self._session_service.get_session(
116126
app_name=app_name, user_id=user_id, session_id=session_id
117127
)
118128
is None
119129
):
120130
# create a new session for this running
121-
await self.session_service.create_session(
131+
await self._session_service.create_session(
122132
app_name=app_name, user_id=user_id, session_id=session_id
123133
)

0 commit comments

Comments
 (0)