Skip to content

Commit c8a8ca1

Browse files
authored
Merge pull request #165 from taqtiqa-mark/pr-50
Automatically switch provider based on envirnment variables, Ollama support: closes #13 & #50
2 parents dc2990e + c6eaaff commit c8a8ca1

2 files changed

Lines changed: 125 additions & 182 deletions

File tree

README.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,7 @@ This is a tutorial project of [Pocket Flow](https://github.com/The-Pocket/Pocket
8686
pip install -r requirements.txt
8787
```
8888

89-
4. Set up LLM in [`utils/call_llm.py`](./utils/call_llm.py) by providing credentials. By default, you can use the [AI Studio key](https://aistudio.google.com/app/apikey) with this client for Gemini Pro 2.5:
90-
91-
```python
92-
client = genai.Client(
93-
api_key=os.getenv("GEMINI_API_KEY", "your-api_key"),
94-
)
95-
```
96-
89+
4. Set up LLM in [`utils/call_llm.py`](./utils/call_llm.py) by providing credentials. To do so, you can put the values in a `.env` file. By default, you can use the AI Studio key with this client for Gemini Pro 2.5 by setting the `GEMINI_API_KEY` environment variable. If you want to use another LLM, you can set the `LLM_PROVIDER` environment variable (e.g. `XAI`), and then set the model, url, and API key (e.g. `XAI_MODEL`, `XAI_URL`,`XAI_API_KEY`). If using Ollama, the url is `http://localhost:11434/` and the API key can be omitted.
9790
You can use your own models. We highly recommend the latest models with thinking capabilities (Claude 3.7 with thinking, O1). You can verify that it is correctly set up by running:
9891
```bash
9992
python utils/call_llm.py

utils/call_llm.py

Lines changed: 124 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import logging
44
import json
5+
import requests
56
from datetime import datetime
67

78
# Configure logging
@@ -25,6 +26,104 @@
2526
cache_file = "llm_cache.json"
2627

2728

29+
def load_cache():
30+
try:
31+
with open(cache_file, 'r') as f:
32+
return json.load(f)
33+
except:
34+
logger.warning(f"Failed to load cache.")
35+
return {}
36+
37+
38+
def save_cache(cache):
39+
try:
40+
with open(cache_file, 'w') as f:
41+
json.dump(cache, f)
42+
except:
43+
logger.warning(f"Failed to save cache")
44+
45+
46+
def get_llm_provider():
47+
provider = os.getenv("LLM_PROVIDER")
48+
if not provider and (os.getenv("GEMINI_PROJECT_ID") or os.getenv("GEMINI_API_KEY")):
49+
provider = "GEMINI"
50+
# if necessary, add ANTHROPIC/OPENAI
51+
return provider
52+
53+
54+
def _call_llm_provider(prompt: str) -> str:
55+
"""
56+
Call an LLM provider based on environment variables.
57+
Environment variables:
58+
- LLM_PROVIDER: "OLLAMA" or "XAI"
59+
- <provider>_MODEL: Model name (e.g., OLLAMA_MODEL, XAI_MODEL)
60+
- <provider>_BASE_URL: Base URL without endpoint (e.g., OLLAMA_BASE_URL, XAI_BASE_URL)
61+
- <provider>_API_KEY: API key (e.g., OLLAMA_API_KEY, XAI_API_KEY; optional for providers that don't require it)
62+
The endpoint /v1/chat/completions will be appended to the base URL.
63+
"""
64+
logger.info(f"PROMPT: {prompt}") # log the prompt
65+
66+
# Read the provider from environment variable
67+
provider = os.environ.get("LLM_PROVIDER")
68+
if not provider:
69+
raise ValueError("LLM_PROVIDER environment variable is required")
70+
71+
# Construct the names of the other environment variables
72+
model_var = f"{provider}_MODEL"
73+
base_url_var = f"{provider}_BASE_URL"
74+
api_key_var = f"{provider}_API_KEY"
75+
76+
# Read the provider-specific variables
77+
model = os.environ.get(model_var)
78+
base_url = os.environ.get(base_url_var)
79+
api_key = os.environ.get(api_key_var, "") # API key is optional, default to empty string
80+
81+
# Validate required variables
82+
if not model:
83+
raise ValueError(f"{model_var} environment variable is required")
84+
if not base_url:
85+
raise ValueError(f"{base_url_var} environment variable is required")
86+
87+
# Append the endpoint to the base URL
88+
url = f"{base_url.rstrip('/')}/v1/chat/completions"
89+
90+
# Configure headers and payload based on provider
91+
headers = {
92+
"Content-Type": "application/json",
93+
}
94+
if api_key: # Only add Authorization header if API key is provided
95+
headers["Authorization"] = f"Bearer {api_key}"
96+
97+
payload = {
98+
"model": model,
99+
"messages": [{"role": "user", "content": prompt}],
100+
"temperature": 0.7,
101+
}
102+
103+
try:
104+
response = requests.post(url, headers=headers, json=payload)
105+
response_json = response.json() # Log the response
106+
logger.info("RESPONSE:\n%s", json.dumps(response_json, indent=2))
107+
#logger.info(f"RESPONSE: {response.json()}")
108+
response.raise_for_status()
109+
return response.json()["choices"][0]["message"]["content"]
110+
except requests.exceptions.HTTPError as e:
111+
error_message = f"HTTP error occurred: {e}"
112+
try:
113+
error_details = response.json().get("error", "No additional details")
114+
error_message += f" (Details: {error_details})"
115+
except:
116+
pass
117+
raise Exception(error_message)
118+
except requests.exceptions.ConnectionError:
119+
raise Exception(f"Failed to connect to {provider} API. Check your network connection.")
120+
except requests.exceptions.Timeout:
121+
raise Exception(f"Request to {provider} API timed out.")
122+
except requests.exceptions.RequestException as e:
123+
raise Exception(f"An error occurred while making the request to {provider}: {e}")
124+
except ValueError:
125+
raise Exception(f"Failed to parse response as JSON from {provider}. The server might have returned an invalid response.")
126+
28127
# By default, we Google Gemini 2.5 pro, as it shows great performance for code understanding
29128
def call_llm(prompt: str, use_cache: bool = True) -> str:
30129
# Log the prompt
@@ -33,198 +132,49 @@ def call_llm(prompt: str, use_cache: bool = True) -> str:
33132
# Check cache if enabled
34133
if use_cache:
35134
# Load cache from disk
36-
cache = {}
37-
if os.path.exists(cache_file):
38-
try:
39-
with open(cache_file, "r", encoding="utf-8") as f:
40-
cache = json.load(f)
41-
except:
42-
logger.warning(f"Failed to load cache, starting with empty cache")
43-
135+
cache = load_cache()
44136
# Return from cache if exists
45137
if prompt in cache:
46138
logger.info(f"RESPONSE: {cache[prompt]}")
47139
return cache[prompt]
48140

49-
# # Call the LLM if not in cache or cache disabled
50-
# client = genai.Client(
51-
# vertexai=True,
52-
# # TODO: change to your own project id and location
53-
# project=os.getenv("GEMINI_PROJECT_ID", "your-project-id"),
54-
# location=os.getenv("GEMINI_LOCATION", "us-central1")
55-
# )
56-
57-
# You can comment the previous line and use the AI Studio key instead:
58-
client = genai.Client(
59-
api_key=os.getenv("GEMINI_API_KEY", ""),
60-
)
61-
model = os.getenv("GEMINI_MODEL", "gemini-2.5-pro")
62-
# model = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
63-
64-
response = client.models.generate_content(model=model, contents=[prompt])
65-
response_text = response.text
141+
provider = get_llm_provider()
142+
if provider == "GEMINI":
143+
response_text = _call_llm_gemini(prompt)
144+
else: # generic method using a URL that is OpenAI compatible API (Ollama, ...)
145+
response_text = _call_llm_provider(prompt)
66146

67147
# Log the response
68148
logger.info(f"RESPONSE: {response_text}")
69149

70150
# Update cache if enabled
71151
if use_cache:
72152
# Load cache again to avoid overwrites
73-
cache = {}
74-
if os.path.exists(cache_file):
75-
try:
76-
with open(cache_file, "r", encoding="utf-8") as f:
77-
cache = json.load(f)
78-
except:
79-
pass
80-
153+
cache = load_cache()
81154
# Add to cache and save
82155
cache[prompt] = response_text
83-
try:
84-
with open(cache_file, "w", encoding="utf-8") as f:
85-
json.dump(cache, f)
86-
except Exception as e:
87-
logger.error(f"Failed to save cache: {e}")
156+
save_cache(cache)
88157

89158
return response_text
90159

91160

92-
# # Use Azure OpenAI
93-
# def call_llm(prompt, use_cache: bool = True):
94-
# from openai import AzureOpenAI
95-
96-
# endpoint = "https://<azure openai name>.openai.azure.com/"
97-
# deployment = "<deployment name>"
98-
99-
# subscription_key = "<azure openai key>"
100-
# api_version = "<api version>"
101-
102-
# client = AzureOpenAI(
103-
# api_version=api_version,
104-
# azure_endpoint=endpoint,
105-
# api_key=subscription_key,
106-
# )
107-
108-
# r = client.chat.completions.create(
109-
# model=deployment,
110-
# messages=[{"role": "user", "content": prompt}],
111-
# response_format={
112-
# "type": "text"
113-
# },
114-
# max_completion_tokens=40000,
115-
# reasoning_effort="medium",
116-
# store=False
117-
# )
118-
# return r.choices[0].message.content
119-
120-
# # Use Anthropic Claude 3.7 Sonnet Extended Thinking
121-
# def call_llm(prompt, use_cache: bool = True):
122-
# from anthropic import Anthropic
123-
# client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY", "your-api-key"))
124-
# response = client.messages.create(
125-
# model="claude-3-7-sonnet-20250219",
126-
# max_tokens=21000,
127-
# thinking={
128-
# "type": "enabled",
129-
# "budget_tokens": 20000
130-
# },
131-
# messages=[
132-
# {"role": "user", "content": prompt}
133-
# ]
134-
# )
135-
# return response.content[1].text
136-
137-
# # Use OpenAI o1
138-
# def call_llm(prompt, use_cache: bool = True):
139-
# from openai import OpenAI
140-
# client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
141-
# r = client.chat.completions.create(
142-
# model="o1",
143-
# messages=[{"role": "user", "content": prompt}],
144-
# response_format={
145-
# "type": "text"
146-
# },
147-
# reasoning_effort="medium",
148-
# store=False
149-
# )
150-
# return r.choices[0].message.content
151-
152-
# Use OpenRouter API
153-
# def call_llm(prompt: str, use_cache: bool = True) -> str:
154-
# import requests
155-
# # Log the prompt
156-
# logger.info(f"PROMPT: {prompt}")
157-
158-
# # Check cache if enabled
159-
# if use_cache:
160-
# # Load cache from disk
161-
# cache = {}
162-
# if os.path.exists(cache_file):
163-
# try:
164-
# with open(cache_file, "r", encoding="utf-8") as f:
165-
# cache = json.load(f)
166-
# except:
167-
# logger.warning(f"Failed to load cache, starting with empty cache")
168-
169-
# # Return from cache if exists
170-
# if prompt in cache:
171-
# logger.info(f"RESPONSE: {cache[prompt]}")
172-
# return cache[prompt]
173-
174-
# # OpenRouter API configuration
175-
# api_key = os.getenv("OPENROUTER_API_KEY", "")
176-
# model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-exp:free")
177-
178-
# headers = {
179-
# "Authorization": f"Bearer {api_key}",
180-
# }
181-
182-
# data = {
183-
# "model": model,
184-
# "messages": [{"role": "user", "content": prompt}]
185-
# }
186-
187-
# response = requests.post(
188-
# "https://openrouter.ai/api/v1/chat/completions",
189-
# headers=headers,
190-
# json=data
191-
# )
192-
193-
# if response.status_code != 200:
194-
# error_msg = f"OpenRouter API call failed with status {response.status_code}: {response.text}"
195-
# logger.error(error_msg)
196-
# raise Exception(error_msg)
197-
# try:
198-
# response_text = response.json()["choices"][0]["message"]["content"]
199-
# except Exception as e:
200-
# error_msg = f"Failed to parse OpenRouter response: {e}; Response: {response.text}"
201-
# logger.error(error_msg)
202-
# raise Exception(error_msg)
203-
204-
205-
# # Log the response
206-
# logger.info(f"RESPONSE: {response_text}")
207-
208-
# # Update cache if enabled
209-
# if use_cache:
210-
# # Load cache again to avoid overwrites
211-
# cache = {}
212-
# if os.path.exists(cache_file):
213-
# try:
214-
# with open(cache_file, "r", encoding="utf-8") as f:
215-
# cache = json.load(f)
216-
# except:
217-
# pass
218-
219-
# # Add to cache and save
220-
# cache[prompt] = response_text
221-
# try:
222-
# with open(cache_file, "w", encoding="utf-8") as f:
223-
# json.dump(cache, f)
224-
# except Exception as e:
225-
# logger.error(f"Failed to save cache: {e}")
226-
227-
# return response_text
161+
def _call_llm_gemini(prompt: str) -> str:
162+
if os.getenv("GEMINI_PROJECT_ID"):
163+
client = genai.Client(
164+
vertexai=True,
165+
project=os.getenv("GEMINI_PROJECT_ID"),
166+
location=os.getenv("GEMINI_LOCATION", "us-central1")
167+
)
168+
elif os.getenv("GEMINI_API_KEY"):
169+
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
170+
else:
171+
raise ValueError("Either GEMINI_PROJECT_ID or GEMINI_API_KEY must be set in the environment")
172+
model = os.getenv("GEMINI_MODEL", "gemini-2.5-pro-exp-03-25")
173+
response = client.models.generate_content(
174+
model=model,
175+
contents=[prompt]
176+
)
177+
return response.text
228178

229179
if __name__ == "__main__":
230180
test_prompt = "Hello, how are you?"

0 commit comments

Comments
 (0)