Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 90 additions & 21 deletions veadk/cli/cli_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,95 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import click

from veadk.memory.long_term_memory import LongTermMemory
from veadk.memory.short_term_memory import ShortTermMemory


def _get_stm_from_module(module) -> ShortTermMemory:
return module.agent_run_config.short_term_memory


def _get_stm_from_env() -> ShortTermMemory:
import os

from veadk.utils.logger import get_logger

logger = get_logger(__name__)

short_term_memory_backend = os.getenv("SHORT_TERM_MEMORY_BACKEND")
if not short_term_memory_backend: # prevent None or empty string
short_term_memory_backend = "local"
logger.info(f"Short term memory: backend={short_term_memory_backend}")

return ShortTermMemory(backend=short_term_memory_backend) # type: ignore


def _get_ltm_from_module(module) -> LongTermMemory | None:
agent = module.agent_run_config.agent

if not hasattr(agent, "long_term_memory"):
return None
else:
return agent.long_term_memory


def _get_ltm_from_env() -> LongTermMemory | None:
import os

from veadk.utils.logger import get_logger

logger = get_logger(__name__)

long_term_memory_backend = os.getenv("LONG_TERM_MEMORY_BACKEND")

if long_term_memory_backend:
logger.info(f"Long term memory: backend={long_term_memory_backend}")
return LongTermMemory(backend=long_term_memory_backend) # type: ignore
else:
logger.warning("No long term memory backend settings detected.")
return None


def _get_memory(
module_path: str,
) -> tuple[ShortTermMemory, LongTermMemory | None]:
from veadk.utils.logger import get_logger
from veadk.utils.misc import load_module_from_file

logger = get_logger(__name__)

# 1. load user module
try:
module_file_path = module_path
module = load_module_from_file(
module_name="agent_and_mem", file_path=f"{module_file_path}/agent.py"
)
except Exception as e:
logger.error(
f"Failed to get memory config from `agent.py`: {e}. Fallback to get memory from environment variables."
)
return _get_stm_from_env(), _get_ltm_from_env()

if not hasattr(module, "agent_run_config"):
logger.error(
"You must export `agent_run_config` as a global variable in `agent.py`. Fallback to get memory from environment variables."
)
return _get_stm_from_env(), _get_ltm_from_env()

# 2. try to get short term memory
# short term memory must exist in user code, as we use `default_factory` to init it
short_term_memory = _get_stm_from_module(module)

# 3. try to get long term memory
long_term_memory = _get_ltm_from_module(module)
if not long_term_memory:
long_term_memory = _get_ltm_from_env()

return short_term_memory, long_term_memory


@click.command()
@click.option("--host", default="127.0.0.1", help="Host to run the web server on")
Expand All @@ -24,7 +111,6 @@ def web(host: str) -> None:

from google.adk.cli.utils.shared_value import SharedValue

from veadk.memory.short_term_memory import ShortTermMemory
from veadk.utils.logger import get_logger

logger = get_logger(__name__)
Expand All @@ -51,26 +137,9 @@ def init_for_veadk(
self.current_app_name_ref = SharedValue(value="")
self.runner_dict = {}

short_term_memory_backend = os.getenv("SHORT_TERM_MEMORY_BACKEND")
if not short_term_memory_backend: # prevent None or empty string
short_term_memory_backend = "local"
logger.info(f"Short term memory: backend={short_term_memory_backend}")

long_term_memory_backend = os.getenv("LONG_TERM_MEMORY_BACKEND")
long_term_memory = None

if long_term_memory_backend:
from veadk.memory.long_term_memory import LongTermMemory

logger.info(f"Long term memory: backend={long_term_memory_backend}")
long_term_memory = LongTermMemory(backend=long_term_memory_backend) # type: ignore
else:
logger.info("No long term memory backend settings detected.")

self.session_service = ShortTermMemory(
backend=short_term_memory_backend # type: ignore
).session_service

# parse VeADK memories
short_term_memory, long_term_memory = _get_memory(module_path=agents_dir)
self.session_service = short_term_memory.session_service
self.memory_service = long_term_memory

import google.adk.cli.adk_web_server
Expand Down
87 changes: 82 additions & 5 deletions veadk/cloud/cloud_agent_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import os
import socket
import subprocess
import time
from pathlib import Path
from typing import Any

Expand All @@ -22,7 +25,7 @@
from veadk.config import getenv
from veadk.integrations.ve_faas.ve_faas import VeFaaS
from veadk.utils.logger import get_logger
from veadk.utils.misc import formatted_timestamp
from veadk.utils.misc import formatted_timestamp, load_module_from_file

logger = get_logger(__name__)

Expand Down Expand Up @@ -65,9 +68,9 @@ def _prepare(self, path: str, name: str):
# prepare template files if not have
template_files = [
"app.py",
"studio_app.py",
# "studio_app.py",
"run.sh",
"requirements.txt",
# "requirements.txt",
"__init__.py",
]
for template_file in template_files:
Expand All @@ -88,6 +91,70 @@ def _prepare(self, path: str, name: str):

shutil.copy(template_file_path, os.path.join(path, template_file))

# copy user's requirements.txt
if os.path.exists(os.path.join(path, "requirements.txt")):
logger.warning(
f"Local agent project path `{path}` contains a `requirements.txt` file. Skip copy requirements."
)
return

module = load_module_from_file(
module_name="agent_source", file_path=f"{path}/agent.py"
)

requirement_file_path = module.agent_run_config.requirement_file_path
shutil.copy(requirement_file_path, os.path.join(path, "requirements.txt"))

logger.info(
f"Copy requirement file: from {requirement_file_path} to {path}/requirements.txt"
)

def _try_launch_fastapi_server(self, path: str):
"""Try to launch a fastapi server for tests according to user's configuration.

Args:
path (str): Local agent project path.
"""
RUN_SH = f"{path}/run.sh"

HOST = "0.0.0.0"
PORT = 8000

# Prepare environment variables
os.environ["_FAAS_FUNC_TIMEOUT"] = "900"
env = os.environ.copy()

process = subprocess.Popen(
["bash", RUN_SH],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
env=env,
bufsize=1,
)

timeout = 30
start_time = time.time()

for line in process.stdout: # type: ignore
print(line, end="")

if time.time() - start_time > timeout:
process.terminate()
raise RuntimeError(f"FastAPI server failed to start on {HOST}:{PORT}")
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(0.1)
s.connect(("127.0.0.1", PORT))
logger.info(f"FastAPI server is listening on {HOST}:{PORT}")
logger.info("Local deplyment test successfully.")
break
except (ConnectionRefusedError, socket.timeout):
continue

process.terminate()
process.wait()

def deploy(
self,
application_name: str,
Expand All @@ -97,15 +164,22 @@ def deploy(
gateway_upstream_name: str = "",
use_studio: bool = False,
use_adk_web: bool = False,
local_test: bool = False,
) -> CloudApp:
"""Deploy local agent project to Volcengine FaaS platform.

Args:
application_name (str): Expected VeFaaS application name.
path (str): Local agent project path.
name (str): Volcengine FaaS function name.
gateway_name (str): Gateway name.
gateway_service_name (str): Gateway service name.
gateway_upstream_name (str): Gateway upstream name.
use_studio (bool): Whether to use Studio [deprecated].
use_adk_web (bool): Whether to use ADK Web.
local_test (bool): Whether to run local test for FastAPI Server.

Returns:
str: Volcengine FaaS function endpoint.
CloudApp: The deployed cloud application instance.
"""
assert not (use_studio and use_adk_web), (
"use_studio and use_adk_web can not be True at the same time."
Expand Down Expand Up @@ -136,6 +210,9 @@ def deploy(
path = str(Path(path).resolve())
self._prepare(path, application_name)

if local_test:
self._try_launch_fastapi_server(path)

if not gateway_name:
gateway_name = f"{application_name}-gw-{formatted_timestamp()}"
if not gateway_service_name:
Expand Down
Loading
Loading