Skip to content

Commit 5f90080

Browse files
committed
fix: reconstruct vanna tool
1 parent b452553 commit 5f90080

8 files changed

Lines changed: 348 additions & 152 deletions

File tree

veadk/tools/vanna_tools/agent_memory.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
"""
4242
self.agent_memory = agent_memory
4343
self.vanna_tool = VannaSaveQuestionToolArgsTool()
44-
self.access_groups = access_groups or ["admin"] # Default: only admin
44+
self.access_groups = access_groups or ["admin", "user"]
4545

4646
super().__init__(
4747
name="save_question_tool_args", # Keep the same name as Vanna
@@ -86,14 +86,19 @@ def _create_vanna_context(
8686
) -> VannaToolContext:
8787
"""Create Vanna context from Veadk ToolContext."""
8888
user_id = tool_context.user_id
89+
session_id = tool_context.session.id
8990
user_email = tool_context.state.get("user_email", "user@example.com")
9091

91-
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
92+
vanna_user = User(
93+
id=user_id + "_" + session_id,
94+
email=user_email,
95+
group_memberships=user_groups,
96+
)
9297

9398
vanna_context = VannaToolContext(
9499
user=vanna_user,
95-
conversation_id=tool_context.session.id,
96-
request_id=tool_context.session.id,
100+
conversation_id=session_id,
101+
request_id=session_id,
97102
agent_memory=self.agent_memory,
98103
)
99104

@@ -190,14 +195,19 @@ def _create_vanna_context(
190195
) -> VannaToolContext:
191196
"""Create Vanna context from Veadk ToolContext."""
192197
user_id = tool_context.user_id
198+
session_id = tool_context.session.id
193199
user_email = tool_context.state.get("user_email", "user@example.com")
194200

195-
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
201+
vanna_user = User(
202+
id=user_id + "_" + session_id,
203+
email=user_email,
204+
group_memberships=user_groups,
205+
)
196206

197207
vanna_context = VannaToolContext(
198208
user=vanna_user,
199-
conversation_id=tool_context.session.id,
200-
request_id=tool_context.session.id,
209+
conversation_id=session_id,
210+
request_id=session_id,
201211
agent_memory=self.agent_memory,
202212
)
203213

@@ -286,14 +296,19 @@ def _create_vanna_context(
286296
) -> VannaToolContext:
287297
"""Create Vanna context from Veadk ToolContext."""
288298
user_id = tool_context.user_id
299+
session_id = tool_context.session.id
289300
user_email = tool_context.state.get("user_email", "user@example.com")
290301

291-
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
302+
vanna_user = User(
303+
id=user_id + "_" + session_id,
304+
email=user_email,
305+
group_memberships=user_groups,
306+
)
292307

293308
vanna_context = VannaToolContext(
294309
user=vanna_user,
295-
conversation_id=tool_context.session.id,
296-
request_id=tool_context.session.id,
310+
conversation_id=session_id,
311+
request_id=session_id,
297312
agent_memory=self.agent_memory,
298313
)
299314

veadk/tools/vanna_tools/examples/agent.py

Lines changed: 7 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -12,46 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
from veadk import Agent, Runner
17-
18-
# Import Vanna dependencies for initialization
19-
from vanna.integrations.sqlite import SqliteRunner
20-
from vanna.tools import LocalFileSystem
21-
from vanna.integrations.local.agent_memory import DemoAgentMemory
22-
import httpx
23-
24-
# Import the refactored class-based tools
25-
from veadk.tools.vanna_tools.run_sql import RunSqlTool
26-
from veadk.tools.vanna_tools.visualize_data import VisualizeDataTool
27-
from veadk.tools.vanna_tools.file_system import WriteFileTool
28-
from veadk.tools.vanna_tools.agent_memory import (
29-
SaveQuestionToolArgsTool,
30-
SearchSavedCorrectToolUsesTool,
31-
)
32-
from veadk.tools.vanna_tools.summarize_data import SummarizeDataTool
33-
16+
from veadk.tools.vanna_tools.vanna_toolset import VannaToolSet
3417
from google.adk.sessions import InMemorySessionService
3518

3619

37-
# Setup SQLite database
38-
def setup_sqlite():
39-
"""Download and setup the Chinook SQLite database."""
40-
db_path = "/tmp/Chinook.sqlite"
41-
if not os.path.exists(db_path):
42-
print("Downloading Chinook.sqlite...")
43-
url = "https://vanna.ai/Chinook.sqlite"
44-
try:
45-
with open(db_path, "wb") as f:
46-
with httpx.stream("GET", url) as response:
47-
for chunk in response.iter_bytes():
48-
f.write(chunk)
49-
print("Database downloaded successfully!")
50-
except Exception as e:
51-
print(f"Error downloading database: {e}")
52-
return db_path
53-
54-
5520
# Create a session with user groups for access control
5621
async def create_session(user_groups: list = ["user"]):
5722
session_service = InMemorySessionService()
@@ -63,57 +28,10 @@ async def create_session(user_groups: list = ["user"]):
6328
return session_service, example_session
6429

6530

66-
# Initialize user-customizable resources
67-
db_path = setup_sqlite()
68-
69-
# 1. SQL Runner - can be SqliteRunner, PostgresRunner, MySQLRunner, etc.
70-
sqlite_runner = SqliteRunner(database_path=db_path)
71-
72-
# 2. File System - customize working directory as needed
73-
file_system = LocalFileSystem(working_directory="/tmp/data_storage")
74-
if not os.path.exists("/tmp/data_storage"):
75-
os.makedirs("/tmp/data_storage", exist_ok=True)
76-
77-
# 3. Agent Memory - customize memory implementation and capacity
78-
agent_memory = DemoAgentMemory(max_items=1000)
79-
80-
# Initialize tools with user-defined components and access control
81-
# Tool names now match Vanna's original names for compatibility
82-
run_sql_tool = RunSqlTool(
83-
sql_runner=sqlite_runner,
84-
file_system=file_system,
85-
agent_memory=agent_memory,
86-
access_groups=["admin", "user"], # Both admin and user can use
31+
vanna_toolset = VannaToolSet(
32+
connection_string="sqlite:///tmp/Chinook.sqlite", file_storage="/tmp/vanna_files"
8733
)
8834

89-
visualize_data_tool = VisualizeDataTool(
90-
file_system=file_system,
91-
agent_memory=agent_memory,
92-
access_groups=["admin", "user"],
93-
)
94-
95-
write_file_tool = WriteFileTool(
96-
file_system=file_system,
97-
agent_memory=agent_memory,
98-
access_groups=["admin", "user"],
99-
)
100-
101-
# Memory tools: save only for admin, search for all users
102-
save_tool = SaveQuestionToolArgsTool(
103-
agent_memory=agent_memory,
104-
access_groups=["admin"], # Only admin can save
105-
)
106-
107-
search_tool = SearchSavedCorrectToolUsesTool(
108-
agent_memory=agent_memory,
109-
access_groups=["admin", "user"], # All users can search
110-
)
111-
112-
summarize_data_tool = SummarizeDataTool(
113-
file_system=file_system,
114-
agent_memory=agent_memory,
115-
access_groups=["admin", "user"],
116-
)
11735

11836
# Define the Veadk Agent with class-based tools
11937
agent: Agent = Agent(
@@ -276,14 +194,7 @@ async def create_session(user_groups: list = ["user"]):
276194
5. `search_saved_correct_tool_uses` - Search for similar tool usage patterns
277195
6. `summarize_data` - Generate statistical summaries of CSV files
278196
""",
279-
tools=[
280-
run_sql_tool,
281-
visualize_data_tool,
282-
write_file_tool,
283-
save_tool,
284-
search_tool,
285-
summarize_data_tool,
286-
],
197+
tools=[vanna_toolset],
287198
model_extra_config={"extra_body": {"thinking": {"type": "disabled"}}},
288199
)
289200

@@ -311,12 +222,7 @@ async def main(prompt: str, user_groups: list = None) -> str:
311222
if __name__ == "__main__":
312223
import asyncio
313224

314-
# print("=== Example 1: Regular User ===")
315-
# user_input = "What are the top 5 selling albums?"
316-
# response = asyncio.run(main(user_input, user_groups=['user']))
317-
# print(response)
318-
319-
print("\n=== Example 2: Admin User (can save patterns) ===")
320-
admin_input = "What are the top 5 selling albums?"
321-
response = asyncio.run(main(admin_input, user_groups=["admin"]))
225+
print("=== Example 1: Regular User ===")
226+
user_input = "What are the top 5 selling albums?"
227+
response = asyncio.run(main(user_input, user_groups=["user"]))
322228
print(response)

veadk/tools/vanna_tools/file_system.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,26 @@ def _check_access(self, user_groups: List[str]) -> bool:
8787
def _create_vanna_context(
8888
self, tool_context: ToolContext, user_groups: List[str]
8989
) -> VannaToolContext:
90+
"""Create Vanna context from Veadk ToolContext."""
9091
user_id = tool_context.user_id
92+
session_id = tool_context.session.id
9193
user_email = tool_context.state.get("user_email", "user@example.com")
9294

93-
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
94-
return VannaToolContext(
95+
vanna_user = User(
96+
id=user_id + "_" + session_id,
97+
email=user_email,
98+
group_memberships=user_groups,
99+
)
100+
101+
vanna_context = VannaToolContext(
95102
user=vanna_user,
96-
conversation_id=tool_context.session.id,
97-
request_id=tool_context.session.id,
103+
conversation_id=session_id,
104+
request_id=session_id,
98105
agent_memory=self.agent_memory,
99106
)
100107

108+
return vanna_context
109+
101110
async def run_async(
102111
self, *, args: Dict[str, Any], tool_context: ToolContext
103112
) -> str:
@@ -168,17 +177,26 @@ def _check_access(self, user_groups: List[str]) -> bool:
168177
def _create_vanna_context(
169178
self, tool_context: ToolContext, user_groups: List[str]
170179
) -> VannaToolContext:
180+
"""Create Vanna context from Veadk ToolContext."""
171181
user_id = tool_context.user_id
182+
session_id = tool_context.session.id
172183
user_email = tool_context.state.get("user_email", "user@example.com")
173184

174-
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
175-
return VannaToolContext(
185+
vanna_user = User(
186+
id=user_id + "_" + session_id,
187+
email=user_email,
188+
group_memberships=user_groups,
189+
)
190+
191+
vanna_context = VannaToolContext(
176192
user=vanna_user,
177-
conversation_id=tool_context.session.id,
178-
request_id=tool_context.session.id,
193+
conversation_id=session_id,
194+
request_id=session_id,
179195
agent_memory=self.agent_memory,
180196
)
181197

198+
return vanna_context
199+
182200
async def run_async(
183201
self, *, args: Dict[str, Any], tool_context: ToolContext
184202
) -> str:
@@ -245,17 +263,26 @@ def _check_access(self, user_groups: List[str]) -> bool:
245263
def _create_vanna_context(
246264
self, tool_context: ToolContext, user_groups: List[str]
247265
) -> VannaToolContext:
266+
"""Create Vanna context from Veadk ToolContext."""
248267
user_id = tool_context.user_id
268+
session_id = tool_context.session.id
249269
user_email = tool_context.state.get("user_email", "user@example.com")
250270

251-
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
252-
return VannaToolContext(
271+
vanna_user = User(
272+
id=user_id + "_" + session_id,
273+
email=user_email,
274+
group_memberships=user_groups,
275+
)
276+
277+
vanna_context = VannaToolContext(
253278
user=vanna_user,
254-
conversation_id=tool_context.session.id,
255-
request_id=tool_context.session.id,
279+
conversation_id=session_id,
280+
request_id=session_id,
256281
agent_memory=self.agent_memory,
257282
)
258283

284+
return vanna_context
285+
259286
async def run_async(
260287
self, *, args: Dict[str, Any], tool_context: ToolContext
261288
) -> str:
@@ -327,17 +354,26 @@ def _check_access(self, user_groups: List[str]) -> bool:
327354
def _create_vanna_context(
328355
self, tool_context: ToolContext, user_groups: List[str]
329356
) -> VannaToolContext:
357+
"""Create Vanna context from Veadk ToolContext."""
330358
user_id = tool_context.user_id
359+
session_id = tool_context.session.id
331360
user_email = tool_context.state.get("user_email", "user@example.com")
332361

333-
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
334-
return VannaToolContext(
362+
vanna_user = User(
363+
id=user_id + "_" + session_id,
364+
email=user_email,
365+
group_memberships=user_groups,
366+
)
367+
368+
vanna_context = VannaToolContext(
335369
user=vanna_user,
336-
conversation_id=tool_context.session.id,
337-
request_id=tool_context.session.id,
370+
conversation_id=session_id,
371+
request_id=session_id,
338372
agent_memory=self.agent_memory,
339373
)
340374

375+
return vanna_context
376+
341377
async def run_async(
342378
self, *, args: Dict[str, Any], tool_context: ToolContext
343379
) -> str:
@@ -429,17 +465,26 @@ def _check_access(self, user_groups: List[str]) -> bool:
429465
def _create_vanna_context(
430466
self, tool_context: ToolContext, user_groups: List[str]
431467
) -> VannaToolContext:
468+
"""Create Vanna context from Veadk ToolContext."""
432469
user_id = tool_context.user_id
470+
session_id = tool_context.session.id
433471
user_email = tool_context.state.get("user_email", "user@example.com")
434472

435-
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
436-
return VannaToolContext(
473+
vanna_user = User(
474+
id=user_id + "_" + session_id,
475+
email=user_email,
476+
group_memberships=user_groups,
477+
)
478+
479+
vanna_context = VannaToolContext(
437480
user=vanna_user,
438-
conversation_id=tool_context.session.id,
439-
request_id=tool_context.session.id,
481+
conversation_id=session_id,
482+
request_id=session_id,
440483
agent_memory=self.agent_memory,
441484
)
442485

486+
return vanna_context
487+
443488
async def run_async(
444489
self, *, args: Dict[str, Any], tool_context: ToolContext
445490
) -> str:

0 commit comments

Comments
 (0)