6.1 函数调用基础
6.1.1 函数调用概念
函数调用(Function Calling)是现代AI模型的重要功能,允许AI模型调用外部函数来获取实时数据、执行计算或与外部系统交互。
核心概念: 1. 函数定义:描述函数的名称、参数和功能 2. 函数注册:将函数注册到AI模型中 3. 函数调用:AI模型根据用户输入决定调用哪个函数 4. 结果处理:处理函数执行结果并生成最终回答
6.1.2 Spring AI函数调用架构
// Function接口
public interface Function<T, R> {
/**
* 函数名称
*/
String getName();
/**
* 函数描述
*/
String getDescription();
/**
* 输入参数模式
*/
String getInputTypeSchema();
/**
* 执行函数
*/
R apply(T input);
}
// FunctionCallback接口
public interface FunctionCallback {
String getName();
String getDescription();
String getInputTypeSchema();
String call(String functionInput);
}
// 函数调用请求
public class FunctionCallingOptions {
private List<String> functions;
private FunctionCallingMode functionCallingMode;
public enum FunctionCallingMode {
AUTO, // 自动决定是否调用函数
NONE, // 不调用函数
REQUIRED // 必须调用函数
}
// Getters and Setters
public List<String> getFunctions() { return functions; }
public void setFunctions(List<String> functions) { this.functions = functions; }
public FunctionCallingMode getFunctionCallingMode() { return functionCallingMode; }
public void setFunctionCallingMode(FunctionCallingMode mode) { this.functionCallingMode = mode; }
}
6.2 基础函数实现
6.2.1 天气查询函数
// WeatherFunction.java
package com.example.springai.function;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;
import java.util.function.Function;
@Component
public class WeatherFunction implements Function<WeatherFunction.Request, WeatherFunction.Response> {
@Override
public Response apply(Request request) {
// 模拟天气API调用
return getWeatherInfo(request.location());
}
/**
* 获取天气信息
*/
private Response getWeatherInfo(String location) {
// 这里应该调用真实的天气API
// 为了演示,返回模拟数据
switch (location.toLowerCase()) {
case "北京":
return new Response(location, "晴天", 25, 60, "微风");
case "上海":
return new Response(location, "多云", 22, 70, "东南风");
case "广州":
return new Response(location, "雨天", 28, 85, "南风");
default:
return new Response(location, "未知", 0, 0, "无数据");
}
}
/**
* 请求参数
*/
public record Request(
@JsonProperty(required = true, value = "location")
@JsonPropertyDescription("要查询天气的城市名称")
String location
) {}
/**
* 响应结果
*/
public record Response(
String location,
String weather,
int temperature,
int humidity,
String wind
) {
@Override
public String toString() {
return String.format(
"%s的天气:%s,温度%d°C,湿度%d%%,风力%s",
location, weather, temperature, humidity, wind
);
}
}
}
// WeatherService.java
package com.example.springai.service;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import java.util.Map;
@Service
public class WeatherService {
@Value("${weather.api.key:demo_key}")
private String apiKey;
@Value("${weather.api.url:https://api.openweathermap.org/data/2.5/weather}")
private String apiUrl;
private final RestTemplate restTemplate;
public WeatherService() {
this.restTemplate = new RestTemplate();
}
/**
* 获取实时天气数据
*/
public WeatherData getCurrentWeather(String city) {
try {
String url = String.format("%s?q=%s&appid=%s&units=metric&lang=zh_cn",
apiUrl, city, apiKey);
Map<String, Object> response = restTemplate.getForObject(url, Map.class);
if (response != null) {
return parseWeatherResponse(response);
}
} catch (Exception e) {
System.err.println("获取天气数据失败: " + e.getMessage());
}
return null;
}
/**
* 解析天气API响应
*/
private WeatherData parseWeatherResponse(Map<String, Object> response) {
Map<String, Object> main = (Map<String, Object>) response.get("main");
Map<String, Object> weather = ((java.util.List<Map<String, Object>>) response.get("weather")).get(0);
Map<String, Object> wind = (Map<String, Object>) response.get("wind");
return new WeatherData(
(String) response.get("name"),
(String) weather.get("description"),
((Number) main.get("temp")).doubleValue(),
((Number) main.get("humidity")).intValue(),
wind != null ? ((Number) wind.get("speed")).doubleValue() : 0.0
);
}
/**
* 天气数据模型
*/
public record WeatherData(
String city,
String description,
double temperature,
int humidity,
double windSpeed
) {}
}
6.2.2 计算器函数
// CalculatorFunction.java
package com.example.springai.function;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;
import java.util.function.Function;
@Component
public class CalculatorFunction implements Function<CalculatorFunction.Request, CalculatorFunction.Response> {
@Override
public Response apply(Request request) {
try {
double result = performCalculation(request.expression());
return new Response(request.expression(), result, true, null);
} catch (Exception e) {
return new Response(request.expression(), 0.0, false, e.getMessage());
}
}
/**
* 执行数学计算
*/
private double performCalculation(String expression) {
// 简单的表达式计算器
// 在实际应用中,应该使用更强大的数学表达式解析器
expression = expression.replaceAll("\\s+", "");
// 支持基本的四则运算
if (expression.contains("+")) {
String[] parts = expression.split("\\+");
return Double.parseDouble(parts[0]) + Double.parseDouble(parts[1]);
} else if (expression.contains("-")) {
String[] parts = expression.split("-");
if (parts.length == 2) {
return Double.parseDouble(parts[0]) - Double.parseDouble(parts[1]);
}
} else if (expression.contains("*")) {
String[] parts = expression.split("\\*");
return Double.parseDouble(parts[0]) * Double.parseDouble(parts[1]);
} else if (expression.contains("/")) {
String[] parts = expression.split("/");
double divisor = Double.parseDouble(parts[1]);
if (divisor == 0) {
throw new ArithmeticException("除数不能为零");
}
return Double.parseDouble(parts[0]) / divisor;
}
// 如果没有运算符,尝试解析为数字
return Double.parseDouble(expression);
}
/**
* 请求参数
*/
public record Request(
@JsonProperty(required = true, value = "expression")
@JsonPropertyDescription("要计算的数学表达式,支持 +、-、*、/ 运算")
String expression
) {}
/**
* 响应结果
*/
public record Response(
String expression,
double result,
boolean success,
String error
) {
@Override
public String toString() {
if (success) {
return String.format("%s = %.2f", expression, result);
} else {
return String.format("计算错误:%s", error);
}
}
}
}
// AdvancedCalculatorFunction.java
package com.example.springai.function;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;
import javax.script.ScriptEngine;
import javax.script.ScriptEngineManager;
import java.util.function.Function;
@Component
public class AdvancedCalculatorFunction implements Function<AdvancedCalculatorFunction.Request, AdvancedCalculatorFunction.Response> {
private final ScriptEngine scriptEngine;
public AdvancedCalculatorFunction() {
ScriptEngineManager manager = new ScriptEngineManager();
this.scriptEngine = manager.getEngineByName("JavaScript");
}
@Override
public Response apply(Request request) {
try {
Object result = evaluateExpression(request.expression());
double numericResult = ((Number) result).doubleValue();
return new Response(
request.expression(),
numericResult,
true,
null,
formatResult(numericResult)
);
} catch (Exception e) {
return new Response(
request.expression(),
0.0,
false,
e.getMessage(),
null
);
}
}
/**
* 计算数学表达式
*/
private Object evaluateExpression(String expression) throws Exception {
// 预处理表达式,添加数学函数支持
String processedExpression = preprocessExpression(expression);
// 使用JavaScript引擎计算
return scriptEngine.eval(processedExpression);
}
/**
* 预处理表达式
*/
private String preprocessExpression(String expression) {
// 替换数学函数
expression = expression.replaceAll("sin\\(", "Math.sin(");
expression = expression.replaceAll("cos\\(", "Math.cos(");
expression = expression.replaceAll("tan\\(", "Math.tan(");
expression = expression.replaceAll("sqrt\\(", "Math.sqrt(");
expression = expression.replaceAll("log\\(", "Math.log(");
expression = expression.replaceAll("pow\\(", "Math.pow(");
expression = expression.replaceAll("abs\\(", "Math.abs(");
// 替换常数
expression = expression.replaceAll("\\bpi\\b", "Math.PI");
expression = expression.replaceAll("\\be\\b", "Math.E");
return expression;
}
/**
* 格式化结果
*/
private String formatResult(double result) {
if (result == (long) result) {
return String.valueOf((long) result);
} else {
return String.format("%.6f", result).replaceAll("0+$", "").replaceAll("\\.$", "");
}
}
/**
* 请求参数
*/
public record Request(
@JsonProperty(required = true, value = "expression")
@JsonPropertyDescription("要计算的数学表达式,支持基本运算和数学函数(sin, cos, tan, sqrt, log, pow, abs等)")
String expression
) {}
/**
* 响应结果
*/
public record Response(
String expression,
double result,
boolean success,
String error,
String formattedResult
) {
@Override
public String toString() {
if (success) {
return String.format("%s = %s", expression, formattedResult);
} else {
return String.format("计算错误:%s", error);
}
}
}
}
6.2.3 时间日期函数
// DateTimeFunction.java
package com.example.springai.function;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;
import java.time.*;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.function.Function;
@Component
public class DateTimeFunction implements Function<DateTimeFunction.Request, DateTimeFunction.Response> {
@Override
public Response apply(Request request) {
try {
String result = processDateTimeRequest(request);
return new Response(request.operation(), result, true, null);
} catch (Exception e) {
return new Response(request.operation(), null, false, e.getMessage());
}
}
/**
* 处理日期时间请求
*/
private String processDateTimeRequest(Request request) {
LocalDateTime now = LocalDateTime.now();
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
switch (request.operation().toLowerCase()) {
case "current_time":
return "当前时间:" + now.format(formatter);
case "current_date":
return "当前日期:" + now.toLocalDate().toString();
case "timestamp":
return "当前时间戳:" + Instant.now().getEpochSecond();
case "day_of_week":
return "今天是:" + getDayOfWeekInChinese(now.getDayOfWeek());
case "days_until_weekend":
int daysUntilWeekend = calculateDaysUntilWeekend(now);
return "距离周末还有:" + daysUntilWeekend + "天";
case "add_days":
if (request.value() != null) {
LocalDateTime future = now.plusDays(request.value());
return String.format("%d天后是:%s", request.value(), future.format(formatter));
}
break;
case "subtract_days":
if (request.value() != null) {
LocalDateTime past = now.minusDays(request.value());
return String.format("%d天前是:%s", request.value(), past.format(formatter));
}
break;
case "format_date":
if (request.dateString() != null && request.format() != null) {
return formatCustomDate(request.dateString(), request.format());
}
break;
case "calculate_age":
if (request.birthDate() != null) {
return calculateAge(request.birthDate());
}
break;
case "time_zone_convert":
if (request.timeZone() != null) {
return convertTimeZone(now, request.timeZone());
}
break;
}
return "不支持的操作:" + request.operation();
}
/**
* 获取中文星期
*/
private String getDayOfWeekInChinese(DayOfWeek dayOfWeek) {
return switch (dayOfWeek) {
case MONDAY -> "星期一";
case TUESDAY -> "星期二";
case WEDNESDAY -> "星期三";
case THURSDAY -> "星期四";
case FRIDAY -> "星期五";
case SATURDAY -> "星期六";
case SUNDAY -> "星期日";
};
}
/**
* 计算距离周末的天数
*/
private int calculateDaysUntilWeekend(LocalDateTime now) {
DayOfWeek currentDay = now.getDayOfWeek();
if (currentDay == DayOfWeek.SATURDAY || currentDay == DayOfWeek.SUNDAY) {
return 0;
}
return DayOfWeek.SATURDAY.getValue() - currentDay.getValue();
}
/**
* 自定义日期格式化
*/
private String formatCustomDate(String dateString, String format) {
try {
LocalDate date = LocalDate.parse(dateString);
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(format);
return "格式化结果:" + date.format(formatter);
} catch (Exception e) {
return "日期格式化失败:" + e.getMessage();
}
}
/**
* 计算年龄
*/
private String calculateAge(String birthDateString) {
try {
LocalDate birthDate = LocalDate.parse(birthDateString);
LocalDate now = LocalDate.now();
long years = ChronoUnit.YEARS.between(birthDate, now);
long months = ChronoUnit.MONTHS.between(birthDate.plusYears(years), now);
long days = ChronoUnit.DAYS.between(birthDate.plusYears(years).plusMonths(months), now);
return String.format("年龄:%d岁%d个月%d天", years, months, days);
} catch (Exception e) {
return "年龄计算失败:" + e.getMessage();
}
}
/**
* 时区转换
*/
private String convertTimeZone(LocalDateTime dateTime, String timeZoneId) {
try {
ZoneId sourceZone = ZoneId.systemDefault();
ZoneId targetZone = ZoneId.of(timeZoneId);
ZonedDateTime sourceTime = dateTime.atZone(sourceZone);
ZonedDateTime targetTime = sourceTime.withZoneSameInstant(targetZone);
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss z");
return String.format("时区转换:%s -> %s",
sourceTime.format(formatter),
targetTime.format(formatter));
} catch (Exception e) {
return "时区转换失败:" + e.getMessage();
}
}
/**
* 请求参数
*/
public record Request(
@JsonProperty(required = true, value = "operation")
@JsonPropertyDescription("操作类型:current_time, current_date, timestamp, day_of_week, days_until_weekend, add_days, subtract_days, format_date, calculate_age, time_zone_convert")
String operation,
@JsonProperty(value = "value")
@JsonPropertyDescription("数值参数(用于add_days, subtract_days等操作)")
Integer value,
@JsonProperty(value = "date_string")
@JsonPropertyDescription("日期字符串(格式:yyyy-MM-dd)")
String dateString,
@JsonProperty(value = "format")
@JsonPropertyDescription("日期格式模式")
String format,
@JsonProperty(value = "birth_date")
@JsonPropertyDescription("出生日期(格式:yyyy-MM-dd)")
String birthDate,
@JsonProperty(value = "time_zone")
@JsonPropertyDescription("目标时区ID(如:Asia/Shanghai, America/New_York)")
String timeZone
) {}
/**
* 响应结果
*/
public record Response(
String operation,
String result,
boolean success,
String error
) {
@Override
public String toString() {
return success ? result : "错误:" + error;
}
}
}
6.3 数据库查询函数
6.3.1 用户查询函数
// UserQueryFunction.java
package com.example.springai.function;
import com.example.springai.entity.User;
import com.example.springai.repository.UserRepository;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
@Component
public class UserQueryFunction implements Function<UserQueryFunction.Request, UserQueryFunction.Response> {
private final UserRepository userRepository;
public UserQueryFunction(UserRepository userRepository) {
this.userRepository = userRepository;
}
@Override
public Response apply(Request request) {
try {
List<User> users = executeQuery(request);
String result = formatUserList(users, request.operation());
return new Response(request.operation(), result, users.size(), true, null);
} catch (Exception e) {
return new Response(request.operation(), null, 0, false, e.getMessage());
}
}
/**
* 执行查询
*/
private List<User> executeQuery(Request request) {
return switch (request.operation().toLowerCase()) {
case "find_by_id" -> {
if (request.userId() != null) {
yield userRepository.findById(request.userId())
.map(List::of)
.orElse(List.of());
}
yield List.of();
}
case "find_by_email" -> {
if (request.email() != null) {
yield userRepository.findByEmail(request.email())
.map(List::of)
.orElse(List.of());
}
yield List.of();
}
case "find_by_name" -> {
if (request.name() != null) {
yield userRepository.findByNameContainingIgnoreCase(request.name());
}
yield List.of();
}
case "find_active_users" -> userRepository.findByActiveTrue();
case "find_recent_users" -> {
int days = request.days() != null ? request.days() : 7;
yield userRepository.findUsersCreatedInLastDays(days);
}
case "count_users" -> {
long count = userRepository.count();
yield List.of(); // 返回空列表,计数在结果中体现
}
case "find_by_role" -> {
if (request.role() != null) {
yield userRepository.findByRole(request.role());
}
yield List.of();
}
default -> throw new IllegalArgumentException("不支持的操作:" + request.operation());
};
}
/**
* 格式化用户列表
*/
private String formatUserList(List<User> users, String operation) {
if ("count_users".equals(operation.toLowerCase())) {
long count = userRepository.count();
return "用户总数:" + count;
}
if (users.isEmpty()) {
return "未找到匹配的用户";
}
if (users.size() == 1) {
User user = users.get(0);
return String.format(
"用户信息:\n" +
"ID: %d\n" +
"姓名: %s\n" +
"邮箱: %s\n" +
"角色: %s\n" +
"状态: %s\n" +
"创建时间: %s",
user.getId(),
user.getName(),
user.getEmail(),
user.getRole(),
user.isActive() ? "活跃" : "非活跃",
user.getCreatedAt()
);
}
// 多个用户的简要列表
String userList = users.stream()
.map(user -> String.format("- %s (%s) - %s",
user.getName(),
user.getEmail(),
user.getRole()))
.collect(Collectors.joining("\n"));
return String.format("找到 %d 个用户:\n%s", users.size(), userList);
}
/**
* 请求参数
*/
public record Request(
@JsonProperty(required = true, value = "operation")
@JsonPropertyDescription("查询操作:find_by_id, find_by_email, find_by_name, find_active_users, find_recent_users, count_users, find_by_role")
String operation,
@JsonProperty(value = "user_id")
@JsonPropertyDescription("用户ID")
Long userId,
@JsonProperty(value = "email")
@JsonPropertyDescription("用户邮箱")
String email,
@JsonProperty(value = "name")
@JsonPropertyDescription("用户姓名(支持模糊匹配)")
String name,
@JsonProperty(value = "role")
@JsonPropertyDescription("用户角色")
String role,
@JsonProperty(value = "days")
@JsonPropertyDescription("天数(用于查询最近注册的用户)")
Integer days
) {}
/**
* 响应结果
*/
public record Response(
String operation,
String result,
int count,
boolean success,
String error
) {
@Override
public String toString() {
return success ? result : "查询失败:" + error;
}
}
}
// User实体类
package com.example.springai.entity;
import jakarta.persistence.*;
import java.time.LocalDateTime;
@Entity
@Table(name = "users")
public class User {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@Column(nullable = false)
private String name;
@Column(nullable = false, unique = true)
private String email;
@Column(nullable = false)
private String role;
@Column(nullable = false)
private boolean active = true;
@Column(name = "created_at", nullable = false)
private LocalDateTime createdAt;
@Column(name = "updated_at")
private LocalDateTime updatedAt;
// 构造函数
public User() {
this.createdAt = LocalDateTime.now();
}
public User(String name, String email, String role) {
this();
this.name = name;
this.email = email;
this.role = role;
}
// Getters and Setters
public Long getId() { return id; }
public void setId(Long id) { this.id = id; }
public String getName() { return name; }
public void setName(String name) { this.name = name; }
public String getEmail() { return email; }
public void setEmail(String email) { this.email = email; }
public String getRole() { return role; }
public void setRole(String role) { this.role = role; }
public boolean isActive() { return active; }
public void setActive(boolean active) { this.active = active; }
public LocalDateTime getCreatedAt() { return createdAt; }
public void setCreatedAt(LocalDateTime createdAt) { this.createdAt = createdAt; }
public LocalDateTime getUpdatedAt() { return updatedAt; }
public void setUpdatedAt(LocalDateTime updatedAt) { this.updatedAt = updatedAt; }
@PreUpdate
public void preUpdate() {
this.updatedAt = LocalDateTime.now();
}
}
// UserRepository接口
package com.example.springai.repository;
import com.example.springai.entity.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Optional;
@Repository
public interface UserRepository extends JpaRepository<User, Long> {
Optional<User> findByEmail(String email);
List<User> findByNameContainingIgnoreCase(String name);
List<User> findByActiveTrue();
List<User> findByRole(String role);
@Query("SELECT u FROM User u WHERE u.createdAt >= CURRENT_DATE - :days")
List<User> findUsersCreatedInLastDays(@Param("days") int days);
@Query("SELECT COUNT(u) FROM User u WHERE u.active = true")
long countActiveUsers();
@Query("SELECT u FROM User u WHERE u.email LIKE %:domain")
List<User> findByEmailDomain(@Param("domain") String domain);
}
6.4 函数调用服务
6.4.1 函数注册管理器
// FunctionRegistryManager.java
package com.example.springai.service;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.function.Function;
@Service
public class FunctionRegistryManager {
private final Map<String, Function<?, ?>> functions = new HashMap<>();
private final Map<String, FunctionMetadata> functionMetadata = new HashMap<>();
/**
* 注册函数
*/
public <T, R> void registerFunction(String name, Function<T, R> function,
String description, Class<T> inputType) {
functions.put(name, function);
functionMetadata.put(name, new FunctionMetadata(name, description, inputType));
}
/**
* 获取函数
*/
@SuppressWarnings("unchecked")
public <T, R> Function<T, R> getFunction(String name) {
return (Function<T, R>) functions.get(name);
}
/**
* 获取所有函数名称
*/
public Set<String> getFunctionNames() {
return functions.keySet();
}
/**
* 获取函数元数据
*/
public FunctionMetadata getFunctionMetadata(String name) {
return functionMetadata.get(name);
}
/**
* 获取所有函数元数据
*/
public Collection<FunctionMetadata> getAllFunctionMetadata() {
return functionMetadata.values();
}
/**
* 创建聊天选项
*/
public ChatOptions createChatOptions(List<String> functionNames) {
return ChatOptions.builder()
.withFunctions(functionNames)
.build();
}
/**
* 函数元数据
*/
public record FunctionMetadata(
String name,
String description,
Class<?> inputType
) {}
}
// FunctionCallService.java
package com.example.springai.service;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Map;
@Service
public class FunctionCallService {
private final ChatModel chatModel;
private final FunctionRegistryManager functionRegistry;
public FunctionCallService(ChatModel chatModel, FunctionRegistryManager functionRegistry) {
this.chatModel = chatModel;
this.functionRegistry = functionRegistry;
}
/**
* 执行带函数调用的聊天
*/
public String chatWithFunctions(String userMessage, List<String> availableFunctions) {
try {
// 创建提示词
PromptTemplate promptTemplate = new PromptTemplate(
"用户问题:{question}\n\n" +
"可用函数:{functions}\n\n" +
"请根据用户问题,决定是否需要调用函数来获取信息。如果需要,请调用相应的函数。"
);
Map<String, Object> variables = Map.of(
"question", userMessage,
"functions", String.join(", ", availableFunctions)
);
Prompt prompt = promptTemplate.create(variables,
functionRegistry.createChatOptions(availableFunctions));
// 执行聊天
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getContent();
} catch (Exception e) {
return "处理请求时发生错误:" + e.getMessage();
}
}
/**
* 自动选择函数执行聊天
*/
public String chatWithAutoFunctionSelection(String userMessage) {
// 根据用户消息内容自动选择可能需要的函数
List<String> selectedFunctions = selectFunctionsForMessage(userMessage);
return chatWithFunctions(userMessage, selectedFunctions);
}
/**
* 根据消息内容选择函数
*/
private List<String> selectFunctionsForMessage(String message) {
List<String> selectedFunctions = new java.util.ArrayList<>();
String lowerMessage = message.toLowerCase();
// 天气相关
if (lowerMessage.contains("天气") || lowerMessage.contains("温度") ||
lowerMessage.contains("下雨") || lowerMessage.contains("晴天")) {
selectedFunctions.add("weatherFunction");
}
// 计算相关
if (lowerMessage.contains("计算") || lowerMessage.contains("加") ||
lowerMessage.contains("减") || lowerMessage.contains("乘") ||
lowerMessage.contains("除") || lowerMessage.matches(".*\\d+.*[+\\-*/].*\\d+.*")) {
selectedFunctions.add("calculatorFunction");
selectedFunctions.add("advancedCalculatorFunction");
}
// 时间日期相关
if (lowerMessage.contains("时间") || lowerMessage.contains("日期") ||
lowerMessage.contains("星期") || lowerMessage.contains("年龄") ||
lowerMessage.contains("时区")) {
selectedFunctions.add("dateTimeFunction");
}
// 用户查询相关
if (lowerMessage.contains("用户") || lowerMessage.contains("查找") ||
lowerMessage.contains("搜索") || lowerMessage.contains("邮箱")) {
selectedFunctions.add("userQueryFunction");
}
// 如果没有匹配到特定函数,返回所有函数
if (selectedFunctions.isEmpty()) {
selectedFunctions.addAll(functionRegistry.getFunctionNames());
}
return selectedFunctions;
}
/**
* 获取函数调用统计
*/
public FunctionCallStats getFunctionCallStats() {
return new FunctionCallStats(
functionRegistry.getFunctionNames().size(),
functionRegistry.getAllFunctionMetadata()
);
}
/**
* 函数调用统计
*/
public record FunctionCallStats(
int totalFunctions,
java.util.Collection<FunctionRegistryManager.FunctionMetadata> functions
) {}
}
6.4.2 函数调用配置
// FunctionConfiguration.java
package com.example.springai.config;
import com.example.springai.function.*;
import com.example.springai.service.FunctionRegistryManager;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class FunctionConfiguration {
/**
* 配置函数注册
*/
@Bean
public FunctionRegistryManager functionRegistryManager(
WeatherFunction weatherFunction,
CalculatorFunction calculatorFunction,
AdvancedCalculatorFunction advancedCalculatorFunction,
DateTimeFunction dateTimeFunction,
UserQueryFunction userQueryFunction) {
FunctionRegistryManager registry = new FunctionRegistryManager();
// 注册天气函数
registry.registerFunction(
"weatherFunction",
weatherFunction,
"获取指定城市的天气信息",
WeatherFunction.Request.class
);
// 注册计算器函数
registry.registerFunction(
"calculatorFunction",
calculatorFunction,
"执行基本的数学计算",
CalculatorFunction.Request.class
);
// 注册高级计算器函数
registry.registerFunction(
"advancedCalculatorFunction",
advancedCalculatorFunction,
"执行高级数学计算,支持数学函数",
AdvancedCalculatorFunction.Request.class
);
// 注册日期时间函数
registry.registerFunction(
"dateTimeFunction",
dateTimeFunction,
"处理日期时间相关操作",
DateTimeFunction.Request.class
);
// 注册用户查询函数
registry.registerFunction(
"userQueryFunction",
userQueryFunction,
"查询用户信息",
UserQueryFunction.Request.class
);
return registry;
}
}
6.5 函数调用控制器
6.5.1 REST API接口
// FunctionCallController.java
package com.example.springai.controller;
import com.example.springai.service.FunctionCallService;
import com.example.springai.service.FunctionRegistryManager;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import java.util.Map;
@RestController
@RequestMapping("/api/functions")
public class FunctionCallController {
private final FunctionCallService functionCallService;
private final FunctionRegistryManager functionRegistry;
public FunctionCallController(FunctionCallService functionCallService,
FunctionRegistryManager functionRegistry) {
this.functionCallService = functionCallService;
this.functionRegistry = functionRegistry;
}
/**
* 执行带函数调用的聊天
*/
@PostMapping("/chat")
public ResponseEntity<ChatResponse> chatWithFunctions(@RequestBody ChatRequest request) {
try {
String response = functionCallService.chatWithFunctions(
request.message(),
request.functions()
);
return ResponseEntity.ok(new ChatResponse(response, true, null));
} catch (Exception e) {
return ResponseEntity.ok(new ChatResponse(null, false, e.getMessage()));
}
}
/**
* 自动选择函数的聊天
*/
@PostMapping("/chat/auto")
public ResponseEntity<ChatResponse> chatWithAutoFunctions(@RequestBody AutoChatRequest request) {
try {
String response = functionCallService.chatWithAutoFunctionSelection(request.message());
return ResponseEntity.ok(new ChatResponse(response, true, null));
} catch (Exception e) {
return ResponseEntity.ok(new ChatResponse(null, false, e.getMessage()));
}
}
/**
* 获取所有可用函数
*/
@GetMapping("/available")
public ResponseEntity<List<FunctionInfo>> getAvailableFunctions() {
List<FunctionInfo> functions = functionRegistry.getAllFunctionMetadata()
.stream()
.map(metadata -> new FunctionInfo(
metadata.name(),
metadata.description(),
metadata.inputType().getSimpleName()
))
.toList();
return ResponseEntity.ok(functions);
}
/**
* 获取函数调用统计
*/
@GetMapping("/stats")
public ResponseEntity<FunctionCallService.FunctionCallStats> getFunctionStats() {
return ResponseEntity.ok(functionCallService.getFunctionCallStats());
}
/**
* 获取函数详细信息
*/
@GetMapping("/{functionName}")
public ResponseEntity<FunctionDetail> getFunctionDetail(@PathVariable String functionName) {
FunctionRegistryManager.FunctionMetadata metadata =
functionRegistry.getFunctionMetadata(functionName);
if (metadata == null) {
return ResponseEntity.notFound().build();
}
FunctionDetail detail = new FunctionDetail(
metadata.name(),
metadata.description(),
metadata.inputType().getSimpleName(),
getInputSchema(metadata.inputType())
);
return ResponseEntity.ok(detail);
}
/**
* 测试函数调用
*/
@PostMapping("/test/{functionName}")
public ResponseEntity<TestResult> testFunction(
@PathVariable String functionName,
@RequestBody Map<String, Object> input) {
try {
// 这里应该实现具体的函数测试逻辑
// 为了演示,返回模拟结果
return ResponseEntity.ok(new TestResult(
functionName,
input,
"测试成功",
true,
null
));
} catch (Exception e) {
return ResponseEntity.ok(new TestResult(
functionName,
input,
null,
false,
e.getMessage()
));
}
}
/**
* 获取输入模式
*/
private String getInputSchema(Class<?> inputType) {
// 这里应该实现JSON Schema生成逻辑
// 为了演示,返回简化的模式描述
return "输入类型:" + inputType.getSimpleName();
}
// 请求和响应记录类
public record ChatRequest(String message, List<String> functions) {}
public record AutoChatRequest(String message) {}
public record ChatResponse(String response, boolean success, String error) {}
public record FunctionInfo(String name, String description, String inputType) {}
public record FunctionDetail(String name, String description, String inputType, String schema) {}
public record TestResult(String functionName, Map<String, Object> input,
String result, boolean success, String error) {}
}
6.5.2 WebSocket实时聊天
// FunctionChatWebSocketHandler.java
package com.example.springai.websocket;
import com.example.springai.service.FunctionCallService;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Component
public class FunctionChatWebSocketHandler implements WebSocketHandler {
private final FunctionCallService functionCallService;
private final ObjectMapper objectMapper;
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
public FunctionChatWebSocketHandler(FunctionCallService functionCallService) {
this.functionCallService = functionCallService;
this.objectMapper = new ObjectMapper();
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
sessions.put(session.getId(), session);
// 发送欢迎消息
WebSocketMessage welcomeMessage = new WebSocketMessage(
"system",
"连接成功!您可以开始与AI助手对话,支持函数调用功能。",
System.currentTimeMillis()
);
session.sendMessage(new TextMessage(objectMapper.writeValueAsString(welcomeMessage)));
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
if (message instanceof TextMessage textMessage) {
try {
// 解析用户消息
Map<String, Object> messageData = objectMapper.readValue(
textMessage.getPayload(), Map.class);
String userMessage = (String) messageData.get("message");
// 发送用户消息确认
WebSocketMessage userEcho = new WebSocketMessage(
"user",
userMessage,
System.currentTimeMillis()
);
session.sendMessage(new TextMessage(objectMapper.writeValueAsString(userEcho)));
// 处理AI响应
String aiResponse = functionCallService.chatWithAutoFunctionSelection(userMessage);
WebSocketMessage aiMessage = new WebSocketMessage(
"assistant",
aiResponse,
System.currentTimeMillis()
);
session.sendMessage(new TextMessage(objectMapper.writeValueAsString(aiMessage)));
} catch (Exception e) {
WebSocketMessage errorMessage = new WebSocketMessage(
"error",
"处理消息时发生错误:" + e.getMessage(),
System.currentTimeMillis()
);
session.sendMessage(new TextMessage(objectMapper.writeValueAsString(errorMessage)));
}
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
sessions.remove(session.getId());
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
sessions.remove(session.getId());
}
@Override
public boolean supportsPartialMessages() {
return false;
}
/**
* WebSocket消息格式
*/
public record WebSocketMessage(
String type,
String content,
long timestamp
) {}
}
// WebSocket配置
package com.example.springai.config;
import com.example.springai.websocket.FunctionChatWebSocketHandler;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.*;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
private final FunctionChatWebSocketHandler webSocketHandler;
public WebSocketConfig(FunctionChatWebSocketHandler webSocketHandler) {
this.webSocketHandler = webSocketHandler;
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(webSocketHandler, "/ws/chat")
.setAllowedOrigins("*");
}
}
6.6 配置文件
6.6.1 应用配置
# application.yml
spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
chat:
options:
model: gpt-4
temperature: 0.7
max-tokens: 2000
datasource:
url: jdbc:h2:mem:testdb
driver-class-name: org.h2.Driver
username: sa
password: password
jpa:
hibernate:
ddl-auto: create-drop
show-sql: true
h2:
console:
enabled: true
# 天气API配置
weather:
api:
key: ${WEATHER_API_KEY:demo_key}
url: https://api.openweathermap.org/data/2.5/weather
# 函数调用配置
function:
call:
timeout: 30000
retry:
max-attempts: 3
delay: 1000
cache:
enabled: true
ttl: 300
# 日志配置
logging:
level:
com.example.springai: DEBUG
org.springframework.ai: DEBUG
pattern:
console: "%d{yyyy-MM-dd HH:mm:ss} - %msg%n"
6.6.2 函数调用属性配置
// FunctionCallProperties.java
package com.example.springai.config;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
@Component
@ConfigurationProperties(prefix = "function.call")
public class FunctionCallProperties {
private int timeout = 30000;
private Retry retry = new Retry();
private Cache cache = new Cache();
// Getters and Setters
public int getTimeout() { return timeout; }
public void setTimeout(int timeout) { this.timeout = timeout; }
public Retry getRetry() { return retry; }
public void setRetry(Retry retry) { this.retry = retry; }
public Cache getCache() { return cache; }
public void setCache(Cache cache) { this.cache = cache; }
public static class Retry {
private int maxAttempts = 3;
private long delay = 1000;
public int getMaxAttempts() { return maxAttempts; }
public void setMaxAttempts(int maxAttempts) { this.maxAttempts = maxAttempts; }
public long getDelay() { return delay; }
public void setDelay(long delay) { this.delay = delay; }
}
public static class Cache {
private boolean enabled = true;
private long ttl = 300;
public boolean isEnabled() { return enabled; }
public void setEnabled(boolean enabled) { this.enabled = enabled; }
public long getTtl() { return ttl; }
public void setTtl(long ttl) { this.ttl = ttl; }
}
}
6.7 本章总结
6.7.1 核心要点
函数调用基础
- 理解函数调用的概念和架构
- 掌握Spring AI的函数调用接口
- 学会函数定义和注册
函数实现
- 天气查询函数:外部API集成
- 计算器函数:数学表达式处理
- 日期时间函数:时间处理和格式化
- 数据库查询函数:数据访问集成
服务架构
- 函数注册管理器:统一函数管理
- 函数调用服务:智能函数选择
- REST API:HTTP接口支持
- WebSocket:实时聊天功能
6.7.2 最佳实践
函数设计原则
- 单一职责:每个函数专注一个功能
- 参数验证:确保输入参数的有效性
- 错误处理:提供清晰的错误信息
- 文档完善:详细的函数描述和参数说明
性能优化
- 缓存机制:缓存函数调用结果
- 超时控制:避免长时间等待
- 重试策略:处理临时性错误
- 异步处理:提高响应速度
安全考虑
- 输入验证:防止恶意输入
- 权限控制:限制函数访问权限
- 敏感数据:保护API密钥等敏感信息
- 日志记录:记录函数调用日志
6.7.3 练习题
基础练习
- 实现一个货币转换函数
- 创建一个文件操作函数
- 开发一个邮件发送函数
进阶练习
- 实现函数调用缓存机制
- 添加函数调用监控和统计
- 创建函数调用链功能
综合项目
- 构建一个智能助手系统
- 集成多个外部API服务
- 实现函数调用的可视化界面
下一章预告: 第七章将介绍Spring AI的流式处理与异步编程,包括流式响应、异步处理、响应式编程等内容。 “`