Skip to content

Commit 3d869a1

Browse files
committed
feat: add create_session and run_sse on reverse mcp app
1 parent 214deb3 commit 3d869a1

1 file changed

Lines changed: 159 additions & 3 deletions

File tree

veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515
import asyncio
1616
import json
1717
import uuid
18-
from typing import TYPE_CHECKING
19-
20-
from fastapi import FastAPI, Request, Response, WebSocket
18+
from typing import TYPE_CHECKING, Any, Optional
19+
20+
from fastapi import FastAPI, HTTPException, Request, Response, WebSocket
21+
from fastapi.responses import StreamingResponse
22+
from google.adk.artifacts import InMemoryArtifactService
23+
from google.adk.cli.adk_web_server import RunAgentRequest
24+
from google.adk.runners import Runner as GoogleRunner
25+
from google.adk.sessions import InMemorySessionService, Session
2126
from google.adk.tools.mcp_tool.mcp_session_manager import (
2227
StreamableHTTPConnectionParams,
2328
)
@@ -86,13 +91,37 @@ def __init__(
8691
agent: "Agent",
8792
host: str = "0.0.0.0",
8893
port: int = 8000,
94+
short_term_memory: Optional[Any] = None,
8995
):
9096
self.agent = agent
9197

9298
self.host = host
9399
self.port = port
94100

95101
self.app = FastAPI()
102+
103+
# Session and artifact services for new endpoints
104+
# Priority: 1. provided short_term_memory, 2. agent's short_term_memory, 3. create new
105+
if short_term_memory is not None:
106+
from google.adk.sessions.base_session_service import BaseSessionService
107+
108+
if isinstance(short_term_memory, BaseSessionService):
109+
self.session_service = short_term_memory
110+
else:
111+
self.session_service = short_term_memory.session_service
112+
elif (
113+
hasattr(agent, "short_term_memory") and agent.short_term_memory is not None
114+
):
115+
from google.adk.sessions.base_session_service import BaseSessionService
116+
117+
if isinstance(agent.short_term_memory, BaseSessionService):
118+
self.session_service = agent.short_term_memory
119+
else:
120+
self.session_service = agent.short_term_memory.session_service
121+
else:
122+
self.session_service = InMemorySessionService()
123+
self.artifact_service = InMemoryArtifactService()
124+
96125
# build routes for self.app
97126
self.build()
98127

@@ -172,8 +201,135 @@ async def ws_endpoint(ws: WebSocket):
172201
raw = await ws.receive_text()
173202
await self.ws_session_mgr.handle_ws_message(client_id, raw)
174203

204+
# ========== New endpoints: create_session, create_session_with_id, run_sse ==========
205+
# NOTE: These must be defined BEFORE the catch-all /{path:path} route
206+
207+
class CreateSessionRequest(BaseModel):
208+
state: Optional[dict[str, Any]] = None
209+
session_id: Optional[str] = None
210+
211+
@self.app.post(
212+
"/apps/{app_name}/users/{user_id}/sessions",
213+
response_model_exclude_none=True,
214+
)
215+
async def create_session(
216+
app_name: str,
217+
user_id: str,
218+
req: Optional[CreateSessionRequest] = None,
219+
) -> Session:
220+
"""Create a new session."""
221+
session_id = req.session_id if req and req.session_id else str(uuid.uuid4())
222+
session = Session(
223+
app_name=app_name,
224+
user_id=user_id,
225+
id=session_id,
226+
state=req.state if req and req.state else {},
227+
)
228+
await self.session_service.create_session(
229+
app_name=app_name,
230+
user_id=user_id,
231+
session_id=session_id,
232+
state=req.state if req and req.state else {},
233+
)
234+
logger.info(
235+
f"Created session: {session_id} for user {user_id} in app {app_name}"
236+
)
237+
return session
238+
239+
@self.app.post(
240+
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
241+
response_model_exclude_none=True,
242+
)
243+
async def create_session_with_id(
244+
app_name: str,
245+
user_id: str,
246+
session_id: str,
247+
state: Optional[dict[str, Any]] = None,
248+
) -> Session:
249+
"""Create a session with specific ID."""
250+
await self.session_service.create_session(
251+
app_name=app_name,
252+
user_id=user_id,
253+
session_id=session_id,
254+
state=state if state else {},
255+
)
256+
session = Session(
257+
app_name=app_name,
258+
user_id=user_id,
259+
id=session_id,
260+
state=state if state else {},
261+
)
262+
logger.info(f"Created session with ID: {session_id} for user {user_id}")
263+
return session
264+
265+
@self.app.post("/run_sse")
266+
async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
267+
"""Run agent with SSE streaming."""
268+
# Get session
269+
session = await self.session_service.get_session(
270+
app_name=req.app_name,
271+
user_id=req.user_id,
272+
session_id=req.session_id,
273+
)
274+
if not session:
275+
raise HTTPException(status_code=404, detail="Session not found")
276+
277+
# Use the first connected websocket client, or create a new agent clone
278+
websocket_id = None
279+
if self.ws_agent_mgr:
280+
websocket_id = list(self.ws_agent_mgr.keys())[0]
281+
agent = self.ws_agent_mgr[websocket_id]
282+
logger.debug(f"Using agent from websocket {websocket_id}")
283+
else:
284+
# No websocket connected, use original agent
285+
agent = self.agent
286+
logger.debug("No websocket connected, using original agent")
287+
288+
# Mount MCPToolset if needed
289+
if not agent.tools:
290+
logger.debug("Mount fake MCPToolset to agent for SSE")
291+
agent.tools.append(
292+
MCPToolset(
293+
connection_params=StreamableHTTPConnectionParams(
294+
url=f"http://127.0.0.1:{self.port}/mcp",
295+
headers={REVERSE_MCP_HEADER_KEY: websocket_id or "default"},
296+
),
297+
)
298+
)
299+
300+
# Create runner
301+
runner = GoogleRunner(
302+
agent=agent,
303+
app_name=req.app_name,
304+
session_service=self.session_service,
305+
artifact_service=self.artifact_service,
306+
)
307+
308+
async def event_generator():
309+
try:
310+
async for event in runner.run_async(
311+
user_id=req.user_id,
312+
session_id=req.session_id,
313+
new_message=req.new_message,
314+
state_delta=req.state_delta,
315+
):
316+
event_json = event.model_dump_json(
317+
exclude_none=True, by_alias=True
318+
)
319+
logger.debug(f"SSE event: {event_json}")
320+
yield f"data: {event_json}\n\n"
321+
except Exception as e:
322+
logger.exception(f"Error in event_generator: {e}")
323+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
324+
325+
return StreamingResponse(
326+
event_generator(),
327+
media_type="text/event-stream",
328+
)
329+
175330
# build the fake MPC server,
176331
# and intercept all requests to the client websocket client.
332+
# NOTE: This catch-all route must be defined LAST
177333
@self.app.api_route("/{path:path}", methods=["GET", "POST"])
178334
async def mcp_proxy(path: str, request: Request):
179335
client_id = request.headers.get(REVERSE_MCP_HEADER_KEY)

0 commit comments

Comments
 (0)