|
2 | 2 | import os |
3 | 3 | import logging |
4 | 4 | import json |
| 5 | +import requests |
5 | 6 | from datetime import datetime |
6 | 7 |
|
7 | 8 | # Configure logging |
|
24 | 25 | # Simple cache configuration |
25 | 26 | cache_file = "llm_cache.json" |
26 | 27 |
|
27 | | - |
28 | | -# By default, we Google Gemini 2.5 pro, as it shows great performance for code understanding |
29 | | -def call_llm(prompt: str, use_cache: bool = True) -> str: |
30 | | - # Log the prompt |
31 | | - logger.info(f"PROMPT: {prompt}") |
32 | | - |
33 | | - # Check cache if enabled |
34 | | - if use_cache: |
35 | | - # Load cache from disk |
36 | | - cache = {} |
37 | | - if os.path.exists(cache_file): |
38 | | - try: |
39 | | - with open(cache_file, "r", encoding="utf-8") as f: |
40 | | - cache = json.load(f) |
41 | | - except: |
42 | | - logger.warning(f"Failed to load cache, starting with empty cache") |
43 | | - |
44 | | - # Return from cache if exists |
45 | | - if prompt in cache: |
46 | | - logger.info(f"RESPONSE: {cache[prompt]}") |
47 | | - return cache[prompt] |
48 | | - |
49 | | - # # Call the LLM if not in cache or cache disabled |
50 | | - # client = genai.Client( |
51 | | - # vertexai=True, |
52 | | - # # TODO: change to your own project id and location |
53 | | - # project=os.getenv("GEMINI_PROJECT_ID", "your-project-id"), |
54 | | - # location=os.getenv("GEMINI_LOCATION", "us-central1") |
55 | | - # ) |
56 | | - |
57 | | - # You can comment the previous line and use the AI Studio key instead: |
58 | | - client = genai.Client( |
59 | | - api_key=os.getenv("GEMINI_API_KEY", ""), |
60 | | - ) |
61 | | - model = os.getenv("GEMINI_MODEL", "gemini-2.5-pro") |
62 | | - # model = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") |
63 | | - |
64 | | - response = client.models.generate_content(model=model, contents=[prompt]) |
65 | | - response_text = response.text |
66 | | - |
67 | | - # Log the response |
68 | | - logger.info(f"RESPONSE: {response_text}") |
69 | | - |
70 | | - # Update cache if enabled |
71 | | - if use_cache: |
72 | | - # Load cache again to avoid overwrites |
73 | | - cache = {} |
74 | | - if os.path.exists(cache_file): |
75 | | - try: |
76 | | - with open(cache_file, "r", encoding="utf-8") as f: |
77 | | - cache = json.load(f) |
78 | | - except: |
79 | | - pass |
80 | | - |
81 | | - # Add to cache and save |
82 | | - cache[prompt] = response_text |
| 28 | +def call_llm(prompt, use_cache: bool = True) -> str: |
| 29 | + """ |
| 30 | + Call an LLM provider based on environment variables. |
| 31 | + Environment variables: |
| 32 | + - LLM_PROVIDER: "OLLAMA" or "XAI" |
| 33 | + - <provider>_MODEL: Model name (e.g., OLLAMA_MODEL, XAI_MODEL) |
| 34 | + - <provider>_BASE_URL: Base URL without endpoint (e.g., OLLAMA_BASE_URL, XAI_BASE_URL) |
| 35 | + - <provider>_API_KEY: API key (e.g., OLLAMA_API_KEY, XAI_API_KEY; optional for providers that don't require it) |
| 36 | + The endpoint /v1/chat/completions will be appended to the base URL. |
| 37 | + """ |
| 38 | + logger.info(f"PROMPT: {prompt}") # log the prompt |
| 39 | + |
| 40 | + # Read the provider from environment variable |
| 41 | + provider = os.environ.get("LLM_PROVIDER") |
| 42 | + if not provider: |
| 43 | + raise ValueError("LLM_PROVIDER environment variable is required") |
| 44 | + |
| 45 | + # Construct the names of the other environment variables |
| 46 | + model_var = f"{provider}_MODEL" |
| 47 | + base_url_var = f"{provider}_BASE_URL" |
| 48 | + api_key_var = f"{provider}_API_KEY" |
| 49 | + |
| 50 | + # Read the provider-specific variables |
| 51 | + model = os.environ.get(model_var) |
| 52 | + base_url = os.environ.get(base_url_var) |
| 53 | + api_key = os.environ.get(api_key_var, "") # API key is optional, default to empty string |
| 54 | + |
| 55 | + # Validate required variables |
| 56 | + if not model: |
| 57 | + raise ValueError(f"{model_var} environment variable is required") |
| 58 | + if not base_url: |
| 59 | + raise ValueError(f"{base_url_var} environment variable is required") |
| 60 | + |
| 61 | + # Append the endpoint to the base URL |
| 62 | + url = f"{base_url}/v1/chat/completions" |
| 63 | + |
| 64 | + # Configure headers and payload based on provider |
| 65 | + headers = { |
| 66 | + "Content-Type": "application/json", |
| 67 | + } |
| 68 | + if api_key: # Only add Authorization header if API key is provided |
| 69 | + headers["Authorization"] = f"Bearer {api_key}" |
| 70 | + |
| 71 | + payload = { |
| 72 | + "model": model, |
| 73 | + "messages": [{"role": "user", "content": prompt}], |
| 74 | + "temperature": 0.7, |
| 75 | + } |
| 76 | + |
| 77 | + try: |
| 78 | + response = requests.post(url, headers=headers, json=payload) |
| 79 | + response_json = response.json() # Log the response |
| 80 | + logger.info("RESPONSE:\n%s", json.dumps(response_json, indent=2)) |
| 81 | + #logger.info(f"RESPONSE: {response.json()}") |
| 82 | + response.raise_for_status() |
| 83 | + return response.json()["choices"][0]["message"]["content"] |
| 84 | + except requests.exceptions.HTTPError as e: |
| 85 | + error_message = f"HTTP error occurred: {e}" |
83 | 86 | try: |
84 | | - with open(cache_file, "w", encoding="utf-8") as f: |
85 | | - json.dump(cache, f) |
86 | | - except Exception as e: |
87 | | - logger.error(f"Failed to save cache: {e}") |
88 | | - |
89 | | - return response_text |
90 | | - |
91 | | - |
92 | | -# # Use Azure OpenAI |
93 | | -# def call_llm(prompt, use_cache: bool = True): |
94 | | -# from openai import AzureOpenAI |
95 | | - |
96 | | -# endpoint = "https://<azure openai name>.openai.azure.com/" |
97 | | -# deployment = "<deployment name>" |
98 | | - |
99 | | -# subscription_key = "<azure openai key>" |
100 | | -# api_version = "<api version>" |
101 | | - |
102 | | -# client = AzureOpenAI( |
103 | | -# api_version=api_version, |
104 | | -# azure_endpoint=endpoint, |
105 | | -# api_key=subscription_key, |
106 | | -# ) |
107 | | - |
108 | | -# r = client.chat.completions.create( |
109 | | -# model=deployment, |
110 | | -# messages=[{"role": "user", "content": prompt}], |
111 | | -# response_format={ |
112 | | -# "type": "text" |
113 | | -# }, |
114 | | -# max_completion_tokens=40000, |
115 | | -# reasoning_effort="medium", |
116 | | -# store=False |
117 | | -# ) |
118 | | -# return r.choices[0].message.content |
119 | | - |
120 | | -# # Use Anthropic Claude 3.7 Sonnet Extended Thinking |
121 | | -# def call_llm(prompt, use_cache: bool = True): |
122 | | -# from anthropic import Anthropic |
123 | | -# client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY", "your-api-key")) |
124 | | -# response = client.messages.create( |
125 | | -# model="claude-3-7-sonnet-20250219", |
126 | | -# max_tokens=21000, |
127 | | -# thinking={ |
128 | | -# "type": "enabled", |
129 | | -# "budget_tokens": 20000 |
130 | | -# }, |
131 | | -# messages=[ |
132 | | -# {"role": "user", "content": prompt} |
133 | | -# ] |
134 | | -# ) |
135 | | -# return response.content[1].text |
136 | | - |
137 | | -# # Use OpenAI o1 |
138 | | -# def call_llm(prompt, use_cache: bool = True): |
139 | | -# from openai import OpenAI |
140 | | -# client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key")) |
141 | | -# r = client.chat.completions.create( |
142 | | -# model="o1", |
143 | | -# messages=[{"role": "user", "content": prompt}], |
144 | | -# response_format={ |
145 | | -# "type": "text" |
146 | | -# }, |
147 | | -# reasoning_effort="medium", |
148 | | -# store=False |
149 | | -# ) |
150 | | -# return r.choices[0].message.content |
151 | | - |
152 | | -# Use OpenRouter API |
153 | | -# def call_llm(prompt: str, use_cache: bool = True) -> str: |
154 | | -# import requests |
155 | | -# # Log the prompt |
156 | | -# logger.info(f"PROMPT: {prompt}") |
157 | | - |
158 | | -# # Check cache if enabled |
159 | | -# if use_cache: |
160 | | -# # Load cache from disk |
161 | | -# cache = {} |
162 | | -# if os.path.exists(cache_file): |
163 | | -# try: |
164 | | -# with open(cache_file, "r", encoding="utf-8") as f: |
165 | | -# cache = json.load(f) |
166 | | -# except: |
167 | | -# logger.warning(f"Failed to load cache, starting with empty cache") |
168 | | - |
169 | | -# # Return from cache if exists |
170 | | -# if prompt in cache: |
171 | | -# logger.info(f"RESPONSE: {cache[prompt]}") |
172 | | -# return cache[prompt] |
173 | | - |
174 | | -# # OpenRouter API configuration |
175 | | -# api_key = os.getenv("OPENROUTER_API_KEY", "") |
176 | | -# model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-exp:free") |
177 | | - |
178 | | -# headers = { |
179 | | -# "Authorization": f"Bearer {api_key}", |
180 | | -# } |
181 | | - |
182 | | -# data = { |
183 | | -# "model": model, |
184 | | -# "messages": [{"role": "user", "content": prompt}] |
185 | | -# } |
186 | | - |
187 | | -# response = requests.post( |
188 | | -# "https://openrouter.ai/api/v1/chat/completions", |
189 | | -# headers=headers, |
190 | | -# json=data |
191 | | -# ) |
192 | | - |
193 | | -# if response.status_code != 200: |
194 | | -# error_msg = f"OpenRouter API call failed with status {response.status_code}: {response.text}" |
195 | | -# logger.error(error_msg) |
196 | | -# raise Exception(error_msg) |
197 | | -# try: |
198 | | -# response_text = response.json()["choices"][0]["message"]["content"] |
199 | | -# except Exception as e: |
200 | | -# error_msg = f"Failed to parse OpenRouter response: {e}; Response: {response.text}" |
201 | | -# logger.error(error_msg) |
202 | | -# raise Exception(error_msg) |
203 | | - |
204 | | - |
205 | | -# # Log the response |
206 | | -# logger.info(f"RESPONSE: {response_text}") |
207 | | - |
208 | | -# # Update cache if enabled |
209 | | -# if use_cache: |
210 | | -# # Load cache again to avoid overwrites |
211 | | -# cache = {} |
212 | | -# if os.path.exists(cache_file): |
213 | | -# try: |
214 | | -# with open(cache_file, "r", encoding="utf-8") as f: |
215 | | -# cache = json.load(f) |
216 | | -# except: |
217 | | -# pass |
218 | | - |
219 | | -# # Add to cache and save |
220 | | -# cache[prompt] = response_text |
221 | | -# try: |
222 | | -# with open(cache_file, "w", encoding="utf-8") as f: |
223 | | -# json.dump(cache, f) |
224 | | -# except Exception as e: |
225 | | -# logger.error(f"Failed to save cache: {e}") |
226 | | - |
227 | | -# return response_text |
| 87 | + error_details = response.json().get("error", "No additional details") |
| 88 | + error_message += f" (Details: {error_details})" |
| 89 | + except: |
| 90 | + pass |
| 91 | + raise Exception(error_message) |
| 92 | + except requests.exceptions.ConnectionError: |
| 93 | + raise Exception(f"Failed to connect to {provider} API. Check your network connection.") |
| 94 | + except requests.exceptions.Timeout: |
| 95 | + raise Exception(f"Request to {provider} API timed out.") |
| 96 | + except requests.exceptions.RequestException as e: |
| 97 | + raise Exception(f"An error occurred while making the request to {provider}: {e}") |
| 98 | + except ValueError: |
| 99 | + raise Exception(f"Failed to parse response as JSON from {provider}. The server might have returned an invalid response.") |
228 | 100 |
|
229 | 101 | if __name__ == "__main__": |
230 | 102 | test_prompt = "Hello, how are you?" |
|
0 commit comments