3.1 Spring AI模型抽象
3.1.1 模型接口体系
Spring AI提供了统一的模型抽象接口,支持多种AI能力:
// 核心模型接口
public interface Model<TRequest, TResponse> {
TResponse call(TRequest request);
}
// 聊天模型接口
public interface ChatModel extends Model<Prompt, ChatResponse> {
ChatResponse call(Prompt prompt);
Flux<ChatResponse> stream(Prompt prompt);
}
// 嵌入模型接口
public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {
EmbeddingResponse call(EmbeddingRequest request);
List<Double> embed(String text);
List<List<Double>> embed(List<String> texts);
}
// 图像模型接口
public interface ImageModel extends Model<ImagePrompt, ImageResponse> {
ImageResponse call(ImagePrompt prompt);
}
3.1.2 模型配置抽象
// ChatOptions.java
package com.example.springai.config;
import org.springframework.ai.chat.ChatOptions;
public class CustomChatOptions implements ChatOptions {
private String model;
private Double temperature;
private Integer maxTokens;
private Double topP;
private Integer topK;
private Double frequencyPenalty;
private Double presencePenalty;
private List<String> stopSequences;
// 构造函数
public CustomChatOptions() {}
public static Builder builder() {
return new Builder();
}
// Builder模式
public static class Builder {
private CustomChatOptions options = new CustomChatOptions();
public Builder model(String model) {
options.model = model;
return this;
}
public Builder temperature(Double temperature) {
options.temperature = temperature;
return this;
}
public Builder maxTokens(Integer maxTokens) {
options.maxTokens = maxTokens;
return this;
}
public Builder topP(Double topP) {
options.topP = topP;
return this;
}
public Builder stopSequences(List<String> stopSequences) {
options.stopSequences = stopSequences;
return this;
}
public CustomChatOptions build() {
return options;
}
}
// Getters and Setters
public String getModel() { return model; }
public void setModel(String model) { this.model = model; }
public Double getTemperature() { return temperature; }
public void setTemperature(Double temperature) { this.temperature = temperature; }
public Integer getMaxTokens() { return maxTokens; }
public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; }
public Double getTopP() { return topP; }
public void setTopP(Double topP) { this.topP = topP; }
public Integer getTopK() { return topK; }
public void setTopK(Integer topK) { this.topK = topK; }
public Double getFrequencyPenalty() { return frequencyPenalty; }
public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; }
public Double getPresencePenalty() { return presencePenalty; }
public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; }
public List<String> getStopSequences() { return stopSequences; }
public void setStopSequences(List<String> stopSequences) { this.stopSequences = stopSequences; }
}
3.2 OpenAI集成
3.2.1 OpenAI配置
// OpenAIConfiguration.java
package com.example.springai.config;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
import java.time.Duration;
@Configuration
@ConditionalOnProperty(name = "spring.ai.openai.enabled", havingValue = "true", matchIfMissing = true)
public class OpenAIConfiguration {
@Value("${spring.ai.openai.api-key}")
private String apiKey;
@Value("${spring.ai.openai.base-url:https://api.openai.com}")
private String baseUrl;
@Value("${spring.ai.openai.timeout:60}")
private int timeoutSeconds;
@Bean
public OpenAiApi openAiApi() {
return OpenAiApi.builder()
.withApiKey(apiKey)
.withBaseUrl(baseUrl)
.withRestTemplate(customRestTemplate())
.build();
}
@Bean
public OpenAiImageApi openAiImageApi() {
return new OpenAiImageApi(baseUrl, apiKey, customRestTemplate());
}
@Bean
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi);
}
@Bean
public OpenAiEmbeddingModel openAiEmbeddingModel(OpenAiApi openAiApi) {
return new OpenAiEmbeddingModel(openAiApi);
}
@Bean
public OpenAiImageModel openAiImageModel(OpenAiImageApi openAiImageApi) {
return new OpenAiImageModel(openAiImageApi);
}
private RestTemplate customRestTemplate() {
RestTemplate restTemplate = new RestTemplate();
restTemplate.getInterceptors().add((request, body, execution) -> {
request.getHeaders().add("User-Agent", "Spring-AI-Demo/1.0");
return execution.execute(request, body);
});
return restTemplate;
}
}
3.2.2 OpenAI聊天服务
// OpenAIChatService.java
package com.example.springai.service;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.Map;
@Service
public class OpenAIChatService {
private final OpenAiChatModel chatModel;
public OpenAIChatService(OpenAiChatModel chatModel) {
this.chatModel = chatModel;
}
/**
* 基础聊天
*/
public String chat(String message) {
Prompt prompt = new Prompt(message);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 带选项的聊天
*/
public String chatWithOptions(String message, OpenAiChatOptions options) {
Prompt prompt = new Prompt(message, options);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 模板化聊天
*/
public String chatWithTemplate(String template, Map<String, Object> variables) {
PromptTemplate promptTemplate = new PromptTemplate(template);
Prompt prompt = promptTemplate.create(variables);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 流式聊天
*/
public Flux<String> streamChat(String message) {
Prompt prompt = new Prompt(message);
return chatModel.stream(prompt)
.map(response -> response.getResult().getOutput().getContent());
}
/**
* 多轮对话
*/
public String multiTurnChat(String systemMessage, String userMessage) {
PromptTemplate template = new PromptTemplate(
"System: {system}\nUser: {user}\nAssistant:"
);
Map<String, Object> variables = Map.of(
"system", systemMessage,
"user", userMessage
);
Prompt prompt = template.create(variables);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 代码生成
*/
public String generateCode(String language, String description) {
String template = """
请用{language}编程语言实现以下功能:
{description}
要求:
1. 代码要有详细注释
2. 遵循最佳实践
3. 包含错误处理
4. 提供使用示例
""";
PromptTemplate promptTemplate = new PromptTemplate(template);
Map<String, Object> variables = Map.of(
"language", language,
"description", description
);
Prompt prompt = promptTemplate.create(variables);
OpenAiChatOptions options = OpenAiChatOptions.builder()
.withModel("gpt-4")
.withTemperature(0.3)
.withMaxTokens(2000)
.build();
prompt = new Prompt(prompt.getInstructions(), options);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 文本总结
*/
public String summarizeText(String text, int maxLength) {
String template = """
请将以下文本总结为不超过{maxLength}字的摘要:
{text}
总结要求:
1. 保留核心信息
2. 语言简洁明了
3. 逻辑清晰
""";
PromptTemplate promptTemplate = new PromptTemplate(template);
Map<String, Object> variables = Map.of(
"text", text,
"maxLength", maxLength
);
Prompt prompt = promptTemplate.create(variables);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
}
3.2.3 OpenAI嵌入服务
// OpenAIEmbeddingService.java
package com.example.springai.service;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class OpenAIEmbeddingService {
private final OpenAiEmbeddingModel embeddingModel;
public OpenAIEmbeddingService(OpenAiEmbeddingModel embeddingModel) {
this.embeddingModel = embeddingModel;
}
/**
* 单文本嵌入
*/
public List<Double> embedText(String text) {
return embeddingModel.embed(text);
}
/**
* 批量文本嵌入
*/
public List<List<Double>> embedTexts(List<String> texts) {
return embeddingModel.embed(texts);
}
/**
* 带选项的嵌入
*/
public EmbeddingResponse embedWithOptions(List<String> texts, String model) {
OpenAiEmbeddingOptions options = OpenAiEmbeddingOptions.builder()
.withModel(model)
.build();
EmbeddingRequest request = new EmbeddingRequest(texts, options);
return embeddingModel.call(request);
}
/**
* 计算文本相似度
*/
public double calculateSimilarity(String text1, String text2) {
List<Double> embedding1 = embedText(text1);
List<Double> embedding2 = embedText(text2);
return cosineSimilarity(embedding1, embedding2);
}
/**
* 余弦相似度计算
*/
private double cosineSimilarity(List<Double> vectorA, List<Double> vectorB) {
if (vectorA.size() != vectorB.size()) {
throw new IllegalArgumentException("向量维度不匹配");
}
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vectorA.size(); i++) {
dotProduct += vectorA.get(i) * vectorB.get(i);
normA += Math.pow(vectorA.get(i), 2);
normB += Math.pow(vectorB.get(i), 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
/**
* 文本聚类
*/
public List<List<String>> clusterTexts(List<String> texts, double threshold) {
List<List<Double>> embeddings = embedTexts(texts);
List<List<String>> clusters = new java.util.ArrayList<>();
boolean[] assigned = new boolean[texts.size()];
for (int i = 0; i < texts.size(); i++) {
if (assigned[i]) continue;
List<String> cluster = new java.util.ArrayList<>();
cluster.add(texts.get(i));
assigned[i] = true;
for (int j = i + 1; j < texts.size(); j++) {
if (assigned[j]) continue;
double similarity = cosineSimilarity(embeddings.get(i), embeddings.get(j));
if (similarity > threshold) {
cluster.add(texts.get(j));
assigned[j] = true;
}
}
clusters.add(cluster);
}
return clusters;
}
}
3.3 Azure OpenAI集成
3.3.1 Azure OpenAI配置
// AzureOpenAIConfiguration.java
package com.example.springai.config;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
import org.springframework.ai.azure.openai.api.AzureOpenAiApi;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
@ConditionalOnProperty(name = "spring.ai.azure.openai.enabled", havingValue = "true")
public class AzureOpenAIConfiguration {
@Value("${spring.ai.azure.openai.api-key}")
private String apiKey;
@Value("${spring.ai.azure.openai.endpoint}")
private String endpoint;
@Value("${spring.ai.azure.openai.api-version:2023-12-01-preview}")
private String apiVersion;
@Bean
public AzureOpenAiApi azureOpenAiApi() {
return new AzureOpenAiApi(endpoint, apiKey, apiVersion);
}
@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(AzureOpenAiApi azureOpenAiApi) {
return new AzureOpenAiChatModel(azureOpenAiApi);
}
@Bean
public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(AzureOpenAiApi azureOpenAiApi) {
return new AzureOpenAiEmbeddingModel(azureOpenAiApi);
}
}
3.3.2 Azure OpenAI服务
// AzureOpenAIService.java
package com.example.springai.service;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.stereotype.Service;
@Service
@ConditionalOnBean(AzureOpenAiChatModel.class)
public class AzureOpenAIService {
private final AzureOpenAiChatModel chatModel;
public AzureOpenAIService(AzureOpenAiChatModel chatModel) {
this.chatModel = chatModel;
}
/**
* Azure OpenAI聊天
*/
public String chat(String message, String deploymentName) {
AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder()
.withDeploymentName(deploymentName)
.withTemperature(0.7)
.withMaxTokens(1000)
.build();
Prompt prompt = new Prompt(message, options);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 企业级聊天(带内容过滤)
*/
public String enterpriseChat(String message, String deploymentName) {
AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder()
.withDeploymentName(deploymentName)
.withTemperature(0.3) // 更保守的温度
.withMaxTokens(500)
.withTopP(0.9)
.withFrequencyPenalty(0.5)
.withPresencePenalty(0.5)
.build();
// 添加企业安全提示
String systemPrompt = """
你是一个企业级AI助手。请遵循以下原则:
1. 提供准确、专业的信息
2. 避免敏感或不当内容
3. 保护用户隐私
4. 遵守企业政策
""";
String fullMessage = systemPrompt + "\n\n用户问题:" + message;
Prompt prompt = new Prompt(fullMessage, options);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
}
3.4 Ollama本地模型集成
3.4.1 Ollama配置
// OllamaConfiguration.java
package com.example.springai.config;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
@ConditionalOnProperty(name = "spring.ai.ollama.enabled", havingValue = "true")
public class OllamaConfiguration {
@Value("${spring.ai.ollama.base-url:http://localhost:11434}")
private String baseUrl;
@Bean
public OllamaApi ollamaApi() {
return new OllamaApi(baseUrl);
}
@Bean
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi) {
return new OllamaChatModel(ollamaApi);
}
@Bean
public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi) {
return new OllamaEmbeddingModel(ollamaApi);
}
}
3.4.2 Ollama服务
// OllamaService.java
package com.example.springai.service;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaChatOptions;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.List;
@Service
@ConditionalOnBean(OllamaChatModel.class)
public class OllamaService {
private final OllamaChatModel chatModel;
private final OllamaApi ollamaApi;
public OllamaService(OllamaChatModel chatModel, OllamaApi ollamaApi) {
this.chatModel = chatModel;
this.ollamaApi = ollamaApi;
}
/**
* 获取可用模型列表
*/
public List<OllamaModel> getAvailableModels() {
return ollamaApi.listModels().models();
}
/**
* 本地模型聊天
*/
public String chat(String message, String model) {
OllamaChatOptions options = OllamaChatOptions.create()
.withModel(model)
.withTemperature(0.8);
Prompt prompt = new Prompt(message, options);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 流式本地聊天
*/
public Flux<String> streamChat(String message, String model) {
OllamaChatOptions options = OllamaChatOptions.create()
.withModel(model)
.withTemperature(0.8);
Prompt prompt = new Prompt(message, options);
return chatModel.stream(prompt)
.map(response -> response.getResult().getOutput().getContent());
}
/**
* 代码助手(使用Code Llama)
*/
public String codeAssistant(String codeRequest) {
String systemPrompt = """
你是一个专业的代码助手。请帮助用户:
1. 编写高质量的代码
2. 解释代码逻辑
3. 优化代码性能
4. 修复代码错误
请用简洁明了的方式回答。
""";
OllamaChatOptions options = OllamaChatOptions.create()
.withModel("codellama")
.withTemperature(0.3)
.withTopP(0.9);
String fullPrompt = systemPrompt + "\n\n" + codeRequest;
Prompt prompt = new Prompt(fullPrompt, options);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 中文对话(使用中文模型)
*/
public String chineseChat(String message) {
OllamaChatOptions options = OllamaChatOptions.create()
.withModel("qwen:7b")
.withTemperature(0.7);
Prompt prompt = new Prompt(message, options);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 检查模型是否可用
*/
public boolean isModelAvailable(String modelName) {
try {
List<OllamaModel> models = getAvailableModels();
return models.stream()
.anyMatch(model -> model.name().equals(modelName));
} catch (Exception e) {
return false;
}
}
/**
* 拉取模型
*/
public void pullModel(String modelName) {
try {
ollamaApi.pullModel(modelName);
} catch (Exception e) {
throw new RuntimeException("拉取模型失败: " + modelName, e);
}
}
}
3.5 多模型管理器
3.5.1 模型管理器
// ModelManager.java
package com.example.springai.service;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Service
public class ModelManager {
private final Map<String, ChatModel> chatModels = new ConcurrentHashMap<>();
private final Map<String, ModelMetrics> modelMetrics = new ConcurrentHashMap<>();
public ModelManager(
@Qualifier("openAiChatModel") ChatModel openAiChatModel,
@Qualifier("azureOpenAiChatModel") ChatModel azureOpenAiChatModel,
@Qualifier("ollamaChatModel") ChatModel ollamaChatModel
) {
chatModels.put("openai", openAiChatModel);
chatModels.put("azure", azureOpenAiChatModel);
chatModels.put("ollama", ollamaChatModel);
// 初始化指标
chatModels.keySet().forEach(key ->
modelMetrics.put(key, new ModelMetrics())
);
}
/**
* 根据模型名称获取聊天响应
*/
public String chat(String modelName, String message) {
ChatModel model = getModel(modelName);
ModelMetrics metrics = modelMetrics.get(modelName);
long startTime = System.currentTimeMillis();
try {
Prompt prompt = new Prompt(message);
ChatResponse response = model.call(prompt);
long duration = System.currentTimeMillis() - startTime;
metrics.recordSuccess(duration);
return response.getResult().getOutput().getContent();
} catch (Exception e) {
long duration = System.currentTimeMillis() - startTime;
metrics.recordFailure(duration);
throw new RuntimeException("模型调用失败: " + modelName, e);
}
}
/**
* 流式聊天
*/
public Flux<String> streamChat(String modelName, String message) {
ChatModel model = getModel(modelName);
Prompt prompt = new Prompt(message);
return model.stream(prompt)
.map(response -> response.getResult().getOutput().getContent())
.doOnError(error -> modelMetrics.get(modelName).recordFailure(0))
.doOnComplete(() -> modelMetrics.get(modelName).recordSuccess(0));
}
/**
* 智能路由(根据负载选择模型)
*/
public String smartChat(String message) {
String bestModel = selectBestModel();
return chat(bestModel, message);
}
/**
* 获取模型
*/
private ChatModel getModel(String modelName) {
ChatModel model = chatModels.get(modelName);
if (model == null) {
throw new IllegalArgumentException("未知的模型: " + modelName);
}
return model;
}
/**
* 选择最佳模型
*/
private String selectBestModel() {
return modelMetrics.entrySet().stream()
.min((e1, e2) -> {
ModelMetrics m1 = e1.getValue();
ModelMetrics m2 = e2.getValue();
// 优先选择成功率高的模型
double successRate1 = m1.getSuccessRate();
double successRate2 = m2.getSuccessRate();
if (Math.abs(successRate1 - successRate2) > 0.1) {
return Double.compare(successRate2, successRate1);
}
// 成功率相近时选择响应时间短的
return Double.compare(m1.getAverageResponseTime(), m2.getAverageResponseTime());
})
.map(Map.Entry::getKey)
.orElse("openai");
}
/**
* 获取模型指标
*/
public Map<String, ModelMetrics> getModelMetrics() {
return Map.copyOf(modelMetrics);
}
/**
* 模型指标类
*/
public static class ModelMetrics {
private long totalRequests = 0;
private long successfulRequests = 0;
private long totalResponseTime = 0;
public synchronized void recordSuccess(long responseTime) {
totalRequests++;
successfulRequests++;
totalResponseTime += responseTime;
}
public synchronized void recordFailure(long responseTime) {
totalRequests++;
totalResponseTime += responseTime;
}
public synchronized double getSuccessRate() {
return totalRequests == 0 ? 0.0 : (double) successfulRequests / totalRequests;
}
public synchronized double getAverageResponseTime() {
return totalRequests == 0 ? 0.0 : (double) totalResponseTime / totalRequests;
}
public synchronized long getTotalRequests() {
return totalRequests;
}
public synchronized long getSuccessfulRequests() {
return successfulRequests;
}
}
}
3.5.2 模型管理控制器
// ModelController.java
package com.example.springai.controller;
import com.example.springai.service.ModelManager;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import java.util.Map;
@RestController
@RequestMapping("/models")
public class ModelController {
private final ModelManager modelManager;
public ModelController(ModelManager modelManager) {
this.modelManager = modelManager;
}
/**
* 指定模型聊天
*/
@PostMapping("/{modelName}/chat")
public String chat(
@PathVariable String modelName,
@RequestBody String message
) {
return modelManager.chat(modelName, message);
}
/**
* 智能路由聊天
*/
@PostMapping("/smart/chat")
public String smartChat(@RequestBody String message) {
return modelManager.smartChat(message);
}
/**
* 流式聊天
*/
@PostMapping(value = "/{modelName}/stream", produces = MediaType.TEXT_PLAIN_VALUE)
public Flux<String> streamChat(
@PathVariable String modelName,
@RequestBody String message
) {
return modelManager.streamChat(modelName, message);
}
/**
* 获取模型指标
*/
@GetMapping("/metrics")
public Map<String, ModelManager.ModelMetrics> getMetrics() {
return modelManager.getModelMetrics();
}
}
3.6 错误处理和重试机制
3.6.1 全局异常处理
// GlobalExceptionHandler.java
package com.example.springai.exception;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import java.time.LocalDateTime;
import java.util.Map;
@RestControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(RuntimeException.class)
public ResponseEntity<Map<String, Object>> handleRuntimeException(RuntimeException e) {
Map<String, Object> error = Map.of(
"timestamp", LocalDateTime.now(),
"status", HttpStatus.INTERNAL_SERVER_ERROR.value(),
"error", "Internal Server Error",
"message", e.getMessage(),
"path", "/api"
);
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(error);
}
@ExceptionHandler(IllegalArgumentException.class)
public ResponseEntity<Map<String, Object>> handleIllegalArgumentException(IllegalArgumentException e) {
Map<String, Object> error = Map.of(
"timestamp", LocalDateTime.now(),
"status", HttpStatus.BAD_REQUEST.value(),
"error", "Bad Request",
"message", e.getMessage(),
"path", "/api"
);
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(error);
}
}
3.6.2 重试配置
// RetryConfiguration.java
package com.example.springai.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.retry.annotation.EnableRetry;
import org.springframework.retry.backoff.ExponentialBackOffPolicy;
import org.springframework.retry.policy.SimpleRetryPolicy;
import org.springframework.retry.support.RetryTemplate;
import java.util.Map;
@Configuration
@EnableRetry
public class RetryConfiguration {
@Bean
public RetryTemplate retryTemplate() {
RetryTemplate retryTemplate = new RetryTemplate();
// 重试策略
SimpleRetryPolicy retryPolicy = new SimpleRetryPolicy();
retryPolicy.setMaxAttempts(3);
// 指定需要重试的异常
Map<Class<? extends Throwable>, Boolean> retryableExceptions = Map.of(
RuntimeException.class, true,
java.net.SocketTimeoutException.class, true,
java.net.ConnectException.class, true
);
retryPolicy.setRetryableExceptions(retryableExceptions);
// 退避策略
ExponentialBackOffPolicy backOffPolicy = new ExponentialBackOffPolicy();
backOffPolicy.setInitialInterval(1000);
backOffPolicy.setMultiplier(2.0);
backOffPolicy.setMaxInterval(10000);
retryTemplate.setRetryPolicy(retryPolicy);
retryTemplate.setBackOffPolicy(backOffPolicy);
return retryTemplate;
}
}
3.7 本章总结
3.7.1 核心要点
- 模型抽象:Spring AI提供统一的模型接口
- 多提供商支持:OpenAI、Azure OpenAI、Ollama等
- 配置管理:灵活的配置选项和环境管理
- 错误处理:完善的异常处理和重试机制
- 性能监控:模型指标收集和智能路由
3.7.2 最佳实践
- 配置安全:使用环境变量管理敏感信息
- 多模型策略:根据场景选择合适的模型
- 错误恢复:实现重试和降级机制
- 性能优化:监控模型性能并智能路由
- 成本控制:合理配置模型参数
3.7.3 练习题
基础练习
- 配置OpenAI和Ollama两个模型提供商
- 实现一个简单的模型切换功能
- 添加基础的错误处理
进阶练习
- 实现模型性能监控和指标收集
- 创建智能模型路由器
- 添加模型响应缓存机制
高级练习
- 实现多模型负载均衡
- 创建模型健康检查系统
- 实现基于成本的模型选择策略
下一章我们将学习提示词工程和模板系统的使用。