8.1 多模态AI基础
8.1.1 多模态AI概念
多模态AI(Multimodal AI)是指能够处理和理解多种类型数据(如文本、图像、音频、视频等)的人工智能系统。Spring AI提供了强大的多模态支持,特别是在图像和文本的结合处理方面。
核心特性: 1. 图像理解:分析图像内容,提取视觉信息 2. 视觉问答:基于图像内容回答问题 3. 图像描述:生成图像的文字描述 4. 视觉推理:基于图像进行逻辑推理 5. 多模态对话:结合图像和文本的对话系统
8.1.2 Spring AI多模态架构
// MultimodalModel接口
public interface MultimodalModel extends Model<MultimodalPrompt, MultimodalResponse> {
/**
* 多模态调用
*/
MultimodalResponse call(MultimodalPrompt prompt);
/**
* 流式多模态调用
*/
default Flux<MultimodalResponse> stream(MultimodalPrompt prompt) {
return Flux.just(call(prompt));
}
}
// 多模态消息类型
package com.example.springai.multimodal;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.core.io.Resource;
import java.util.List;
import java.util.ArrayList;
public class MultimodalMessage {
private String text;
private List<Media> mediaList;
public MultimodalMessage(String text) {
this.text = text;
this.mediaList = new ArrayList<>();
}
public MultimodalMessage(String text, List<Media> mediaList) {
this.text = text;
this.mediaList = mediaList != null ? mediaList : new ArrayList<>();
}
/**
* 添加图像媒体
*/
public MultimodalMessage addImage(Resource imageResource) {
this.mediaList.add(new Media(Media.Type.IMAGE, imageResource));
return this;
}
/**
* 添加图像媒体(Base64)
*/
public MultimodalMessage addImageBase64(String base64Data, String mimeType) {
this.mediaList.add(new Media(Media.Type.IMAGE, base64Data, mimeType));
return this;
}
/**
* 添加音频媒体
*/
public MultimodalMessage addAudio(Resource audioResource) {
this.mediaList.add(new Media(Media.Type.AUDIO, audioResource));
return this;
}
/**
* 添加视频媒体
*/
public MultimodalMessage addVideo(Resource videoResource) {
this.mediaList.add(new Media(Media.Type.VIDEO, videoResource));
return this;
}
/**
* 转换为UserMessage
*/
public UserMessage toUserMessage() {
if (mediaList.isEmpty()) {
return new UserMessage(text);
} else {
return new UserMessage(text, mediaList);
}
}
// Getters and Setters
public String getText() { return text; }
public void setText(String text) { this.text = text; }
public List<Media> getMediaList() { return mediaList; }
public void setMediaList(List<Media> mediaList) { this.mediaList = mediaList; }
/**
* 检查是否包含图像
*/
public boolean hasImages() {
return mediaList.stream().anyMatch(media -> media.getMimeType().getType().equals("image"));
}
/**
* 检查是否包含音频
*/
public boolean hasAudio() {
return mediaList.stream().anyMatch(media -> media.getMimeType().getType().equals("audio"));
}
/**
* 检查是否包含视频
*/
public boolean hasVideo() {
return mediaList.stream().anyMatch(media -> media.getMimeType().getType().equals("video"));
}
/**
* 获取媒体数量
*/
public int getMediaCount() {
return mediaList.size();
}
}
8.2 图像处理服务
8.2.1 基础图像处理
// ImageProcessingService.java
package com.example.springai.service;
import com.example.springai.multimodal.MultimodalMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
@Service
public class ImageProcessingService {
private final ChatModel chatModel;
public ImageProcessingService(ChatModel chatModel) {
this.chatModel = chatModel;
}
/**
* 图像描述生成
*/
public String describeImage(Resource imageResource) {
MultimodalMessage message = new MultimodalMessage("请详细描述这张图片的内容")
.addImage(imageResource);
return chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
}
/**
* 图像描述生成(自定义提示词)
*/
public String describeImage(Resource imageResource, String customPrompt) {
MultimodalMessage message = new MultimodalMessage(customPrompt)
.addImage(imageResource);
return chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
}
/**
* 图像问答
*/
public String askAboutImage(Resource imageResource, String question) {
MultimodalMessage message = new MultimodalMessage(question)
.addImage(imageResource);
return chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
}
/**
* 批量图像处理
*/
public List<ImageAnalysisResult> analyzeImages(List<Resource> imageResources, String prompt) {
return imageResources.stream()
.map(resource -> {
try {
String analysis = describeImage(resource, prompt);
return new ImageAnalysisResult(
resource.getFilename(),
analysis,
true,
null,
extractImageMetadata(resource)
);
} catch (Exception e) {
return new ImageAnalysisResult(
resource.getFilename(),
null,
false,
e.getMessage(),
null
);
}
})
.toList();
}
/**
* 图像比较
*/
public String compareImages(Resource image1, Resource image2) {
MultimodalMessage message = new MultimodalMessage(
"请比较这两张图片的相似点和不同点,并给出详细分析")
.addImage(image1)
.addImage(image2);
return chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
}
/**
* 图像分类
*/
public ImageClassificationResult classifyImage(Resource imageResource, List<String> categories) {
String categoriesStr = String.join(", ", categories);
String prompt = String.format(
"请将这张图片分类到以下类别之一:%s。只返回最匹配的类别名称和置信度(0-1之间的数字)。",
categoriesStr);
MultimodalMessage message = new MultimodalMessage(prompt)
.addImage(imageResource);
String response = chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
return parseClassificationResponse(response, categories);
}
/**
* 图像内容检测
*/
public ContentDetectionResult detectContent(Resource imageResource) {
String prompt = "请分析这张图片并检测以下内容:\n" +
"1. 主要物体和人物\n" +
"2. 场景和环境\n" +
"3. 颜色和风格\n" +
"4. 情感和氛围\n" +
"5. 任何文字内容\n" +
"请以结构化的方式返回结果。";
MultimodalMessage message = new MultimodalMessage(prompt)
.addImage(imageResource);
String analysis = chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
return parseContentDetection(analysis);
}
/**
* 图像安全检查
*/
public SafetyCheckResult checkImageSafety(Resource imageResource) {
String prompt = "请检查这张图片是否包含以下不当内容:\n" +
"1. 暴力或血腥内容\n" +
"2. 成人或不雅内容\n" +
"3. 仇恨或歧视性内容\n" +
"4. 危险或有害活动\n" +
"请返回安全评级(安全/警告/危险)和具体原因。";
MultimodalMessage message = new MultimodalMessage(prompt)
.addImage(imageResource);
String response = chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
return parseSafetyCheck(response);
}
/**
* 从MultipartFile处理图像
*/
public String processUploadedImage(MultipartFile file, String prompt) throws IOException {
// 验证文件类型
if (!isValidImageFile(file)) {
throw new IllegalArgumentException("不支持的图像格式");
}
// 转换为Base64
String base64Data = Base64.getEncoder().encodeToString(file.getBytes());
String mimeType = file.getContentType();
MultimodalMessage message = new MultimodalMessage(prompt)
.addImageBase64(base64Data, mimeType);
return chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
}
/**
* 提取图像元数据
*/
private Map<String, Object> extractImageMetadata(Resource imageResource) {
Map<String, Object> metadata = new HashMap<>();
try {
BufferedImage image = ImageIO.read(imageResource.getInputStream());
if (image != null) {
metadata.put("width", image.getWidth());
metadata.put("height", image.getHeight());
metadata.put("type", image.getType());
metadata.put("hasAlpha", image.getColorModel().hasAlpha());
}
metadata.put("filename", imageResource.getFilename());
metadata.put("contentLength", imageResource.contentLength());
} catch (IOException e) {
metadata.put("error", "无法读取图像元数据: " + e.getMessage());
}
return metadata;
}
/**
* 验证图像文件
*/
private boolean isValidImageFile(MultipartFile file) {
String contentType = file.getContentType();
return contentType != null &&
(contentType.startsWith("image/jpeg") ||
contentType.startsWith("image/png") ||
contentType.startsWith("image/gif") ||
contentType.startsWith("image/webp"));
}
/**
* 解析分类响应
*/
private ImageClassificationResult parseClassificationResponse(String response, List<String> categories) {
// 简单的解析逻辑,实际应用中可能需要更复杂的NLP处理
String category = "未知";
double confidence = 0.0;
for (String cat : categories) {
if (response.toLowerCase().contains(cat.toLowerCase())) {
category = cat;
break;
}
}
// 尝试提取置信度
try {
String[] parts = response.split("置信度|confidence");
if (parts.length > 1) {
String confidenceStr = parts[1].replaceAll("[^0-9.]", "");
if (!confidenceStr.isEmpty()) {
confidence = Double.parseDouble(confidenceStr);
if (confidence > 1.0) confidence = confidence / 100.0; // 转换百分比
}
}
} catch (NumberFormatException e) {
confidence = 0.5; // 默认置信度
}
return new ImageClassificationResult(category, confidence, response);
}
/**
* 解析内容检测结果
*/
private ContentDetectionResult parseContentDetection(String analysis) {
return new ContentDetectionResult(
extractSection(analysis, "主要物体", "物体"),
extractSection(analysis, "场景", "环境"),
extractSection(analysis, "颜色", "风格"),
extractSection(analysis, "情感", "氛围"),
extractSection(analysis, "文字", "文本"),
analysis
);
}
/**
* 解析安全检查结果
*/
private SafetyCheckResult parseSafetyCheck(String response) {
SafetyLevel level = SafetyLevel.SAFE;
String reason = "内容安全";
String lowerResponse = response.toLowerCase();
if (lowerResponse.contains("危险")) {
level = SafetyLevel.DANGEROUS;
} else if (lowerResponse.contains("警告")) {
level = SafetyLevel.WARNING;
}
return new SafetyCheckResult(level, reason, response);
}
/**
* 提取文本段落
*/
private String extractSection(String text, String... keywords) {
for (String keyword : keywords) {
int index = text.toLowerCase().indexOf(keyword.toLowerCase());
if (index != -1) {
int start = text.indexOf(':', index);
if (start != -1) {
int end = text.indexOf('\n', start);
if (end == -1) end = text.length();
return text.substring(start + 1, end).trim();
}
}
}
return "未检测到";
}
// 结果记录类
public record ImageAnalysisResult(
String filename,
String analysis,
boolean success,
String error,
Map<String, Object> metadata
) {}
public record ImageClassificationResult(
String category,
double confidence,
String rawResponse
) {}
public record ContentDetectionResult(
String objects,
String scene,
String colors,
String emotions,
String text,
String fullAnalysis
) {}
public record SafetyCheckResult(
SafetyLevel level,
String reason,
String details
) {}
public enum SafetyLevel {
SAFE, WARNING, DANGEROUS
}
}
8.2.2 视觉问答服务
// VisualQuestionAnsweringService.java
package com.example.springai.service;
import com.example.springai.multimodal.MultimodalMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.ArrayList;
@Service
public class VisualQuestionAnsweringService {
private final ChatModel chatModel;
public VisualQuestionAnsweringService(ChatModel chatModel) {
this.chatModel = chatModel;
}
/**
* 单图像问答
*/
public VQAResult answerQuestion(Resource imageResource, String question) {
long startTime = System.currentTimeMillis();
try {
MultimodalMessage message = new MultimodalMessage(question)
.addImage(imageResource);
String answer = chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
long processingTime = System.currentTimeMillis() - startTime;
return new VQAResult(
question,
answer,
true,
null,
processingTime,
calculateConfidence(answer)
);
} catch (Exception e) {
long processingTime = System.currentTimeMillis() - startTime;
return new VQAResult(
question,
null,
false,
e.getMessage(),
processingTime,
0.0
);
}
}
/**
* 批量问答
*/
public List<VQAResult> answerQuestions(Resource imageResource, List<String> questions) {
return questions.stream()
.map(question -> answerQuestion(imageResource, question))
.toList();
}
/**
* 多图像问答
*/
public VQAResult answerQuestionMultiImage(List<Resource> imageResources, String question) {
long startTime = System.currentTimeMillis();
try {
MultimodalMessage message = new MultimodalMessage(question);
imageResources.forEach(message::addImage);
String answer = chatModel.call(new Prompt(message.toUserMessage()))
.getResult().getOutput().getContent();
long processingTime = System.currentTimeMillis() - startTime;
return new VQAResult(
question,
answer,
true,
null,
processingTime,
calculateConfidence(answer)
);
} catch (Exception e) {
long processingTime = System.currentTimeMillis() - startTime;
return new VQAResult(
question,
null,
false,
e.getMessage(),
processingTime,
0.0
);
}
}
/**
* 结构化问答
*/
public StructuredVQAResult answerStructuredQuestion(Resource imageResource,
StructuredQuestion structuredQuestion) {
Map<String, VQAResult> results = new HashMap<>();
for (Map.Entry<String, String> entry : structuredQuestion.questions().entrySet()) {
String key = entry.getKey();
String question = entry.getValue();
VQAResult result = answerQuestion(imageResource, question);
results.put(key, result);
}
return new StructuredVQAResult(
structuredQuestion.category(),
results,
calculateOverallConfidence(results)
);
}
/**
* 对话式问答
*/
public ConversationalVQAResult answerConversational(Resource imageResource,
List<ConversationTurn> history,
String currentQuestion) {
StringBuilder conversationContext = new StringBuilder();
// 构建对话历史
for (ConversationTurn turn : history) {
conversationContext.append("问题: ").append(turn.question()).append("\n");
conversationContext.append("回答: ").append(turn.answer()).append("\n\n");
}
// 添加当前问题
conversationContext.append("问题: ").append(currentQuestion);
String contextualPrompt = "基于之前的对话历史和这张图片,请回答以下问题:\n\n" +
conversationContext.toString();
VQAResult result = answerQuestion(imageResource, contextualPrompt);
return new ConversationalVQAResult(
currentQuestion,
result.answer(),
result.success(),
result.error(),
history.size() + 1,
result.confidence()
);
}
/**
* 比较式问答
*/
public ComparisonVQAResult answerComparisonQuestion(List<Resource> imageResources,
String comparisonQuestion) {
if (imageResources.size() < 2) {
throw new IllegalArgumentException("比较问答至少需要两张图片");
}
MultimodalMessage message = new MultimodalMessage(
"请基于这些图片回答比较问题: " + comparisonQuestion);
imageResources.forEach(message::addImage);
VQAResult result = answerQuestion(imageResources.get(0), message.getText());
return new ComparisonVQAResult(
comparisonQuestion,
result.answer(),
imageResources.size(),
result.success(),
result.error(),
result.confidence()
);
}
/**
* 计算置信度
*/
private double calculateConfidence(String answer) {
if (answer == null || answer.trim().isEmpty()) {
return 0.0;
}
// 简单的置信度计算逻辑
double confidence = 0.5; // 基础置信度
// 根据答案长度调整
if (answer.length() > 50) confidence += 0.2;
if (answer.length() > 100) confidence += 0.1;
// 根据确定性词汇调整
String lowerAnswer = answer.toLowerCase();
if (lowerAnswer.contains("确定") || lowerAnswer.contains("明确")) {
confidence += 0.2;
}
if (lowerAnswer.contains("可能") || lowerAnswer.contains("也许")) {
confidence -= 0.1;
}
if (lowerAnswer.contains("不确定") || lowerAnswer.contains("不清楚")) {
confidence -= 0.2;
}
return Math.max(0.0, Math.min(1.0, confidence));
}
/**
* 计算整体置信度
*/
private double calculateOverallConfidence(Map<String, VQAResult> results) {
return results.values().stream()
.mapToDouble(VQAResult::confidence)
.average()
.orElse(0.0);
}
// 结果记录类
public record VQAResult(
String question,
String answer,
boolean success,
String error,
long processingTimeMs,
double confidence
) {}
public record StructuredQuestion(
String category,
Map<String, String> questions
) {}
public record StructuredVQAResult(
String category,
Map<String, VQAResult> results,
double overallConfidence
) {}
public record ConversationTurn(
String question,
String answer
) {}
public record ConversationalVQAResult(
String question,
String answer,
boolean success,
String error,
int turnNumber,
double confidence
) {}
public record ComparisonVQAResult(
String question,
String answer,
int imageCount,
boolean success,
String error,
double confidence
) {}
}
8.3 多模态对话系统
8.3.1 多模态对话服务
// MultimodalConversationService.java
package com.example.springai.service;
import com.example.springai.multimodal.MultimodalMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@Service
public class MultimodalConversationService {
private final ChatModel chatModel;
private final Map<String, ConversationSession> sessions = new ConcurrentHashMap<>();
public MultimodalConversationService(ChatModel chatModel) {
this.chatModel = chatModel;
}
/**
* 开始新的对话会话
*/
public String startConversation(String userId) {
String sessionId = UUID.randomUUID().toString();
ConversationSession session = new ConversationSession(
sessionId, userId, System.currentTimeMillis());
sessions.put(sessionId, session);
return sessionId;
}
/**
* 发送文本消息
*/
public ConversationResponse sendTextMessage(String sessionId, String message) {
ConversationSession session = getSession(sessionId);
try {
String response = chatModel.call(new Prompt(message))
.getResult().getOutput().getContent();
// 记录对话轮次
ConversationTurn turn = new ConversationTurn(
UUID.randomUUID().toString(),
message,
null,
response,
System.currentTimeMillis(),
MessageType.TEXT
);
session.addTurn(turn);
return new ConversationResponse(
sessionId,
response,
MessageType.TEXT,
true,
null,
session.getTurnCount()
);
} catch (Exception e) {
return new ConversationResponse(
sessionId,
null,
MessageType.TEXT,
false,
e.getMessage(),
session.getTurnCount()
);
}
}
/**
* 发送图像消息
*/
public ConversationResponse sendImageMessage(String sessionId,
Resource imageResource,
String message) {
ConversationSession session = getSession(sessionId);
try {
MultimodalMessage multimodalMessage = new MultimodalMessage(message)
.addImage(imageResource);
String response = chatModel.call(new Prompt(multimodalMessage.toUserMessage()))
.getResult().getOutput().getContent();
// 记录对话轮次
ConversationTurn turn = new ConversationTurn(
UUID.randomUUID().toString(),
message,
List.of(imageResource.getFilename()),
response,
System.currentTimeMillis(),
MessageType.IMAGE
);
session.addTurn(turn);
return new ConversationResponse(
sessionId,
response,
MessageType.IMAGE,
true,
null,
session.getTurnCount()
);
} catch (Exception e) {
return new ConversationResponse(
sessionId,
null,
MessageType.IMAGE,
false,
e.getMessage(),
session.getTurnCount()
);
}
}
/**
* 发送多图像消息
*/
public ConversationResponse sendMultiImageMessage(String sessionId,
List<Resource> imageResources,
String message) {
ConversationSession session = getSession(sessionId);
try {
MultimodalMessage multimodalMessage = new MultimodalMessage(message);
imageResources.forEach(multimodalMessage::addImage);
String response = chatModel.call(new Prompt(multimodalMessage.toUserMessage()))
.getResult().getOutput().getContent();
// 记录对话轮次
List<String> filenames = imageResources.stream()
.map(Resource::getFilename)
.toList();
ConversationTurn turn = new ConversationTurn(
UUID.randomUUID().toString(),
message,
filenames,
response,
System.currentTimeMillis(),
MessageType.MULTI_IMAGE
);
session.addTurn(turn);
return new ConversationResponse(
sessionId,
response,
MessageType.MULTI_IMAGE,
true,
null,
session.getTurnCount()
);
} catch (Exception e) {
return new ConversationResponse(
sessionId,
null,
MessageType.MULTI_IMAGE,
false,
e.getMessage(),
session.getTurnCount()
);
}
}
/**
* 获取对话历史
*/
public ConversationHistory getConversationHistory(String sessionId) {
ConversationSession session = getSession(sessionId);
return new ConversationHistory(
sessionId,
session.getUserId(),
session.getTurns(),
session.getStartTime(),
System.currentTimeMillis()
);
}
/**
* 清除对话会话
*/
public boolean clearConversation(String sessionId) {
return sessions.remove(sessionId) != null;
}
/**
* 获取活跃会话数量
*/
public int getActiveSessionCount() {
return sessions.size();
}
/**
* 获取会话
*/
private ConversationSession getSession(String sessionId) {
ConversationSession session = sessions.get(sessionId);
if (session == null) {
throw new IllegalArgumentException("会话不存在: " + sessionId);
}
return session;
}
// 内部类和枚举
public enum MessageType {
TEXT, IMAGE, MULTI_IMAGE, AUDIO, VIDEO
}
public static class ConversationSession {
private final String sessionId;
private final String userId;
private final long startTime;
private final List<ConversationTurn> turns = new ArrayList<>();
public ConversationSession(String sessionId, String userId, long startTime) {
this.sessionId = sessionId;
this.userId = userId;
this.startTime = startTime;
}
public void addTurn(ConversationTurn turn) {
turns.add(turn);
}
public String getSessionId() { return sessionId; }
public String getUserId() { return userId; }
public long getStartTime() { return startTime; }
public List<ConversationTurn> getTurns() { return new ArrayList<>(turns); }
public int getTurnCount() { return turns.size(); }
}
public record ConversationTurn(
String turnId,
String userMessage,
List<String> imageFilenames,
String assistantResponse,
long timestamp,
MessageType messageType
) {}
public record ConversationResponse(
String sessionId,
String response,
MessageType messageType,
boolean success,
String error,
int turnNumber
) {}
public record ConversationHistory(
String sessionId,
String userId,
List<ConversationTurn> turns,
long startTime,
long endTime
) {}
}
8.3.2 多模态对话控制器
// MultimodalConversationController.java
package com.example.springai.controller;
import com.example.springai.service.MultimodalConversationService;
import com.example.springai.service.MultimodalConversationService.*;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.core.io.ByteArrayResource;
import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
@RestController
@RequestMapping("/api/multimodal/conversation")
public class MultimodalConversationController {
private final MultimodalConversationService conversationService;
public MultimodalConversationController(MultimodalConversationService conversationService) {
this.conversationService = conversationService;
}
/**
* 开始新对话
*/
@PostMapping("/start")
public ResponseEntity<StartConversationResponse> startConversation(
@RequestParam String userId) {
try {
String sessionId = conversationService.startConversation(userId);
return ResponseEntity.ok(new StartConversationResponse(sessionId, true, null));
} catch (Exception e) {
return ResponseEntity.badRequest()
.body(new StartConversationResponse(null, false, e.getMessage()));
}
}
/**
* 发送文本消息
*/
@PostMapping("/text")
public ResponseEntity<ConversationResponse> sendTextMessage(
@RequestParam String sessionId,
@RequestParam String message) {
try {
ConversationResponse response = conversationService.sendTextMessage(sessionId, message);
return ResponseEntity.ok(response);
} catch (Exception e) {
return ResponseEntity.badRequest()
.body(new ConversationResponse(sessionId, null, MessageType.TEXT,
false, e.getMessage(), 0));
}
}
/**
* 发送图像消息
*/
@PostMapping("/image")
public ResponseEntity<ConversationResponse> sendImageMessage(
@RequestParam String sessionId,
@RequestParam String message,
@RequestParam("image") MultipartFile imageFile) {
try {
Resource imageResource = new ByteArrayResource(imageFile.getBytes()) {
@Override
public String getFilename() {
return imageFile.getOriginalFilename();
}
};
ConversationResponse response = conversationService.sendImageMessage(
sessionId, imageResource, message);
return ResponseEntity.ok(response);
} catch (IOException e) {
return ResponseEntity.badRequest()
.body(new ConversationResponse(sessionId, null, MessageType.IMAGE,
false, "图像处理失败: " + e.getMessage(), 0));
} catch (Exception e) {
return ResponseEntity.badRequest()
.body(new ConversationResponse(sessionId, null, MessageType.IMAGE,
false, e.getMessage(), 0));
}
}
/**
* 发送多图像消息
*/
@PostMapping("/multi-image")
public ResponseEntity<ConversationResponse> sendMultiImageMessage(
@RequestParam String sessionId,
@RequestParam String message,
@RequestParam("images") List<MultipartFile> imageFiles) {
try {
List<Resource> imageResources = imageFiles.stream()
.map(file -> {
try {
return new ByteArrayResource(file.getBytes()) {
@Override
public String getFilename() {
return file.getOriginalFilename();
}
};
} catch (IOException e) {
throw new RuntimeException("图像处理失败", e);
}
})
.collect(Collectors.toList());
ConversationResponse response = conversationService.sendMultiImageMessage(
sessionId, imageResources, message);
return ResponseEntity.ok(response);
} catch (Exception e) {
return ResponseEntity.badRequest()
.body(new ConversationResponse(sessionId, null, MessageType.MULTI_IMAGE,
false, e.getMessage(), 0));
}
}
/**
* 获取对话历史
*/
@GetMapping("/history/{sessionId}")
public ResponseEntity<ConversationHistory> getConversationHistory(
@PathVariable String sessionId) {
try {
ConversationHistory history = conversationService.getConversationHistory(sessionId);
return ResponseEntity.ok(history);
} catch (Exception e) {
return ResponseEntity.notFound().build();
}
}
/**
* 清除对话
*/
@DeleteMapping("/clear/{sessionId}")
public ResponseEntity<ClearConversationResponse> clearConversation(
@PathVariable String sessionId) {
boolean cleared = conversationService.clearConversation(sessionId);
return ResponseEntity.ok(new ClearConversationResponse(cleared,
cleared ? "对话已清除" : "对话不存在"));
}
/**
* 获取活跃会话统计
*/
@GetMapping("/stats")
public ResponseEntity<ConversationStats> getConversationStats() {
int activeSessionCount = conversationService.getActiveSessionCount();
return ResponseEntity.ok(new ConversationStats(activeSessionCount));
}
// 响应记录类
public record StartConversationResponse(
String sessionId,
boolean success,
String error
) {}
public record ClearConversationResponse(
boolean success,
String message
) {}
public record ConversationStats(
int activeSessionCount
) {}
}
8.4 图像生成与编辑
8.4.1 图像生成服务
// ImageGenerationService.java
package com.example.springai.service;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.ImageOptions;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.concurrent.CompletableFuture;
@Service
public class ImageGenerationService {
private final ImageModel imageModel;
public ImageGenerationService(ImageModel imageModel) {
this.imageModel = imageModel;
}
/**
* 基础图像生成
*/
public ImageGenerationResult generateImage(String prompt) {
return generateImage(prompt, ImageGenerationOptions.defaultOptions());
}
/**
* 带选项的图像生成
*/
public ImageGenerationResult generateImage(String prompt, ImageGenerationOptions options) {
long startTime = System.currentTimeMillis();
try {
ImageOptions imageOptions = ImageOptions.builder()
.withWidth(options.width())
.withHeight(options.height())
.withModel(options.model())
.build();
ImagePrompt imagePrompt = new ImagePrompt(prompt, imageOptions);
ImageResponse response = imageModel.call(imagePrompt);
long processingTime = System.currentTimeMillis() - startTime;
return new ImageGenerationResult(
prompt,
response.getResults().get(0).getOutput().getUrl(),
true,
null,
processingTime,
options
);
} catch (Exception e) {
long processingTime = System.currentTimeMillis() - startTime;
return new ImageGenerationResult(
prompt,
null,
false,
e.getMessage(),
processingTime,
options
);
}
}
/**
* 批量图像生成
*/
public List<ImageGenerationResult> generateImages(List<String> prompts,
ImageGenerationOptions options) {
return prompts.stream()
.map(prompt -> generateImage(prompt, options))
.toList();
}
/**
* 异步图像生成
*/
public CompletableFuture<ImageGenerationResult> generateImageAsync(String prompt,
ImageGenerationOptions options) {
return CompletableFuture.supplyAsync(() -> generateImage(prompt, options));
}
/**
* 风格化图像生成
*/
public ImageGenerationResult generateStylizedImage(String prompt, ImageStyle style) {
String stylizedPrompt = applyStyle(prompt, style);
ImageGenerationOptions options = ImageGenerationOptions.forStyle(style);
return generateImage(stylizedPrompt, options);
}
/**
* 变体图像生成
*/
public List<ImageGenerationResult> generateVariations(String basePrompt,
int variationCount,
ImageGenerationOptions options) {
List<String> variations = generatePromptVariations(basePrompt, variationCount);
return generateImages(variations, options);
}
/**
* 应用风格到提示词
*/
private String applyStyle(String prompt, ImageStyle style) {
Map<ImageStyle, String> styleModifiers = Map.of(
ImageStyle.REALISTIC, ", photorealistic, high quality, detailed",
ImageStyle.ARTISTIC, ", artistic, creative, stylized",
ImageStyle.CARTOON, ", cartoon style, animated, colorful",
ImageStyle.ABSTRACT, ", abstract art, modern, conceptual",
ImageStyle.VINTAGE, ", vintage style, retro, classic"
);
return prompt + styleModifiers.getOrDefault(style, "");
}
/**
* 生成提示词变体
*/
private List<String> generatePromptVariations(String basePrompt, int count) {
List<String> variations = new java.util.ArrayList<>();
variations.add(basePrompt); // 原始提示词
String[] modifiers = {
", different angle",
", different lighting",
", different color scheme",
", different composition",
", different mood"
};
for (int i = 1; i < count && i <= modifiers.length; i++) {
variations.add(basePrompt + modifiers[i - 1]);
}
return variations;
}
// 枚举和记录类
public enum ImageStyle {
REALISTIC, ARTISTIC, CARTOON, ABSTRACT, VINTAGE
}
public record ImageGenerationOptions(
int width,
int height,
String model,
ImageStyle style
) {
public static ImageGenerationOptions defaultOptions() {
return new ImageGenerationOptions(1024, 1024, "dall-e-3", ImageStyle.REALISTIC);
}
public static ImageGenerationOptions forStyle(ImageStyle style) {
return new ImageGenerationOptions(1024, 1024, "dall-e-3", style);
}
}
public record ImageGenerationResult(
String prompt,
String imageUrl,
boolean success,
String error,
long processingTimeMs,
ImageGenerationOptions options
) {}
}
8.4.2 图像生成控制器
// ImageGenerationController.java
package com.example.springai.controller;
import com.example.springai.service.ImageGenerationService;
import com.example.springai.service.ImageGenerationService.*;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import java.util.concurrent.CompletableFuture;
@RestController
@RequestMapping("/api/image/generation")
public class ImageGenerationController {
private final ImageGenerationService imageGenerationService;
public ImageGenerationController(ImageGenerationService imageGenerationService) {
this.imageGenerationService = imageGenerationService;
}
/**
* 基础图像生成
*/
@PostMapping("/generate")
public ResponseEntity<ImageGenerationResult> generateImage(
@RequestParam String prompt) {
try {
ImageGenerationResult result = imageGenerationService.generateImage(prompt);
return ResponseEntity.ok(result);
} catch (Exception e) {
return ResponseEntity.badRequest()
.body(new ImageGenerationResult(prompt, null, false, e.getMessage(), 0, null));
}
}
/**
* 带选项的图像生成
*/
@PostMapping("/generate/advanced")
public ResponseEntity<ImageGenerationResult> generateImageWithOptions(
@RequestParam String prompt,
@RequestParam(defaultValue = "1024") int width,
@RequestParam(defaultValue = "1024") int height,
@RequestParam(defaultValue = "dall-e-3") String model,
@RequestParam(defaultValue = "REALISTIC") ImageStyle style) {
try {
ImageGenerationOptions options = new ImageGenerationOptions(width, height, model, style);
ImageGenerationResult result = imageGenerationService.generateImage(prompt, options);
return ResponseEntity.ok(result);
} catch (Exception e) {
return ResponseEntity.badRequest()
.body(new ImageGenerationResult(prompt, null, false, e.getMessage(), 0, null));
}
}
/**
* 批量图像生成
*/
@PostMapping("/generate/batch")
public ResponseEntity<List<ImageGenerationResult>> generateImages(
@RequestBody List<String> prompts,
@RequestParam(defaultValue = "1024") int width,
@RequestParam(defaultValue = "1024") int height,
@RequestParam(defaultValue = "dall-e-3") String model,
@RequestParam(defaultValue = "REALISTIC") ImageStyle style) {
try {
ImageGenerationOptions options = new ImageGenerationOptions(width, height, model, style);
List<ImageGenerationResult> results = imageGenerationService.generateImages(prompts, options);
return ResponseEntity.ok(results);
} catch (Exception e) {
return ResponseEntity.badRequest().build();
}
}
/**
* 异步图像生成
*/
@PostMapping("/generate/async")
public ResponseEntity<String> generateImageAsync(
@RequestParam String prompt,
@RequestParam(defaultValue = "1024") int width,
@RequestParam(defaultValue = "1024") int height,
@RequestParam(defaultValue = "dall-e-3") String model,
@RequestParam(defaultValue = "REALISTIC") ImageStyle style) {
try {
ImageGenerationOptions options = new ImageGenerationOptions(width, height, model, style);
CompletableFuture<ImageGenerationResult> future =
imageGenerationService.generateImageAsync(prompt, options);
// 这里可以返回任务ID,客户端可以轮询结果
return ResponseEntity.ok("图像生成任务已启动,请稍后查询结果");
} catch (Exception e) {
return ResponseEntity.badRequest().body("任务启动失败: " + e.getMessage());
}
}
/**
* 风格化图像生成
*/
@PostMapping("/generate/stylized")
public ResponseEntity<ImageGenerationResult> generateStylizedImage(
@RequestParam String prompt,
@RequestParam ImageStyle style) {
try {
ImageGenerationResult result = imageGenerationService.generateStylizedImage(prompt, style);
return ResponseEntity.ok(result);
} catch (Exception e) {
return ResponseEntity.badRequest()
.body(new ImageGenerationResult(prompt, null, false, e.getMessage(), 0, null));
}
}
/**
* 变体图像生成
*/
@PostMapping("/generate/variations")
public ResponseEntity<List<ImageGenerationResult>> generateVariations(
@RequestParam String basePrompt,
@RequestParam(defaultValue = "3") int variationCount,
@RequestParam(defaultValue = "1024") int width,
@RequestParam(defaultValue = "1024") int height,
@RequestParam(defaultValue = "dall-e-3") String model,
@RequestParam(defaultValue = "REALISTIC") ImageStyle style) {
try {
ImageGenerationOptions options = new ImageGenerationOptions(width, height, model, style);
List<ImageGenerationResult> results = imageGenerationService.generateVariations(
basePrompt, variationCount, options);
return ResponseEntity.ok(results);
} catch (Exception e) {
return ResponseEntity.badRequest().build();
}
}
/**
* 获取支持的图像风格
*/
@GetMapping("/styles")
public ResponseEntity<ImageStyle[]> getSupportedStyles() {
return ResponseEntity.ok(ImageStyle.values());
}
}
8.5 配置文件
8.5.1 application.yml
# Spring AI 多模态配置
spring:
ai:
# OpenAI 配置
openai:
api-key: ${OPENAI_API_KEY}
chat:
options:
model: gpt-4-vision-preview
temperature: 0.7
max-tokens: 1000
image:
options:
model: dall-e-3
size: 1024x1024
quality: standard
style: natural
# 向量存储配置(用于图像嵌入)
vectorstore:
chroma:
url: http://localhost:8000
collection-name: image-embeddings
# 文件上传配置
servlet:
multipart:
max-file-size: 10MB
max-request-size: 50MB
enabled: true
# 多模态处理配置
multimodal:
image:
# 图像处理配置
processing:
max-size: 10485760 # 10MB
allowed-formats:
- jpg
- jpeg
- png
- gif
- webp
resize:
enabled: true
max-width: 2048
max-height: 2048
# 图像生成配置
generation:
default-model: dall-e-3
default-size: 1024x1024
timeout: 60000 # 60秒
retry:
max-attempts: 3
delay: 1000
# 安全检查配置
safety:
enabled: true
strict-mode: false
blocked-content:
- violence
- adult
- hate
# 对话配置
conversation:
session:
timeout: 3600000 # 1小时
max-turns: 100
cleanup-interval: 300000 # 5分钟
# 缓存配置
cache:
enabled: true
ttl: 1800 # 30分钟
max-size: 1000
# 日志配置
logging:
level:
com.example.springai.multimodal: DEBUG
org.springframework.ai: INFO
pattern:
console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n"
8.5.2 多模态配置类
// MultimodalProperties.java
package com.example.springai.config;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
@ConfigurationProperties(prefix = "multimodal")
public class MultimodalProperties {
private Image image = new Image();
private Conversation conversation = new Conversation();
// Getters and Setters
public Image getImage() { return image; }
public void setImage(Image image) { this.image = image; }
public Conversation getConversation() { return conversation; }
public void setConversation(Conversation conversation) { this.conversation = conversation; }
public static class Image {
private Processing processing = new Processing();
private Generation generation = new Generation();
private Safety safety = new Safety();
// Getters and Setters
public Processing getProcessing() { return processing; }
public void setProcessing(Processing processing) { this.processing = processing; }
public Generation getGeneration() { return generation; }
public void setGeneration(Generation generation) { this.generation = generation; }
public Safety getSafety() { return safety; }
public void setSafety(Safety safety) { this.safety = safety; }
public static class Processing {
private long maxSize = 10485760; // 10MB
private List<String> allowedFormats = List.of("jpg", "jpeg", "png", "gif", "webp");
private Resize resize = new Resize();
// Getters and Setters
public long getMaxSize() { return maxSize; }
public void setMaxSize(long maxSize) { this.maxSize = maxSize; }
public List<String> getAllowedFormats() { return allowedFormats; }
public void setAllowedFormats(List<String> allowedFormats) { this.allowedFormats = allowedFormats; }
public Resize getResize() { return resize; }
public void setResize(Resize resize) { this.resize = resize; }
public static class Resize {
private boolean enabled = true;
private int maxWidth = 2048;
private int maxHeight = 2048;
// Getters and Setters
public boolean isEnabled() { return enabled; }
public void setEnabled(boolean enabled) { this.enabled = enabled; }
public int getMaxWidth() { return maxWidth; }
public void setMaxWidth(int maxWidth) { this.maxWidth = maxWidth; }
public int getMaxHeight() { return maxHeight; }
public void setMaxHeight(int maxHeight) { this.maxHeight = maxHeight; }
}
}
public static class Generation {
private String defaultModel = "dall-e-3";
private String defaultSize = "1024x1024";
private long timeout = 60000;
private Retry retry = new Retry();
// Getters and Setters
public String getDefaultModel() { return defaultModel; }
public void setDefaultModel(String defaultModel) { this.defaultModel = defaultModel; }
public String getDefaultSize() { return defaultSize; }
public void setDefaultSize(String defaultSize) { this.defaultSize = defaultSize; }
public long getTimeout() { return timeout; }
public void setTimeout(long timeout) { this.timeout = timeout; }
public Retry getRetry() { return retry; }
public void setRetry(Retry retry) { this.retry = retry; }
public static class Retry {
private int maxAttempts = 3;
private long delay = 1000;
// Getters and Setters
public int getMaxAttempts() { return maxAttempts; }
public void setMaxAttempts(int maxAttempts) { this.maxAttempts = maxAttempts; }
public long getDelay() { return delay; }
public void setDelay(long delay) { this.delay = delay; }
}
}
public static class Safety {
private boolean enabled = true;
private boolean strictMode = false;
private List<String> blockedContent = List.of("violence", "adult", "hate");
// Getters and Setters
public boolean isEnabled() { return enabled; }
public void setEnabled(boolean enabled) { this.enabled = enabled; }
public boolean isStrictMode() { return strictMode; }
public void setStrictMode(boolean strictMode) { this.strictMode = strictMode; }
public List<String> getBlockedContent() { return blockedContent; }
public void setBlockedContent(List<String> blockedContent) { this.blockedContent = blockedContent; }
}
}
public static class Conversation {
private Session session = new Session();
private Cache cache = new Cache();
// Getters and Setters
public Session getSession() { return session; }
public void setSession(Session session) { this.session = session; }
public Cache getCache() { return cache; }
public void setCache(Cache cache) { this.cache = cache; }
public static class Session {
private long timeout = 3600000; // 1小时
private int maxTurns = 100;
private long cleanupInterval = 300000; // 5分钟
// Getters and Setters
public long getTimeout() { return timeout; }
public void setTimeout(long timeout) { this.timeout = timeout; }
public int getMaxTurns() { return maxTurns; }
public void setMaxTurns(int maxTurns) { this.maxTurns = maxTurns; }
public long getCleanupInterval() { return cleanupInterval; }
public void setCleanupInterval(long cleanupInterval) { this.cleanupInterval = cleanupInterval; }
}
public static class Cache {
private boolean enabled = true;
private long ttl = 1800; // 30分钟
private int maxSize = 1000;
// Getters and Setters
public boolean isEnabled() { return enabled; }
public void setEnabled(boolean enabled) { this.enabled = enabled; }
public long getTtl() { return ttl; }
public void setTtl(long ttl) { this.ttl = ttl; }
public int getMaxSize() { return maxSize; }
public void setMaxSize(int maxSize) { this.maxSize = maxSize; }
}
}
}
8.6 本章总结
8.6.1 核心要点
多模态AI基础
- 理解多模态AI的概念和应用场景
- 掌握Spring AI的多模态架构设计
- 学会处理文本、图像等多种模态数据
图像处理服务
- 实现图像描述、问答、分类等功能
- 掌握图像安全检查和内容过滤
- 学会批量处理和性能优化
视觉问答系统
- 构建单图像和多图像问答功能
- 实现结构化问答和对话式问答
- 掌握置信度计算和结果评估
多模态对话
- 设计会话管理和状态维护
- 实现文本、图像混合对话
- 掌握对话历史和上下文管理
图像生成
- 实现基于文本的图像生成
- 掌握风格化和变体生成
- 学会异步处理和批量生成
8.6.2 最佳实践
性能优化
- 合理设置图像尺寸和质量
- 使用异步处理提高响应速度
- 实现缓存机制减少重复计算
安全考虑
- 实施内容安全检查
- 限制文件大小和格式
- 保护用户隐私和数据安全
用户体验
- 提供实时反馈和进度提示
- 支持多种交互方式
- 优化错误处理和异常情况
系统设计
- 模块化设计便于扩展
- 配置化管理提高灵活性
- 监控和日志记录便于维护
8.6.3 练习题
基础练习
- 实现一个简单的图像描述API
- 创建图像分类服务
- 构建基础的视觉问答功能
进阶练习
- 设计多轮对话的图像问答系统
- 实现图像生成的批量处理
- 构建图像内容安全检查服务
综合项目
- 开发完整的多模态聊天应用
- 构建图像处理和生成的工作流
- 设计多模态内容管理系统
通过本章学习,你将掌握Spring AI框架中多模态AI的核心功能,能够构建处理文本、图像等多种数据类型的智能应用,为用户提供丰富的交互体验。