Skip to content

Commit b597bb9

Browse files
committed
update call_llm() to use environ variable for LLM
1 parent dc2990e commit b597bb9

1 file changed

Lines changed: 72 additions & 200 deletions

File tree

utils/call_llm.py

Lines changed: 72 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import logging
44
import json
5+
import requests
56
from datetime import datetime
67

78
# Configure logging
@@ -24,207 +25,78 @@
2425
# Simple cache configuration
2526
cache_file = "llm_cache.json"
2627

27-
28-
# By default, we Google Gemini 2.5 pro, as it shows great performance for code understanding
29-
def call_llm(prompt: str, use_cache: bool = True) -> str:
30-
# Log the prompt
31-
logger.info(f"PROMPT: {prompt}")
32-
33-
# Check cache if enabled
34-
if use_cache:
35-
# Load cache from disk
36-
cache = {}
37-
if os.path.exists(cache_file):
38-
try:
39-
with open(cache_file, "r", encoding="utf-8") as f:
40-
cache = json.load(f)
41-
except:
42-
logger.warning(f"Failed to load cache, starting with empty cache")
43-
44-
# Return from cache if exists
45-
if prompt in cache:
46-
logger.info(f"RESPONSE: {cache[prompt]}")
47-
return cache[prompt]
48-
49-
# # Call the LLM if not in cache or cache disabled
50-
# client = genai.Client(
51-
# vertexai=True,
52-
# # TODO: change to your own project id and location
53-
# project=os.getenv("GEMINI_PROJECT_ID", "your-project-id"),
54-
# location=os.getenv("GEMINI_LOCATION", "us-central1")
55-
# )
56-
57-
# You can comment the previous line and use the AI Studio key instead:
58-
client = genai.Client(
59-
api_key=os.getenv("GEMINI_API_KEY", ""),
60-
)
61-
model = os.getenv("GEMINI_MODEL", "gemini-2.5-pro")
62-
# model = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
63-
64-
response = client.models.generate_content(model=model, contents=[prompt])
65-
response_text = response.text
66-
67-
# Log the response
68-
logger.info(f"RESPONSE: {response_text}")
69-
70-
# Update cache if enabled
71-
if use_cache:
72-
# Load cache again to avoid overwrites
73-
cache = {}
74-
if os.path.exists(cache_file):
75-
try:
76-
with open(cache_file, "r", encoding="utf-8") as f:
77-
cache = json.load(f)
78-
except:
79-
pass
80-
81-
# Add to cache and save
82-
cache[prompt] = response_text
28+
def call_llm(prompt, use_cache: bool = True) -> str:
29+
"""
30+
Call an LLM provider based on environment variables.
31+
Environment variables:
32+
- LLM_PROVIDER: "OLLAMA" or "XAI"
33+
- <provider>_MODEL: Model name (e.g., OLLAMA_MODEL, XAI_MODEL)
34+
- <provider>_BASE_URL: Base URL without endpoint (e.g., OLLAMA_BASE_URL, XAI_BASE_URL)
35+
- <provider>_API_KEY: API key (e.g., OLLAMA_API_KEY, XAI_API_KEY; optional for providers that don't require it)
36+
The endpoint /v1/chat/completions will be appended to the base URL.
37+
"""
38+
logger.info(f"PROMPT: {prompt}") # log the prompt
39+
40+
# Read the provider from environment variable
41+
provider = os.environ.get("LLM_PROVIDER")
42+
if not provider:
43+
raise ValueError("LLM_PROVIDER environment variable is required")
44+
45+
# Construct the names of the other environment variables
46+
model_var = f"{provider}_MODEL"
47+
base_url_var = f"{provider}_BASE_URL"
48+
api_key_var = f"{provider}_API_KEY"
49+
50+
# Read the provider-specific variables
51+
model = os.environ.get(model_var)
52+
base_url = os.environ.get(base_url_var)
53+
api_key = os.environ.get(api_key_var, "") # API key is optional, default to empty string
54+
55+
# Validate required variables
56+
if not model:
57+
raise ValueError(f"{model_var} environment variable is required")
58+
if not base_url:
59+
raise ValueError(f"{base_url_var} environment variable is required")
60+
61+
# Append the endpoint to the base URL
62+
url = f"{base_url}/v1/chat/completions"
63+
64+
# Configure headers and payload based on provider
65+
headers = {
66+
"Content-Type": "application/json",
67+
}
68+
if api_key: # Only add Authorization header if API key is provided
69+
headers["Authorization"] = f"Bearer {api_key}"
70+
71+
payload = {
72+
"model": model,
73+
"messages": [{"role": "user", "content": prompt}],
74+
"temperature": 0.7,
75+
}
76+
77+
try:
78+
response = requests.post(url, headers=headers, json=payload)
79+
response_json = response.json() # Log the response
80+
logger.info("RESPONSE:\n%s", json.dumps(response_json, indent=2))
81+
#logger.info(f"RESPONSE: {response.json()}")
82+
response.raise_for_status()
83+
return response.json()["choices"][0]["message"]["content"]
84+
except requests.exceptions.HTTPError as e:
85+
error_message = f"HTTP error occurred: {e}"
8386
try:
84-
with open(cache_file, "w", encoding="utf-8") as f:
85-
json.dump(cache, f)
86-
except Exception as e:
87-
logger.error(f"Failed to save cache: {e}")
88-
89-
return response_text
90-
91-
92-
# # Use Azure OpenAI
93-
# def call_llm(prompt, use_cache: bool = True):
94-
# from openai import AzureOpenAI
95-
96-
# endpoint = "https://<azure openai name>.openai.azure.com/"
97-
# deployment = "<deployment name>"
98-
99-
# subscription_key = "<azure openai key>"
100-
# api_version = "<api version>"
101-
102-
# client = AzureOpenAI(
103-
# api_version=api_version,
104-
# azure_endpoint=endpoint,
105-
# api_key=subscription_key,
106-
# )
107-
108-
# r = client.chat.completions.create(
109-
# model=deployment,
110-
# messages=[{"role": "user", "content": prompt}],
111-
# response_format={
112-
# "type": "text"
113-
# },
114-
# max_completion_tokens=40000,
115-
# reasoning_effort="medium",
116-
# store=False
117-
# )
118-
# return r.choices[0].message.content
119-
120-
# # Use Anthropic Claude 3.7 Sonnet Extended Thinking
121-
# def call_llm(prompt, use_cache: bool = True):
122-
# from anthropic import Anthropic
123-
# client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY", "your-api-key"))
124-
# response = client.messages.create(
125-
# model="claude-3-7-sonnet-20250219",
126-
# max_tokens=21000,
127-
# thinking={
128-
# "type": "enabled",
129-
# "budget_tokens": 20000
130-
# },
131-
# messages=[
132-
# {"role": "user", "content": prompt}
133-
# ]
134-
# )
135-
# return response.content[1].text
136-
137-
# # Use OpenAI o1
138-
# def call_llm(prompt, use_cache: bool = True):
139-
# from openai import OpenAI
140-
# client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key"))
141-
# r = client.chat.completions.create(
142-
# model="o1",
143-
# messages=[{"role": "user", "content": prompt}],
144-
# response_format={
145-
# "type": "text"
146-
# },
147-
# reasoning_effort="medium",
148-
# store=False
149-
# )
150-
# return r.choices[0].message.content
151-
152-
# Use OpenRouter API
153-
# def call_llm(prompt: str, use_cache: bool = True) -> str:
154-
# import requests
155-
# # Log the prompt
156-
# logger.info(f"PROMPT: {prompt}")
157-
158-
# # Check cache if enabled
159-
# if use_cache:
160-
# # Load cache from disk
161-
# cache = {}
162-
# if os.path.exists(cache_file):
163-
# try:
164-
# with open(cache_file, "r", encoding="utf-8") as f:
165-
# cache = json.load(f)
166-
# except:
167-
# logger.warning(f"Failed to load cache, starting with empty cache")
168-
169-
# # Return from cache if exists
170-
# if prompt in cache:
171-
# logger.info(f"RESPONSE: {cache[prompt]}")
172-
# return cache[prompt]
173-
174-
# # OpenRouter API configuration
175-
# api_key = os.getenv("OPENROUTER_API_KEY", "")
176-
# model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-exp:free")
177-
178-
# headers = {
179-
# "Authorization": f"Bearer {api_key}",
180-
# }
181-
182-
# data = {
183-
# "model": model,
184-
# "messages": [{"role": "user", "content": prompt}]
185-
# }
186-
187-
# response = requests.post(
188-
# "https://openrouter.ai/api/v1/chat/completions",
189-
# headers=headers,
190-
# json=data
191-
# )
192-
193-
# if response.status_code != 200:
194-
# error_msg = f"OpenRouter API call failed with status {response.status_code}: {response.text}"
195-
# logger.error(error_msg)
196-
# raise Exception(error_msg)
197-
# try:
198-
# response_text = response.json()["choices"][0]["message"]["content"]
199-
# except Exception as e:
200-
# error_msg = f"Failed to parse OpenRouter response: {e}; Response: {response.text}"
201-
# logger.error(error_msg)
202-
# raise Exception(error_msg)
203-
204-
205-
# # Log the response
206-
# logger.info(f"RESPONSE: {response_text}")
207-
208-
# # Update cache if enabled
209-
# if use_cache:
210-
# # Load cache again to avoid overwrites
211-
# cache = {}
212-
# if os.path.exists(cache_file):
213-
# try:
214-
# with open(cache_file, "r", encoding="utf-8") as f:
215-
# cache = json.load(f)
216-
# except:
217-
# pass
218-
219-
# # Add to cache and save
220-
# cache[prompt] = response_text
221-
# try:
222-
# with open(cache_file, "w", encoding="utf-8") as f:
223-
# json.dump(cache, f)
224-
# except Exception as e:
225-
# logger.error(f"Failed to save cache: {e}")
226-
227-
# return response_text
87+
error_details = response.json().get("error", "No additional details")
88+
error_message += f" (Details: {error_details})"
89+
except:
90+
pass
91+
raise Exception(error_message)
92+
except requests.exceptions.ConnectionError:
93+
raise Exception(f"Failed to connect to {provider} API. Check your network connection.")
94+
except requests.exceptions.Timeout:
95+
raise Exception(f"Request to {provider} API timed out.")
96+
except requests.exceptions.RequestException as e:
97+
raise Exception(f"An error occurred while making the request to {provider}: {e}")
98+
except ValueError:
99+
raise Exception(f"Failed to parse response as JSON from {provider}. The server might have returned an invalid response.")
228100

229101
if __name__ == "__main__":
230102
test_prompt = "Hello, how are you?"

0 commit comments

Comments
 (0)