|
15 | 15 | import asyncio |
16 | 16 | import json |
17 | 17 | import uuid |
18 | | -from typing import TYPE_CHECKING |
19 | | - |
20 | | -from fastapi import FastAPI, Request, Response, WebSocket |
| 18 | +from typing import TYPE_CHECKING, Any, Optional |
| 19 | + |
| 20 | +from fastapi import FastAPI, HTTPException, Request, Response, WebSocket |
| 21 | +from fastapi.responses import StreamingResponse |
| 22 | +from google.adk.artifacts import InMemoryArtifactService |
| 23 | +from google.adk.cli.adk_web_server import RunAgentRequest |
| 24 | +from google.adk.runners import Runner as GoogleRunner |
| 25 | +from google.adk.sessions import InMemorySessionService, Session |
21 | 26 | from google.adk.tools.mcp_tool.mcp_session_manager import ( |
22 | 27 | StreamableHTTPConnectionParams, |
23 | 28 | ) |
@@ -86,13 +91,37 @@ def __init__( |
86 | 91 | agent: "Agent", |
87 | 92 | host: str = "0.0.0.0", |
88 | 93 | port: int = 8000, |
| 94 | + short_term_memory: Optional[Any] = None, |
89 | 95 | ): |
90 | 96 | self.agent = agent |
91 | 97 |
|
92 | 98 | self.host = host |
93 | 99 | self.port = port |
94 | 100 |
|
95 | 101 | self.app = FastAPI() |
| 102 | + |
| 103 | + # Session and artifact services for new endpoints |
| 104 | + # Priority: 1. provided short_term_memory, 2. agent's short_term_memory, 3. create new |
| 105 | + if short_term_memory is not None: |
| 106 | + from google.adk.sessions.base_session_service import BaseSessionService |
| 107 | + |
| 108 | + if isinstance(short_term_memory, BaseSessionService): |
| 109 | + self.session_service = short_term_memory |
| 110 | + else: |
| 111 | + self.session_service = short_term_memory.session_service |
| 112 | + elif ( |
| 113 | + hasattr(agent, "short_term_memory") and agent.short_term_memory is not None |
| 114 | + ): |
| 115 | + from google.adk.sessions.base_session_service import BaseSessionService |
| 116 | + |
| 117 | + if isinstance(agent.short_term_memory, BaseSessionService): |
| 118 | + self.session_service = agent.short_term_memory |
| 119 | + else: |
| 120 | + self.session_service = agent.short_term_memory.session_service |
| 121 | + else: |
| 122 | + self.session_service = InMemorySessionService() |
| 123 | + self.artifact_service = InMemoryArtifactService() |
| 124 | + |
96 | 125 | # build routes for self.app |
97 | 126 | self.build() |
98 | 127 |
|
@@ -172,8 +201,135 @@ async def ws_endpoint(ws: WebSocket): |
172 | 201 | raw = await ws.receive_text() |
173 | 202 | await self.ws_session_mgr.handle_ws_message(client_id, raw) |
174 | 203 |
|
| 204 | + # ========== New endpoints: create_session, create_session_with_id, run_sse ========== |
| 205 | + # NOTE: These must be defined BEFORE the catch-all /{path:path} route |
| 206 | + |
| 207 | + class CreateSessionRequest(BaseModel): |
| 208 | + state: Optional[dict[str, Any]] = None |
| 209 | + session_id: Optional[str] = None |
| 210 | + |
| 211 | + @self.app.post( |
| 212 | + "/apps/{app_name}/users/{user_id}/sessions", |
| 213 | + response_model_exclude_none=True, |
| 214 | + ) |
| 215 | + async def create_session( |
| 216 | + app_name: str, |
| 217 | + user_id: str, |
| 218 | + req: Optional[CreateSessionRequest] = None, |
| 219 | + ) -> Session: |
| 220 | + """Create a new session.""" |
| 221 | + session_id = req.session_id if req and req.session_id else str(uuid.uuid4()) |
| 222 | + session = Session( |
| 223 | + app_name=app_name, |
| 224 | + user_id=user_id, |
| 225 | + id=session_id, |
| 226 | + state=req.state if req and req.state else {}, |
| 227 | + ) |
| 228 | + await self.session_service.create_session( |
| 229 | + app_name=app_name, |
| 230 | + user_id=user_id, |
| 231 | + session_id=session_id, |
| 232 | + state=req.state if req and req.state else {}, |
| 233 | + ) |
| 234 | + logger.info( |
| 235 | + f"Created session: {session_id} for user {user_id} in app {app_name}" |
| 236 | + ) |
| 237 | + return session |
| 238 | + |
| 239 | + @self.app.post( |
| 240 | + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", |
| 241 | + response_model_exclude_none=True, |
| 242 | + ) |
| 243 | + async def create_session_with_id( |
| 244 | + app_name: str, |
| 245 | + user_id: str, |
| 246 | + session_id: str, |
| 247 | + state: Optional[dict[str, Any]] = None, |
| 248 | + ) -> Session: |
| 249 | + """Create a session with specific ID.""" |
| 250 | + await self.session_service.create_session( |
| 251 | + app_name=app_name, |
| 252 | + user_id=user_id, |
| 253 | + session_id=session_id, |
| 254 | + state=state if state else {}, |
| 255 | + ) |
| 256 | + session = Session( |
| 257 | + app_name=app_name, |
| 258 | + user_id=user_id, |
| 259 | + id=session_id, |
| 260 | + state=state if state else {}, |
| 261 | + ) |
| 262 | + logger.info(f"Created session with ID: {session_id} for user {user_id}") |
| 263 | + return session |
| 264 | + |
| 265 | + @self.app.post("/run_sse") |
| 266 | + async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: |
| 267 | + """Run agent with SSE streaming.""" |
| 268 | + # Get session |
| 269 | + session = await self.session_service.get_session( |
| 270 | + app_name=req.app_name, |
| 271 | + user_id=req.user_id, |
| 272 | + session_id=req.session_id, |
| 273 | + ) |
| 274 | + if not session: |
| 275 | + raise HTTPException(status_code=404, detail="Session not found") |
| 276 | + |
| 277 | + # Use the first connected websocket client, or create a new agent clone |
| 278 | + websocket_id = None |
| 279 | + if self.ws_agent_mgr: |
| 280 | + websocket_id = list(self.ws_agent_mgr.keys())[0] |
| 281 | + agent = self.ws_agent_mgr[websocket_id] |
| 282 | + logger.debug(f"Using agent from websocket {websocket_id}") |
| 283 | + else: |
| 284 | + # No websocket connected, use original agent |
| 285 | + agent = self.agent |
| 286 | + logger.debug("No websocket connected, using original agent") |
| 287 | + |
| 288 | + # Mount MCPToolset if needed |
| 289 | + if not agent.tools: |
| 290 | + logger.debug("Mount fake MCPToolset to agent for SSE") |
| 291 | + agent.tools.append( |
| 292 | + MCPToolset( |
| 293 | + connection_params=StreamableHTTPConnectionParams( |
| 294 | + url=f"http://127.0.0.1:{self.port}/mcp", |
| 295 | + headers={REVERSE_MCP_HEADER_KEY: websocket_id or "default"}, |
| 296 | + ), |
| 297 | + ) |
| 298 | + ) |
| 299 | + |
| 300 | + # Create runner |
| 301 | + runner = GoogleRunner( |
| 302 | + agent=agent, |
| 303 | + app_name=req.app_name, |
| 304 | + session_service=self.session_service, |
| 305 | + artifact_service=self.artifact_service, |
| 306 | + ) |
| 307 | + |
| 308 | + async def event_generator(): |
| 309 | + try: |
| 310 | + async for event in runner.run_async( |
| 311 | + user_id=req.user_id, |
| 312 | + session_id=req.session_id, |
| 313 | + new_message=req.new_message, |
| 314 | + state_delta=req.state_delta, |
| 315 | + ): |
| 316 | + event_json = event.model_dump_json( |
| 317 | + exclude_none=True, by_alias=True |
| 318 | + ) |
| 319 | + logger.debug(f"SSE event: {event_json}") |
| 320 | + yield f"data: {event_json}\n\n" |
| 321 | + except Exception as e: |
| 322 | + logger.exception(f"Error in event_generator: {e}") |
| 323 | + yield f"data: {json.dumps({'error': str(e)})}\n\n" |
| 324 | + |
| 325 | + return StreamingResponse( |
| 326 | + event_generator(), |
| 327 | + media_type="text/event-stream", |
| 328 | + ) |
| 329 | + |
175 | 330 | # build the fake MPC server, |
176 | 331 | # and intercept all requests to the client websocket client. |
| 332 | + # NOTE: This catch-all route must be defined LAST |
177 | 333 | @self.app.api_route("/{path:path}", methods=["GET", "POST"]) |
178 | 334 | async def mcp_proxy(path: str, request: Request): |
179 | 335 | client_id = request.headers.get(REVERSE_MCP_HEADER_KEY) |
|
0 commit comments