1515from functools import wraps
1616from typing import Any , Callable , Literal
1717
18- from google .adk .sessions import DatabaseSessionService , InMemorySessionService
19- from pydantic import BaseModel , Field
18+ from google .adk .sessions import (
19+ BaseSessionService ,
20+ DatabaseSessionService ,
21+ InMemorySessionService ,
22+ )
23+ from pydantic import BaseModel , Field , PrivateAttr
2024
2125from veadk .memory .short_term_memory_backends .mysql_backend import (
2226 MysqlSTMBackend ,
@@ -61,10 +65,12 @@ class ShortTermMemory(BaseModel):
6165 after_load_memory_callback : Callable | None = None
6266 """A callback to be called after loading memory from the backend. The callback function should accept `Session` as an input."""
6367
68+ _session_service : BaseSessionService = PrivateAttr ()
69+
6470 def model_post_init (self , __context : Any ) -> None :
6571 if self .db_url :
6672 logger .info ("The `db_url` is set, ignore `backend` option." )
67- self .session_service = DatabaseSessionService (db_url = self .db_url )
73+ self ._session_service = DatabaseSessionService (db_url = self .db_url )
6874 else :
6975 if self .backend == "database" :
7076 logger .warning (
@@ -73,37 +79,41 @@ def model_post_init(self, __context: Any) -> None:
7379 self .backend = "sqlite"
7480 match self .backend :
7581 case "local" :
76- self .session_service = InMemorySessionService ()
82+ self ._session_service = InMemorySessionService ()
7783 case "mysql" :
78- self .session_service = MysqlSTMBackend (
84+ self ._session_service = MysqlSTMBackend (
7985 ** self .backend_configs
8086 ).session_service
8187 case "sqlite" :
82- self .session_service = SQLiteSTMBackend (
88+ self ._session_service = SQLiteSTMBackend (
8389 local_path = self .local_database_path
8490 ).session_service
8591 case "redis" :
86- self .session_service = RedisSTMBackend (
92+ self ._session_service = RedisSTMBackend (
8793 ** self .backend_configs
8894 ).session_service
8995 case "postgresql" :
90- self .session_service = PostgreSqlSTMBackend (
96+ self ._session_service = PostgreSqlSTMBackend (
9197 ** self .backend_configs
9298 ).session_service
9399
94100 if self .after_load_memory_callback :
95101 wrap_get_session_with_callbacks (
96- self .session_service , self .after_load_memory_callback
102+ self ._session_service , self .after_load_memory_callback
97103 )
98104
105+ @property
106+ def session_service (self ) -> BaseSessionService :
107+ return self ._session_service
108+
99109 async def create_session (
100110 self ,
101111 app_name : str ,
102112 user_id : str ,
103113 session_id : str ,
104114 ) -> None :
105- if isinstance (self .session_service , DatabaseSessionService ):
106- list_sessions_response = await self .session_service .list_sessions (
115+ if isinstance (self ._session_service , DatabaseSessionService ):
116+ list_sessions_response = await self ._session_service .list_sessions (
107117 app_name = app_name , user_id = user_id
108118 )
109119
@@ -112,12 +122,12 @@ async def create_session(
112122 )
113123
114124 if (
115- await self .session_service .get_session (
125+ await self ._session_service .get_session (
116126 app_name = app_name , user_id = user_id , session_id = session_id
117127 )
118128 is None
119129 ):
120130 # create a new session for this running
121- await self .session_service .create_session (
131+ await self ._session_service .create_session (
122132 app_name = app_name , user_id = user_id , session_id = session_id
123133 )
0 commit comments