Skip to content

Commit d66408f

Browse files
committed
fix type error in some files
1 parent 0d1b92a commit d66408f

5 files changed

Lines changed: 44 additions & 46 deletions

File tree

veadk/database/viking/viking_database.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import os
1818
import uuid
19-
from typing import Any, BinaryIO, Literal, Optional, TextIO
19+
from typing import Any, BinaryIO, Literal, TextIO
2020

2121
import requests
2222
import tos
@@ -44,42 +44,44 @@
4444

4545

4646
class VolcengineTOSConfig(BaseModel):
47-
endpoint: Optional[str] = Field(
48-
default=getenv("DATABASE_TOS_ENDPOINT", "tos-cn-beijing.volces.com"),
47+
endpoint: str = Field(
48+
default_factory=lambda: getenv(
49+
"DATABASE_TOS_ENDPOINT", "tos-cn-beijing.volces.com"
50+
),
4951
description="VikingDB TOS endpoint",
5052
)
51-
region: Optional[str] = Field(
52-
default=getenv("DATABASE_TOS_REGION", "cn-beijing"),
53+
region: str = Field(
54+
default_factory=lambda: getenv("DATABASE_TOS_REGION", "cn-beijing"),
5355
description="VikingDB TOS region",
5456
)
55-
bucket: Optional[str] = Field(
56-
default=getenv("DATABASE_TOS_BUCKET"),
57+
bucket: str = Field(
58+
default_factory=lambda: getenv("DATABASE_TOS_BUCKET"),
5759
description="VikingDB TOS bucket",
5860
)
59-
base_key: Optional[str] = Field(
61+
base_key: str = Field(
6062
default="veadk",
6163
description="VikingDB TOS base key",
6264
)
6365

6466

6567
class VikingDatabaseConfig(BaseModel):
66-
volcengine_ak: Optional[str] = Field(
67-
default=getenv("VOLCENGINE_ACCESS_KEY"),
68+
volcengine_ak: str = Field(
69+
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"),
6870
description="VikingDB access key",
6971
)
70-
volcengine_sk: Optional[str] = Field(
71-
default=getenv("VOLCENGINE_SECRET_KEY"),
72+
volcengine_sk: str = Field(
73+
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"),
7274
description="VikingDB secret key",
7375
)
74-
project: Optional[str] = Field(
75-
default=getenv("DATABASE_VIKING_PROJECT"),
76+
project: str = Field(
77+
default_factory=lambda: getenv("DATABASE_VIKING_PROJECT"),
7678
description="VikingDB project name",
7779
)
78-
region: Optional[str] = Field(
79-
default=getenv("DATABASE_VIKING_REGION"),
80+
region: str = Field(
81+
default_factory=lambda: getenv("DATABASE_VIKING_REGION"),
8082
description="VikingDB region",
8183
)
82-
tos: Optional[VolcengineTOSConfig] = Field(
84+
tos: VolcengineTOSConfig = Field(
8385
default_factory=VolcengineTOSConfig,
8486
description="VikingDB TOS configuration",
8587
)

veadk/database/viking/viking_memory_db.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import threading
1919
import time
2020
from datetime import datetime
21-
from typing import Any, Optional
21+
from typing import Any
2222

2323
from pydantic import BaseModel, Field
2424
from volcengine.ApiInfo import ApiInfo
@@ -35,20 +35,20 @@
3535

3636

3737
class VikingMemConfig(BaseModel):
38-
volcengine_ak: Optional[str] = Field(
39-
default=getenv("VOLCENGINE_ACCESS_KEY"),
38+
volcengine_ak: str = Field(
39+
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"),
4040
description="VikingDB access key",
4141
)
42-
volcengine_sk: Optional[str] = Field(
43-
default=getenv("VOLCENGINE_SECRET_KEY"),
42+
volcengine_sk: str = Field(
43+
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"),
4444
description="VikingDB secret key",
4545
)
46-
project: Optional[str] = Field(
47-
default=getenv("DATABASE_VIKING_PROJECT"),
46+
project: str = Field(
47+
default_factory=lambda: getenv("DATABASE_VIKING_PROJECT"),
4848
description="VikingDB project name",
4949
)
50-
region: Optional[str] = Field(
51-
default=getenv("DATABASE_VIKING_REGION"),
50+
region: str = Field(
51+
default_factory=lambda: getenv("DATABASE_VIKING_REGION"),
5252
description="VikingDB region",
5353
)
5454

veadk/evaluation/adk_evaluator/adk_evaluator.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any, Optional
2020

2121
from google.adk import Runner
22+
from google.adk.agents.base_agent import BaseAgent
2223
from google.adk.artifacts import BaseArtifactService, InMemoryArtifactService
2324
from google.adk.evaluation.agent_evaluator import (
2425
NUM_RUNS,
@@ -68,7 +69,7 @@ async def ve_generate_responses( # done
6869
eval_set: EvalSet,
6970
agent: Agent,
7071
repeat_num: int = 3,
71-
agent_name: str = None,
72+
agent_name: str | None = None,
7273
):
7374
results = []
7475

@@ -90,7 +91,7 @@ async def ve_generate_responses( # done
9091
@staticmethod
9192
async def _ve_generate_inferences_from_root_agent(
9293
invocations: list[Invocation],
93-
root_agent: Agent,
94+
root_agent: BaseAgent,
9495
reset_func: Any,
9596
initial_session: Optional[SessionInput] = None,
9697
session_id: Optional[str] = None,
@@ -117,21 +118,15 @@ async def _ve_generate_inferences_from_root_agent(
117118
if not artifact_service:
118119
artifact_service = InMemoryArtifactService()
119120

120-
if getattr(root_agent, "long_term_memory", None) is not None:
121-
runner = Runner(
122-
app_name=app_name,
123-
agent=root_agent,
124-
artifact_service=artifact_service,
125-
session_service=session_service,
126-
memory_service=root_agent.long_term_memory, # add long_term_memory
127-
)
128-
else:
129-
runner = Runner(
130-
app_name=app_name,
131-
agent=root_agent,
132-
artifact_service=artifact_service,
133-
session_service=session_service,
134-
)
121+
runner = Runner(
122+
app_name=app_name,
123+
agent=root_agent,
124+
artifact_service=artifact_service,
125+
session_service=session_service,
126+
memory_service=root_agent.long_term_memory
127+
if isinstance(root_agent, Agent)
128+
else None,
129+
)
135130

136131
# Reset agent state for each query
137132
if callable(reset_func):

veadk/tools/load_knowledgebase_tool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ async def load_knowledgebase(
8585
Returns:
8686
A list of knowledgebase results.
8787
"""
88-
search_knowledgebase_response = await tool_context.search_knowledgebase(
88+
89+
search_knowledgebase_response = await tool_context.search_knowledgebase( # type: ignore[attr-defined]
8990
query, tool_context._invocation_context.app_name
9091
)
9192
return LoadKnowledgebaseResponse(

veadk/tracing/telemetry/exporters/inmemory_exporter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def export(self, spans: typing.Sequence[ReadableSpan]) -> export.SpanExportResul
4747
)
4848

4949
if span.name == "call_llm":
50-
attributes = dict(span.attributes)
50+
attributes = dict(span.attributes or {})
5151
prompt_token = attributes.get("gen_ai.usage.prompt_tokens", None)
5252
completion_token = attributes.get(
5353
"gen_ai.usage.completion_tokens", None
@@ -57,7 +57,7 @@ def export(self, spans: typing.Sequence[ReadableSpan]) -> export.SpanExportResul
5757
if completion_token:
5858
self.completion_tokens.append(completion_token)
5959
if span.name == "call_llm":
60-
attributes = dict(span.attributes)
60+
attributes = dict(span.attributes or {})
6161
session_id = attributes.get("gcp.vertex.agent.session_id", None)
6262
if session_id:
6363
if session_id not in self.session_trace_dict:

0 commit comments

Comments
 (0)