Skip to content

Commit 4b9f2cf

Browse files
committed
feat: reverse mcp with session_service_mgr
1 parent 3d869a1 commit 4b9f2cf

1 file changed

Lines changed: 85 additions & 53 deletions

File tree

veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py

Lines changed: 85 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def __init__(
9191
agent: "Agent",
9292
host: str = "0.0.0.0",
9393
port: int = 8000,
94-
short_term_memory: Optional[Any] = None,
9594
):
9695
self.agent = agent
9796

@@ -100,33 +99,14 @@ def __init__(
10099

101100
self.app = FastAPI()
102101

103-
# Session and artifact services for new endpoints
104-
# Priority: 1. provided short_term_memory, 2. agent's short_term_memory, 3. create new
105-
if short_term_memory is not None:
106-
from google.adk.sessions.base_session_service import BaseSessionService
107-
108-
if isinstance(short_term_memory, BaseSessionService):
109-
self.session_service = short_term_memory
110-
else:
111-
self.session_service = short_term_memory.session_service
112-
elif (
113-
hasattr(agent, "short_term_memory") and agent.short_term_memory is not None
114-
):
115-
from google.adk.sessions.base_session_service import BaseSessionService
116-
117-
if isinstance(agent.short_term_memory, BaseSessionService):
118-
self.session_service = agent.short_term_memory
119-
else:
120-
self.session_service = agent.short_term_memory.session_service
121-
else:
122-
self.session_service = InMemorySessionService()
123102
self.artifact_service = InMemoryArtifactService()
124103

125104
# build routes for self.app
126105
self.build()
127106

128107
self.ws_session_mgr = WebsocketSessionManager()
129108
self.ws_agent_mgr: dict[str, "Agent"] = {}
109+
self.ws_session_service_mgr: dict[str, "InMemorySessionService"] = {}
130110

131111
def build(self):
132112
logger.info("Build routes for server with reverse mcp")
@@ -141,6 +121,8 @@ class InvokeRequest(BaseModel):
141121

142122
websocket_id: str
143123

124+
mcp_tool_filter: Optional[list[str]] = None
125+
144126
class InvokeResponse(BaseModel):
145127
"""Response model for /invoke endpoint"""
146128

@@ -155,16 +137,32 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
155137

156138
agent = self.ws_agent_mgr[payload.websocket_id]
157139

158-
if not agent.tools:
140+
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
141+
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: payload.websocket_id}
142+
143+
has_mcp_toolset = False
144+
for tool in agent.tools:
145+
if isinstance(tool, MCPToolset):
146+
if hasattr(tool, "_connection_params"):
147+
conn_params = tool._connection_params
148+
if (
149+
hasattr(conn_params, "url")
150+
and conn_params.url == mcp_toolset_url
151+
and hasattr(conn_params, "headers")
152+
and conn_params.headers == mcp_toolset_headers
153+
):
154+
has_mcp_toolset = True
155+
break
156+
157+
if not has_mcp_toolset:
159158
logger.debug("Mount fake MCPToolset to agent")
160-
161-
# we hard code the mcp url with `/mcp` to obey the mcp protocol
162159
agent.tools.append(
163160
MCPToolset(
164161
connection_params=StreamableHTTPConnectionParams(
165-
url=f"http://127.0.0.1:{self.port}/mcp",
166-
headers={REVERSE_MCP_HEADER_KEY: payload.websocket_id},
162+
url=mcp_toolset_url,
163+
headers=mcp_toolset_headers,
167164
),
165+
tool_filter=payload.mcp_tool_filter,
168166
)
169167
)
170168

@@ -194,19 +192,32 @@ async def ws_endpoint(ws: WebSocket):
194192
logger.info(f"Fork agent for websocket {client_id}")
195193
self.ws_agent_mgr[client_id] = self.agent.clone()
196194

195+
logger.info(f"Create session service for websocket {client_id}")
196+
self.ws_session_service_mgr[client_id] = InMemorySessionService()
197+
197198
await ws.accept()
198199
logger.info(f"Websocket {client_id} connected")
199200

200201
while True:
201202
raw = await ws.receive_text()
202203
await self.ws_session_mgr.handle_ws_message(client_id, raw)
203204

204-
# ========== New endpoints: create_session, create_session_with_id, run_sse ==========
205-
# NOTE: These must be defined BEFORE the catch-all /{path:path} route
206-
207205
class CreateSessionRequest(BaseModel):
208206
state: Optional[dict[str, Any]] = None
209207
session_id: Optional[str] = None
208+
websocket_id: str
209+
210+
class RunAgentRequestWithWsId(RunAgentRequest):
211+
websocket_id: str
212+
mcp_tool_filter: Optional[list[str]] = None
213+
214+
def _get_session_service(websocket_id: str) -> InMemorySessionService:
215+
"""Get session service for the websocket client."""
216+
if websocket_id not in self.ws_session_service_mgr:
217+
raise HTTPException(
218+
status_code=404, detail=f"WebSocket client {websocket_id} not found"
219+
)
220+
return self.ws_session_service_mgr[websocket_id]
210221

211222
@self.app.post(
212223
"/apps/{app_name}/users/{user_id}/sessions",
@@ -215,21 +226,22 @@ class CreateSessionRequest(BaseModel):
215226
async def create_session(
216227
app_name: str,
217228
user_id: str,
218-
req: Optional[CreateSessionRequest] = None,
229+
req: CreateSessionRequest,
219230
) -> Session:
220231
"""Create a new session."""
221-
session_id = req.session_id if req and req.session_id else str(uuid.uuid4())
232+
session_id = req.session_id if req.session_id else str(uuid.uuid4())
222233
session = Session(
223234
app_name=app_name,
224235
user_id=user_id,
225236
id=session_id,
226-
state=req.state if req and req.state else {},
237+
state=req.state if req.state else {},
227238
)
228-
await self.session_service.create_session(
239+
session_service = _get_session_service(req.websocket_id)
240+
await session_service.create_session(
229241
app_name=app_name,
230242
user_id=user_id,
231243
session_id=session_id,
232-
state=req.state if req and req.state else {},
244+
state=req.state if req.state else {},
233245
)
234246
logger.info(
235247
f"Created session: {session_id} for user {user_id} in app {app_name}"
@@ -244,64 +256,84 @@ async def create_session_with_id(
244256
app_name: str,
245257
user_id: str,
246258
session_id: str,
247-
state: Optional[dict[str, Any]] = None,
259+
req: CreateSessionRequest,
248260
) -> Session:
249261
"""Create a session with specific ID."""
250-
await self.session_service.create_session(
262+
session_service = _get_session_service(req.websocket_id)
263+
await session_service.create_session(
251264
app_name=app_name,
252265
user_id=user_id,
253266
session_id=session_id,
254-
state=state if state else {},
267+
state=req.state if req.state else {},
255268
)
256269
session = Session(
257270
app_name=app_name,
258271
user_id=user_id,
259272
id=session_id,
260-
state=state if state else {},
273+
state=req.state if req.state else {},
261274
)
262275
logger.info(f"Created session with ID: {session_id} for user {user_id}")
263276
return session
264277

265278
@self.app.post("/run_sse")
266-
async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
279+
async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
267280
"""Run agent with SSE streaming."""
281+
session_service = _get_session_service(req.websocket_id)
282+
268283
# Get session
269-
session = await self.session_service.get_session(
284+
session = await session_service.get_session(
270285
app_name=req.app_name,
271286
user_id=req.user_id,
272287
session_id=req.session_id,
273288
)
274289
if not session:
275290
raise HTTPException(status_code=404, detail="Session not found")
276291

277-
# Use the first connected websocket client, or create a new agent clone
278-
websocket_id = None
279-
if self.ws_agent_mgr:
280-
websocket_id = list(self.ws_agent_mgr.keys())[0]
281-
agent = self.ws_agent_mgr[websocket_id]
282-
logger.debug(f"Using agent from websocket {websocket_id}")
292+
# Get agent for this websocket
293+
if req.websocket_id in self.ws_agent_mgr:
294+
agent = self.ws_agent_mgr[req.websocket_id]
295+
logger.debug(f"Using agent from websocket {req.websocket_id}")
283296
else:
284-
# No websocket connected, use original agent
285-
agent = self.agent
286-
logger.debug("No websocket connected, using original agent")
297+
raise HTTPException(
298+
status_code=404,
299+
detail=f"WebSocket client {req.websocket_id} not found",
300+
)
287301

288302
# Mount MCPToolset if needed
289-
if not agent.tools:
303+
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
304+
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: req.websocket_id}
305+
306+
has_mcp_toolset = False
307+
for tool in agent.tools:
308+
if isinstance(tool, MCPToolset):
309+
if hasattr(tool, "_connection_params"):
310+
conn_params = tool._connection_params
311+
if (
312+
hasattr(conn_params, "url")
313+
and conn_params.url == mcp_toolset_url
314+
and hasattr(conn_params, "headers")
315+
and conn_params.headers == mcp_toolset_headers
316+
):
317+
has_mcp_toolset = True
318+
break
319+
320+
if not has_mcp_toolset:
290321
logger.debug("Mount fake MCPToolset to agent for SSE")
291322
agent.tools.append(
292323
MCPToolset(
293324
connection_params=StreamableHTTPConnectionParams(
294-
url=f"http://127.0.0.1:{self.port}/mcp",
295-
headers={REVERSE_MCP_HEADER_KEY: websocket_id or "default"},
325+
url=mcp_toolset_url,
326+
headers=mcp_toolset_headers,
296327
),
328+
tool_filter=req.mcp_tool_filter,
297329
)
298330
)
299331

300332
# Create runner
301333
runner = GoogleRunner(
302334
agent=agent,
303335
app_name=req.app_name,
304-
session_service=self.session_service,
336+
session_service=session_service,
305337
artifact_service=self.artifact_service,
306338
)
307339

0 commit comments

Comments
 (0)