|
25 | 25 | # Simple cache configuration |
26 | 26 | cache_file = "llm_cache.json" |
27 | 27 |
|
28 | | -def call_llm(prompt, use_cache: bool = True) -> str: |
| 28 | + |
| 29 | +def load_cache(): |
| 30 | + try: |
| 31 | + with open(cache_file, 'r') as f: |
| 32 | + return json.load(f) |
| 33 | + except: |
| 34 | + logger.warning(f"Failed to load cache.") |
| 35 | + return {} |
| 36 | + |
| 37 | + |
| 38 | +def save_cache(cache): |
| 39 | + try: |
| 40 | + with open(cache_file, 'w') as f: |
| 41 | + json.dump(cache, f) |
| 42 | + except: |
| 43 | + logger.warning(f"Failed to save cache") |
| 44 | + |
| 45 | + |
| 46 | +def get_llm_provider(): |
| 47 | + provider = os.getenv("LLM_PROVIDER") |
| 48 | + if not provider and (os.getenv("GEMINI_PROJECT_ID") or os.getenv("GEMINI_API_KEY")): |
| 49 | + provider = "GEMINI" |
| 50 | + # if necessary, add ANTHROPIC/OPENAI |
| 51 | + return provider |
| 52 | + |
| 53 | + |
| 54 | +def _call_llm_provider(prompt: str) -> str: |
29 | 55 | """ |
30 | 56 | Call an LLM provider based on environment variables. |
31 | 57 | Environment variables: |
@@ -59,7 +85,7 @@ def call_llm(prompt, use_cache: bool = True) -> str: |
59 | 85 | raise ValueError(f"{base_url_var} environment variable is required") |
60 | 86 |
|
61 | 87 | # Append the endpoint to the base URL |
62 | | - url = f"{base_url}/v1/chat/completions" |
| 88 | + url = f"{base_url.rstrip('/')}/v1/chat/completions" |
63 | 89 |
|
64 | 90 | # Configure headers and payload based on provider |
65 | 91 | headers = { |
@@ -98,6 +124,58 @@ def call_llm(prompt, use_cache: bool = True) -> str: |
98 | 124 | except ValueError: |
99 | 125 | raise Exception(f"Failed to parse response as JSON from {provider}. The server might have returned an invalid response.") |
100 | 126 |
|
| 127 | +# By default, we Google Gemini 2.5 pro, as it shows great performance for code understanding |
| 128 | +def call_llm(prompt: str, use_cache: bool = True) -> str: |
| 129 | + # Log the prompt |
| 130 | + logger.info(f"PROMPT: {prompt}") |
| 131 | + |
| 132 | + # Check cache if enabled |
| 133 | + if use_cache: |
| 134 | + # Load cache from disk |
| 135 | + cache = load_cache() |
| 136 | + # Return from cache if exists |
| 137 | + if prompt in cache: |
| 138 | + logger.info(f"RESPONSE: {cache[prompt]}") |
| 139 | + return cache[prompt] |
| 140 | + |
| 141 | + provider = get_llm_provider() |
| 142 | + if provider == "GEMINI": |
| 143 | + response_text = _call_llm_gemini(prompt) |
| 144 | + else: # generic method using a URL that is OpenAI compatible API (Ollama, ...) |
| 145 | + response_text = _call_llm_provider(prompt) |
| 146 | + |
| 147 | + # Log the response |
| 148 | + logger.info(f"RESPONSE: {response_text}") |
| 149 | + |
| 150 | + # Update cache if enabled |
| 151 | + if use_cache: |
| 152 | + # Load cache again to avoid overwrites |
| 153 | + cache = load_cache() |
| 154 | + # Add to cache and save |
| 155 | + cache[prompt] = response_text |
| 156 | + save_cache(cache) |
| 157 | + |
| 158 | + return response_text |
| 159 | + |
| 160 | + |
| 161 | +def _call_llm_gemini(prompt: str) -> str: |
| 162 | + if os.getenv("GEMINI_PROJECT_ID"): |
| 163 | + client = genai.Client( |
| 164 | + vertexai=True, |
| 165 | + project=os.getenv("GEMINI_PROJECT_ID"), |
| 166 | + location=os.getenv("GEMINI_LOCATION", "us-central1") |
| 167 | + ) |
| 168 | + elif os.getenv("GEMINI_API_KEY"): |
| 169 | + client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) |
| 170 | + else: |
| 171 | + raise ValueError("Either GEMINI_PROJECT_ID or GEMINI_API_KEY must be set in the environment") |
| 172 | + model = os.getenv("GEMINI_MODEL", "gemini-2.5-pro-exp-03-25") |
| 173 | + response = client.models.generate_content( |
| 174 | + model=model, |
| 175 | + contents=[prompt] |
| 176 | + ) |
| 177 | + return response.text |
| 178 | + |
101 | 179 | if __name__ == "__main__": |
102 | 180 | test_prompt = "Hello, how are you?" |
103 | 181 |
|
|
0 commit comments