Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spring-ai-modules/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<module>spring-ai-agentic-patterns</module>
<module>spring-ai-chat-stream</module>
<module>spring-ai-introduction</module>
<module>spring-ai-llm-as-a-judge</module>
<module>spring-ai-mcp</module>
<module>spring-ai-mcp-elicitations</module>
<!-- <module>spring-ai-mcp-oauth</module>--><!-- test failures -->
Expand Down
88 changes: 88 additions & 0 deletions spring-ai-modules/spring-ai-llm-as-a-judge/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<artifactId>spring-ai-llm-as-a-judge</artifactId>
<packaging>jar</packaging>
<name>spring-ai-llm-as-a-judge</name>

<parent>
<groupId>com.baeldung</groupId>
<artifactId>spring-ai-modules</artifactId>
<version>0.0.1</version>
<relativePath>../pom.xml</relativePath>
</parent>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<dependency>
<groupId>org.commonmark</groupId>
<artifactId>commonmark</artifactId>
<version>0.27.1</version>
</dependency>
<dependency>
<groupId>org.commonmark</groupId>
<artifactId>commonmark-ext-gfm-tables</artifactId>
<version>${commonmark-ext-gfm-tables.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId>
<scope>runtime</scope>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<mainClass>${spring.boot.mainclass}</mainClass>
</configuration>
</plugin>
</plugins>
</build>

<properties>
<java.version>21</java.version>
<spring-boot.version>3.5.0</spring-boot.version>
<spring-ai.version>1.1.2</spring-ai.version>
<commonmark-ext-gfm-tables.version>0.21.0</commonmark-ext-gfm-tables.version>
<logback.version>1.5.18</logback.version>
</properties>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.baeldung.springai;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class ChatConfig {

@Bean
public ChatClient chatClient(
ChatClient.Builder builder,
LlmJudgeAdvisor judgeAdvisor
) {
return builder
.defaultAdvisors(judgeAdvisor)
.build();
}

@Bean
public LlmJudgeAdvisor llmJudgeAdvisor(
ChatClient.Builder builder,
@Value("${judge.score-threshold:0.7}") double scoreThreshold,
@Value("${judge.max-refinements:2}") int maxRefinements
) {
return new LlmJudgeAdvisor(
builder.build(),
scoreThreshold,
maxRefinements
);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package com.baeldung.springai;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
import org.springframework.ai.chat.prompt.Prompt;

public class LlmJudgeAdvisor implements CallAdvisor {

private static final String JUDGE_SYSTEM_PROMPT = """
You are a strict quality evaluator for AI-generated answers.

Given a user question and an AI-generated answer, rate the answer quality.

Use this rubric:
- 1.0: Complete, accurate, and clearly explained
- 0.7: Mostly correct but missing details or clarity
- 0.4: Partially correct or overly vague
- 0.0: Incorrect, irrelevant, or harmful

Respond ONLY with a valid JSON object. Do not add any explanation outside the JSON.
Format: {"score": <0.0 to 1.0>, "feedback": "<one concise sentence>"}
""";

private final ChatClient judgeClient;
private final double scoreThreshold;
private final int maxRefinements;

public LlmJudgeAdvisor(
ChatClient judgeClient,
double scoreThreshold,
int maxRefinements
) {
this.judgeClient = judgeClient;
this.scoreThreshold = scoreThreshold;
this.maxRefinements = maxRefinements;
}

@Override
public ChatClientResponse adviseCall(ChatClientRequest request, CallAdvisorChain chain) {
for (int attempt = 1; attempt <= maxRefinements + 1; attempt++) {
ChatClientResponse response = chain.copy(this).nextCall(request);
if(attempt > maxRefinements) {
return response;
};
Verdict verdict = evaluate(request, response);
if (verdict.score() >= scoreThreshold) {
return response;
}

request = addFeedback(request, verdict.feedback());
}
return chain.copy(this).nextCall(request);
}

private Verdict evaluate(ChatClientRequest request, ChatClientResponse response) {
String question = request.prompt().getUserMessage().getText();
String answer = response.chatResponse().getResult().getOutput().getText();

return judgeClient.prompt()
.system(JUDGE_SYSTEM_PROMPT)
.user("Question: " + question + "\n\nAnswer: " + answer)
.call()
.entity(Verdict.class);
}

private ChatClientRequest addFeedback(ChatClientRequest original, String feedback) {
Prompt augmented = original.prompt()
.augmentUserMessage(msg -> msg.mutate()
.text(msg.getText()
+ "\n\nYour previous answer was insufficient. Feedback: " + feedback
+ "\nPlease provide an improved answer.")
.build());
return original.mutate().prompt(augmented).build();
}

@Override
public String getName() {
return "LlmJudgeAdvisor";
}

@Override
public int getOrder() {
return 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.baeldung.springai;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class SpringBootAiLlmAsAJudgeApplication {

public static void main(String[] args) {
SpringApplication.run(SpringBootAiLlmAsAJudgeApplication.class, args);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.baeldung.springai;

import lombok.RequiredArgsConstructor;
import org.commonmark.parser.Parser;
import org.commonmark.renderer.html.HtmlRenderer;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/api/travel")
@RequiredArgsConstructor
public class TravelController {

private final Parser parser = Parser.builder().build();
private final HtmlRenderer renderer = HtmlRenderer.builder().build();

private final TravelService travelService;

@GetMapping(
value = "/tips",
produces = MediaType.TEXT_HTML_VALUE
)
public String getTips(
@RequestParam(defaultValue = "Paris", name = "destination")
String destination
) {
final var markdown = travelService.getTravelTip(destination);
return renderer.render(parser.parse(markdown));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.baeldung.springai;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.stereotype.Service;

@Service
public class TravelService {

private final ChatClient chatClient;

public TravelService(ChatClient chatClient) {
this.chatClient = chatClient;
}

public String getTravelTip(String destination) {
return this.chatClient
.prompt()
.user("Give me three insider tips for a trip to: " + destination)
.call()
.content();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.baeldung.springai;

public record Verdict(double score, String feedback) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
chat:
options:
model: gpt-5-nano
temperature: 1
judge:
score-threshold: 0.75
max-refinements: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package com.baeldung.springai;

import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.bean.override.mockito.MockitoBean;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.when;

@SpringBootTest(
properties = {
"judge.score-threshold=0.7",
"judge.max-refinements=2"
}
)
class LlmJudgeAdvisorTest {

@MockitoBean
ChatModel chatModel;

@Autowired
ChatClient chatClient;

@Test
void givenLowQualityAnswer_whenAdvisorRuns_thenAnswerIsRefined() {
when(chatModel.call(any(Prompt.class)))
.thenReturn(buildChatResponse("It runs Java."))
.thenReturn(buildChatResponse("""
{"score": 0.3, "feedback": "Too vague."}
"""))
.thenReturn(buildChatResponse("The JVM executes Java bytecode, manages memory, and enables platform independence."))
.thenReturn(buildChatResponse("""
{"score": 0.9, "feedback": "Complete and accurate."}
"""));

String result = chatClient.prompt()
.user("Explain what a JVM is.")
.call()
.content();

assertThat(result).contains("bytecode");
}

@Test
void givenLowQualityAnswer_whenAdvisorRuns_thenAnswerIsRefinedOnlyTwice() {
when(chatModel.call(any(Prompt.class)))
.thenReturn(buildChatResponse("It runs Java."))
.thenReturn(buildChatResponse("""
{"score": 0.3, "feedback": "Too vague."}
"""))
.thenReturn(buildChatResponse("The JVM runs Java bytecode."))
.thenReturn(buildChatResponse("""
{"score": 0.4, "feedback": "Still too vague."}
"""))
.thenReturn(buildChatResponse("The JVM executes Java bytecode, manages memory, and enables platform independence."));

String result = chatClient.prompt()
.user("Explain what a JVM is.")
.call()
.content();

assertThat(result).contains("bytecode");
Mockito.verify(chatModel, Mockito.times(5)).call(any(Prompt.class));
}

private ChatResponse buildChatResponse(String content) {
return new ChatResponse(List.of(new Generation(new AssistantMessage(content))));
}

}