@@ -91,7 +91,6 @@ def __init__(
9191 agent : "Agent" ,
9292 host : str = "0.0.0.0" ,
9393 port : int = 8000 ,
94- short_term_memory : Optional [Any ] = None ,
9594 ):
9695 self .agent = agent
9796
@@ -100,33 +99,14 @@ def __init__(
10099
101100 self .app = FastAPI ()
102101
103- # Session and artifact services for new endpoints
104- # Priority: 1. provided short_term_memory, 2. agent's short_term_memory, 3. create new
105- if short_term_memory is not None :
106- from google .adk .sessions .base_session_service import BaseSessionService
107-
108- if isinstance (short_term_memory , BaseSessionService ):
109- self .session_service = short_term_memory
110- else :
111- self .session_service = short_term_memory .session_service
112- elif (
113- hasattr (agent , "short_term_memory" ) and agent .short_term_memory is not None
114- ):
115- from google .adk .sessions .base_session_service import BaseSessionService
116-
117- if isinstance (agent .short_term_memory , BaseSessionService ):
118- self .session_service = agent .short_term_memory
119- else :
120- self .session_service = agent .short_term_memory .session_service
121- else :
122- self .session_service = InMemorySessionService ()
123102 self .artifact_service = InMemoryArtifactService ()
124103
125104 # build routes for self.app
126105 self .build ()
127106
128107 self .ws_session_mgr = WebsocketSessionManager ()
129108 self .ws_agent_mgr : dict [str , "Agent" ] = {}
109+ self .ws_session_service_mgr : dict [str , "InMemorySessionService" ] = {}
130110
131111 def build (self ):
132112 logger .info ("Build routes for server with reverse mcp" )
@@ -141,6 +121,8 @@ class InvokeRequest(BaseModel):
141121
142122 websocket_id : str
143123
124+ mcp_tool_filter : Optional [list [str ]] = None
125+
144126 class InvokeResponse (BaseModel ):
145127 """Response model for /invoke endpoint"""
146128
@@ -155,16 +137,32 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
155137
156138 agent = self .ws_agent_mgr [payload .websocket_id ]
157139
158- if not agent .tools :
140+ mcp_toolset_url = f"http://127.0.0.1:{ self .port } /mcp"
141+ mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY : payload .websocket_id }
142+
143+ has_mcp_toolset = False
144+ for tool in agent .tools :
145+ if isinstance (tool , MCPToolset ):
146+ if hasattr (tool , "_connection_params" ):
147+ conn_params = tool ._connection_params
148+ if (
149+ hasattr (conn_params , "url" )
150+ and conn_params .url == mcp_toolset_url
151+ and hasattr (conn_params , "headers" )
152+ and conn_params .headers == mcp_toolset_headers
153+ ):
154+ has_mcp_toolset = True
155+ break
156+
157+ if not has_mcp_toolset :
159158 logger .debug ("Mount fake MCPToolset to agent" )
160-
161- # we hard code the mcp url with `/mcp` to obey the mcp protocol
162159 agent .tools .append (
163160 MCPToolset (
164161 connection_params = StreamableHTTPConnectionParams (
165- url = f"http://127.0.0.1: { self . port } /mcp" ,
166- headers = { REVERSE_MCP_HEADER_KEY : payload . websocket_id } ,
162+ url = mcp_toolset_url ,
163+ headers = mcp_toolset_headers ,
167164 ),
165+ tool_filter = payload .mcp_tool_filter ,
168166 )
169167 )
170168
@@ -194,19 +192,32 @@ async def ws_endpoint(ws: WebSocket):
194192 logger .info (f"Fork agent for websocket { client_id } " )
195193 self .ws_agent_mgr [client_id ] = self .agent .clone ()
196194
195+ logger .info (f"Create session service for websocket { client_id } " )
196+ self .ws_session_service_mgr [client_id ] = InMemorySessionService ()
197+
197198 await ws .accept ()
198199 logger .info (f"Websocket { client_id } connected" )
199200
200201 while True :
201202 raw = await ws .receive_text ()
202203 await self .ws_session_mgr .handle_ws_message (client_id , raw )
203204
204- # ========== New endpoints: create_session, create_session_with_id, run_sse ==========
205- # NOTE: These must be defined BEFORE the catch-all /{path:path} route
206-
207205 class CreateSessionRequest (BaseModel ):
208206 state : Optional [dict [str , Any ]] = None
209207 session_id : Optional [str ] = None
208+ websocket_id : str
209+
210+ class RunAgentRequestWithWsId (RunAgentRequest ):
211+ websocket_id : str
212+ mcp_tool_filter : Optional [list [str ]] = None
213+
214+ def _get_session_service (websocket_id : str ) -> InMemorySessionService :
215+ """Get session service for the websocket client."""
216+ if websocket_id not in self .ws_session_service_mgr :
217+ raise HTTPException (
218+ status_code = 404 , detail = f"WebSocket client { websocket_id } not found"
219+ )
220+ return self .ws_session_service_mgr [websocket_id ]
210221
211222 @self .app .post (
212223 "/apps/{app_name}/users/{user_id}/sessions" ,
@@ -215,21 +226,22 @@ class CreateSessionRequest(BaseModel):
215226 async def create_session (
216227 app_name : str ,
217228 user_id : str ,
218- req : Optional [ CreateSessionRequest ] = None ,
229+ req : CreateSessionRequest ,
219230 ) -> Session :
220231 """Create a new session."""
221- session_id = req .session_id if req and req .session_id else str (uuid .uuid4 ())
232+ session_id = req .session_id if req .session_id else str (uuid .uuid4 ())
222233 session = Session (
223234 app_name = app_name ,
224235 user_id = user_id ,
225236 id = session_id ,
226- state = req .state if req and req .state else {},
237+ state = req .state if req .state else {},
227238 )
228- await self .session_service .create_session (
239+ session_service = _get_session_service (req .websocket_id )
240+ await session_service .create_session (
229241 app_name = app_name ,
230242 user_id = user_id ,
231243 session_id = session_id ,
232- state = req .state if req and req .state else {},
244+ state = req .state if req .state else {},
233245 )
234246 logger .info (
235247 f"Created session: { session_id } for user { user_id } in app { app_name } "
@@ -244,64 +256,84 @@ async def create_session_with_id(
244256 app_name : str ,
245257 user_id : str ,
246258 session_id : str ,
247- state : Optional [ dict [ str , Any ]] = None ,
259+ req : CreateSessionRequest ,
248260 ) -> Session :
249261 """Create a session with specific ID."""
250- await self .session_service .create_session (
262+ session_service = _get_session_service (req .websocket_id )
263+ await session_service .create_session (
251264 app_name = app_name ,
252265 user_id = user_id ,
253266 session_id = session_id ,
254- state = state if state else {},
267+ state = req . state if req . state else {},
255268 )
256269 session = Session (
257270 app_name = app_name ,
258271 user_id = user_id ,
259272 id = session_id ,
260- state = state if state else {},
273+ state = req . state if req . state else {},
261274 )
262275 logger .info (f"Created session with ID: { session_id } for user { user_id } " )
263276 return session
264277
265278 @self .app .post ("/run_sse" )
266- async def run_agent_sse (req : RunAgentRequest ) -> StreamingResponse :
279+ async def run_agent_sse (req : RunAgentRequestWithWsId ) -> StreamingResponse :
267280 """Run agent with SSE streaming."""
281+ session_service = _get_session_service (req .websocket_id )
282+
268283 # Get session
269- session = await self . session_service .get_session (
284+ session = await session_service .get_session (
270285 app_name = req .app_name ,
271286 user_id = req .user_id ,
272287 session_id = req .session_id ,
273288 )
274289 if not session :
275290 raise HTTPException (status_code = 404 , detail = "Session not found" )
276291
277- # Use the first connected websocket client, or create a new agent clone
278- websocket_id = None
279- if self .ws_agent_mgr :
280- websocket_id = list (self .ws_agent_mgr .keys ())[0 ]
281- agent = self .ws_agent_mgr [websocket_id ]
282- logger .debug (f"Using agent from websocket { websocket_id } " )
292+ # Get agent for this websocket
293+ if req .websocket_id in self .ws_agent_mgr :
294+ agent = self .ws_agent_mgr [req .websocket_id ]
295+ logger .debug (f"Using agent from websocket { req .websocket_id } " )
283296 else :
284- # No websocket connected, use original agent
285- agent = self .agent
286- logger .debug ("No websocket connected, using original agent" )
297+ raise HTTPException (
298+ status_code = 404 ,
299+ detail = f"WebSocket client { req .websocket_id } not found" ,
300+ )
287301
288302 # Mount MCPToolset if needed
289- if not agent .tools :
303+ mcp_toolset_url = f"http://127.0.0.1:{ self .port } /mcp"
304+ mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY : req .websocket_id }
305+
306+ has_mcp_toolset = False
307+ for tool in agent .tools :
308+ if isinstance (tool , MCPToolset ):
309+ if hasattr (tool , "_connection_params" ):
310+ conn_params = tool ._connection_params
311+ if (
312+ hasattr (conn_params , "url" )
313+ and conn_params .url == mcp_toolset_url
314+ and hasattr (conn_params , "headers" )
315+ and conn_params .headers == mcp_toolset_headers
316+ ):
317+ has_mcp_toolset = True
318+ break
319+
320+ if not has_mcp_toolset :
290321 logger .debug ("Mount fake MCPToolset to agent for SSE" )
291322 agent .tools .append (
292323 MCPToolset (
293324 connection_params = StreamableHTTPConnectionParams (
294- url = f"http://127.0.0.1: { self . port } /mcp" ,
295- headers = { REVERSE_MCP_HEADER_KEY : websocket_id or "default" } ,
325+ url = mcp_toolset_url ,
326+ headers = mcp_toolset_headers ,
296327 ),
328+ tool_filter = req .mcp_tool_filter ,
297329 )
298330 )
299331
300332 # Create runner
301333 runner = GoogleRunner (
302334 agent = agent ,
303335 app_name = req .app_name ,
304- session_service = self . session_service ,
336+ session_service = session_service ,
305337 artifact_service = self .artifact_service ,
306338 )
307339
0 commit comments