diff --git a/spring-ai-modules/pom.xml b/spring-ai-modules/pom.xml index 09fc0ed8f1ae..67a2e4700837 100644 --- a/spring-ai-modules/pom.xml +++ b/spring-ai-modules/pom.xml @@ -23,6 +23,7 @@ spring-ai-agentic-patterns spring-ai-chat-stream spring-ai-introduction + spring-ai-llm-as-a-judge spring-ai-mcp spring-ai-mcp-elicitations diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/pom.xml b/spring-ai-modules/spring-ai-llm-as-a-judge/pom.xml new file mode 100644 index 000000000000..58dc05009c92 --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/pom.xml @@ -0,0 +1,88 @@ + + + 4.0.0 + spring-ai-llm-as-a-judge + jar + spring-ai-llm-as-a-judge + + + com.baeldung + spring-ai-modules + 0.0.1 + ../pom.xml + + + + + + org.springframework.ai + spring-ai-bom + ${spring-ai.version} + pom + import + + + + + + + org.springframework.boot + spring-boot-starter-web + + + org.springframework.ai + spring-ai-starter-model-openai + ${spring-ai.version} + + + org.commonmark + commonmark + 0.27.1 + + + org.commonmark + commonmark-ext-gfm-tables + ${commonmark-ext-gfm-tables.version} + + + org.springframework.boot + spring-boot-devtools + runtime + true + + + org.projectlombok + lombok + true + + + org.springframework.boot + spring-boot-starter-test + test + + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + ${spring.boot.mainclass} + + + + + + + 21 + 3.5.0 + 1.1.2 + 0.21.0 + 1.5.18 + + + diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/ChatConfig.java b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/ChatConfig.java new file mode 100644 index 000000000000..0b179caef439 --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/ChatConfig.java @@ -0,0 +1,34 @@ +package com.baeldung.springai; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class ChatConfig { + + @Bean + public ChatClient chatClient( + ChatClient.Builder builder, + LlmJudgeAdvisor judgeAdvisor + ) { + return builder + .defaultAdvisors(judgeAdvisor) + .build(); + } + + @Bean + public LlmJudgeAdvisor llmJudgeAdvisor( + ChatClient.Builder builder, + @Value("${judge.score-threshold:0.7}") double scoreThreshold, + @Value("${judge.max-refinements:2}") int maxRefinements + ) { + return new LlmJudgeAdvisor( + builder.build(), + scoreThreshold, + maxRefinements + ); + } + +} diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/LlmJudgeAdvisor.java b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/LlmJudgeAdvisor.java new file mode 100644 index 000000000000..83b505c923af --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/LlmJudgeAdvisor.java @@ -0,0 +1,88 @@ +package com.baeldung.springai; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.prompt.Prompt; + +public class LlmJudgeAdvisor implements CallAdvisor { + + private static final String JUDGE_SYSTEM_PROMPT = """ + You are a strict quality evaluator for AI-generated answers. + + Given a user question and an AI-generated answer, rate the answer quality. + + Use this rubric: + - 1.0: Complete, accurate, and clearly explained + - 0.7: Mostly correct but missing details or clarity + - 0.4: Partially correct or overly vague + - 0.0: Incorrect, irrelevant, or harmful + + Respond ONLY with a valid JSON object. Do not add any explanation outside the JSON. + Format: {"score": <0.0 to 1.0>, "feedback": ""} + """; + + private final ChatClient judgeClient; + private final double scoreThreshold; + private final int maxRefinements; + + public LlmJudgeAdvisor( + ChatClient judgeClient, + double scoreThreshold, + int maxRefinements + ) { + this.judgeClient = judgeClient; + this.scoreThreshold = scoreThreshold; + this.maxRefinements = maxRefinements; + } + + @Override + public ChatClientResponse adviseCall(ChatClientRequest request, CallAdvisorChain chain) { + for (int attempt = 1; attempt <= maxRefinements + 1; attempt++) { + ChatClientResponse response = chain.copy(this).nextCall(request); + if(attempt > maxRefinements) { + return response; + }; + Verdict verdict = evaluate(request, response); + if (verdict.score() >= scoreThreshold) { + return response; + } + + request = addFeedback(request, verdict.feedback()); + } + return chain.copy(this).nextCall(request); + } + + private Verdict evaluate(ChatClientRequest request, ChatClientResponse response) { + String question = request.prompt().getUserMessage().getText(); + String answer = response.chatResponse().getResult().getOutput().getText(); + + return judgeClient.prompt() + .system(JUDGE_SYSTEM_PROMPT) + .user("Question: " + question + "\n\nAnswer: " + answer) + .call() + .entity(Verdict.class); + } + + private ChatClientRequest addFeedback(ChatClientRequest original, String feedback) { + Prompt augmented = original.prompt() + .augmentUserMessage(msg -> msg.mutate() + .text(msg.getText() + + "\n\nYour previous answer was insufficient. Feedback: " + feedback + + "\nPlease provide an improved answer.") + .build()); + return original.mutate().prompt(augmented).build(); + } + + @Override + public String getName() { + return "LlmJudgeAdvisor"; + } + + @Override + public int getOrder() { + return 0; + } +} \ No newline at end of file diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/SpringBootAiLlmAsAJudgeApplication.java b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/SpringBootAiLlmAsAJudgeApplication.java new file mode 100644 index 000000000000..65085f5ad44b --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/SpringBootAiLlmAsAJudgeApplication.java @@ -0,0 +1,13 @@ +package com.baeldung.springai; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication +public class SpringBootAiLlmAsAJudgeApplication { + + public static void main(String[] args) { + SpringApplication.run(SpringBootAiLlmAsAJudgeApplication.class, args); + } + +} diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/TravelController.java b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/TravelController.java new file mode 100644 index 000000000000..fce0f0a61dff --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/TravelController.java @@ -0,0 +1,33 @@ +package com.baeldung.springai; + +import lombok.RequiredArgsConstructor; +import org.commonmark.parser.Parser; +import org.commonmark.renderer.html.HtmlRenderer; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +@RestController +@RequestMapping("/api/travel") +@RequiredArgsConstructor +public class TravelController { + + private final Parser parser = Parser.builder().build(); + private final HtmlRenderer renderer = HtmlRenderer.builder().build(); + + private final TravelService travelService; + + @GetMapping( + value = "/tips", + produces = MediaType.TEXT_HTML_VALUE + ) + public String getTips( + @RequestParam(defaultValue = "Paris", name = "destination") + String destination + ) { + final var markdown = travelService.getTravelTip(destination); + return renderer.render(parser.parse(markdown)); + } +} \ No newline at end of file diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/TravelService.java b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/TravelService.java new file mode 100644 index 000000000000..8575a4227084 --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/TravelService.java @@ -0,0 +1,23 @@ +package com.baeldung.springai; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; +import org.springframework.stereotype.Service; + +@Service +public class TravelService { + + private final ChatClient chatClient; + + public TravelService(ChatClient chatClient) { + this.chatClient = chatClient; + } + + public String getTravelTip(String destination) { + return this.chatClient + .prompt() + .user("Give me three insider tips for a trip to: " + destination) + .call() + .content(); + } +} \ No newline at end of file diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/Verdict.java b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/Verdict.java new file mode 100644 index 000000000000..d4b40df70dc6 --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/java/com/baeldung/springai/Verdict.java @@ -0,0 +1,4 @@ +package com.baeldung.springai; + +public record Verdict(double score, String feedback) { +} diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/resources/application.yaml b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/resources/application.yaml new file mode 100644 index 000000000000..80eee225208d --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/src/main/resources/application.yaml @@ -0,0 +1,11 @@ +spring: + ai: + openai: + api-key: ${OPENAI_API_KEY} + chat: + options: + model: gpt-5-nano + temperature: 1 +judge: + score-threshold: 0.75 + max-refinements: 2 \ No newline at end of file diff --git a/spring-ai-modules/spring-ai-llm-as-a-judge/src/test/java/com/baeldung/springai/LlmJudgeAdvisorTest.java b/spring-ai-modules/spring-ai-llm-as-a-judge/src/test/java/com/baeldung/springai/LlmJudgeAdvisorTest.java new file mode 100644 index 000000000000..24e47594a062 --- /dev/null +++ b/spring-ai-modules/spring-ai-llm-as-a-judge/src/test/java/com/baeldung/springai/LlmJudgeAdvisorTest.java @@ -0,0 +1,81 @@ +package com.baeldung.springai; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.bean.override.mockito.MockitoBean; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.when; + +@SpringBootTest( + properties = { + "judge.score-threshold=0.7", + "judge.max-refinements=2" + } +) +class LlmJudgeAdvisorTest { + + @MockitoBean + ChatModel chatModel; + + @Autowired + ChatClient chatClient; + + @Test + void givenLowQualityAnswer_whenAdvisorRuns_thenAnswerIsRefined() { + when(chatModel.call(any(Prompt.class))) + .thenReturn(buildChatResponse("It runs Java.")) + .thenReturn(buildChatResponse(""" + {"score": 0.3, "feedback": "Too vague."} + """)) + .thenReturn(buildChatResponse("The JVM executes Java bytecode, manages memory, and enables platform independence.")) + .thenReturn(buildChatResponse(""" + {"score": 0.9, "feedback": "Complete and accurate."} + """)); + + String result = chatClient.prompt() + .user("Explain what a JVM is.") + .call() + .content(); + + assertThat(result).contains("bytecode"); + } + + @Test + void givenLowQualityAnswer_whenAdvisorRuns_thenAnswerIsRefinedOnlyTwice() { + when(chatModel.call(any(Prompt.class))) + .thenReturn(buildChatResponse("It runs Java.")) + .thenReturn(buildChatResponse(""" + {"score": 0.3, "feedback": "Too vague."} + """)) + .thenReturn(buildChatResponse("The JVM runs Java bytecode.")) + .thenReturn(buildChatResponse(""" + {"score": 0.4, "feedback": "Still too vague."} + """)) + .thenReturn(buildChatResponse("The JVM executes Java bytecode, manages memory, and enables platform independence.")); + + String result = chatClient.prompt() + .user("Explain what a JVM is.") + .call() + .content(); + + assertThat(result).contains("bytecode"); + Mockito.verify(chatModel, Mockito.times(5)).call(any(Prompt.class)); + } + + private ChatResponse buildChatResponse(String content) { + return new ChatResponse(List.of(new Generation(new AssistantMessage(content)))); + } + +} \ No newline at end of file