Skip to content

Commit c6eaaff

Browse files
committed
switch provider based on environment variables
1 parent b597bb9 commit c6eaaff

2 files changed

Lines changed: 81 additions & 10 deletions

File tree

README.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,7 @@ This is a tutorial project of [Pocket Flow](https://github.com/The-Pocket/Pocket
8686
pip install -r requirements.txt
8787
```
8888

89-
4. Set up LLM in [`utils/call_llm.py`](./utils/call_llm.py) by providing credentials. By default, you can use the [AI Studio key](https://aistudio.google.com/app/apikey) with this client for Gemini Pro 2.5:
90-
91-
```python
92-
client = genai.Client(
93-
api_key=os.getenv("GEMINI_API_KEY", "your-api_key"),
94-
)
95-
```
96-
89+
4. Set up LLM in [`utils/call_llm.py`](./utils/call_llm.py) by providing credentials. To do so, you can put the values in a `.env` file. By default, you can use the AI Studio key with this client for Gemini Pro 2.5 by setting the `GEMINI_API_KEY` environment variable. If you want to use another LLM, you can set the `LLM_PROVIDER` environment variable (e.g. `XAI`), and then set the model, url, and API key (e.g. `XAI_MODEL`, `XAI_URL`,`XAI_API_KEY`). If using Ollama, the url is `http://localhost:11434/` and the API key can be omitted.
9790
You can use your own models. We highly recommend the latest models with thinking capabilities (Claude 3.7 with thinking, O1). You can verify that it is correctly set up by running:
9891
```bash
9992
python utils/call_llm.py

utils/call_llm.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,33 @@
2525
# Simple cache configuration
2626
cache_file = "llm_cache.json"
2727

28-
def call_llm(prompt, use_cache: bool = True) -> str:
28+
29+
def load_cache():
30+
try:
31+
with open(cache_file, 'r') as f:
32+
return json.load(f)
33+
except:
34+
logger.warning(f"Failed to load cache.")
35+
return {}
36+
37+
38+
def save_cache(cache):
39+
try:
40+
with open(cache_file, 'w') as f:
41+
json.dump(cache, f)
42+
except:
43+
logger.warning(f"Failed to save cache")
44+
45+
46+
def get_llm_provider():
47+
provider = os.getenv("LLM_PROVIDER")
48+
if not provider and (os.getenv("GEMINI_PROJECT_ID") or os.getenv("GEMINI_API_KEY")):
49+
provider = "GEMINI"
50+
# if necessary, add ANTHROPIC/OPENAI
51+
return provider
52+
53+
54+
def _call_llm_provider(prompt: str) -> str:
2955
"""
3056
Call an LLM provider based on environment variables.
3157
Environment variables:
@@ -59,7 +85,7 @@ def call_llm(prompt, use_cache: bool = True) -> str:
5985
raise ValueError(f"{base_url_var} environment variable is required")
6086

6187
# Append the endpoint to the base URL
62-
url = f"{base_url}/v1/chat/completions"
88+
url = f"{base_url.rstrip('/')}/v1/chat/completions"
6389

6490
# Configure headers and payload based on provider
6591
headers = {
@@ -98,6 +124,58 @@ def call_llm(prompt, use_cache: bool = True) -> str:
98124
except ValueError:
99125
raise Exception(f"Failed to parse response as JSON from {provider}. The server might have returned an invalid response.")
100126

127+
# By default, we Google Gemini 2.5 pro, as it shows great performance for code understanding
128+
def call_llm(prompt: str, use_cache: bool = True) -> str:
129+
# Log the prompt
130+
logger.info(f"PROMPT: {prompt}")
131+
132+
# Check cache if enabled
133+
if use_cache:
134+
# Load cache from disk
135+
cache = load_cache()
136+
# Return from cache if exists
137+
if prompt in cache:
138+
logger.info(f"RESPONSE: {cache[prompt]}")
139+
return cache[prompt]
140+
141+
provider = get_llm_provider()
142+
if provider == "GEMINI":
143+
response_text = _call_llm_gemini(prompt)
144+
else: # generic method using a URL that is OpenAI compatible API (Ollama, ...)
145+
response_text = _call_llm_provider(prompt)
146+
147+
# Log the response
148+
logger.info(f"RESPONSE: {response_text}")
149+
150+
# Update cache if enabled
151+
if use_cache:
152+
# Load cache again to avoid overwrites
153+
cache = load_cache()
154+
# Add to cache and save
155+
cache[prompt] = response_text
156+
save_cache(cache)
157+
158+
return response_text
159+
160+
161+
def _call_llm_gemini(prompt: str) -> str:
162+
if os.getenv("GEMINI_PROJECT_ID"):
163+
client = genai.Client(
164+
vertexai=True,
165+
project=os.getenv("GEMINI_PROJECT_ID"),
166+
location=os.getenv("GEMINI_LOCATION", "us-central1")
167+
)
168+
elif os.getenv("GEMINI_API_KEY"):
169+
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
170+
else:
171+
raise ValueError("Either GEMINI_PROJECT_ID or GEMINI_API_KEY must be set in the environment")
172+
model = os.getenv("GEMINI_MODEL", "gemini-2.5-pro-exp-03-25")
173+
response = client.models.generate_content(
174+
model=model,
175+
contents=[prompt]
176+
)
177+
return response.text
178+
101179
if __name__ == "__main__":
102180
test_prompt = "Hello, how are you?"
103181

0 commit comments

Comments
 (0)