6.1 函数调用基础

6.1.1 函数调用概念

函数调用(Function Calling)是现代AI模型的重要功能,允许AI模型调用外部函数来获取实时数据、执行计算或与外部系统交互。

核心概念: 1. 函数定义:描述函数的名称、参数和功能 2. 函数注册:将函数注册到AI模型中 3. 函数调用:AI模型根据用户输入决定调用哪个函数 4. 结果处理:处理函数执行结果并生成最终回答

6.1.2 Spring AI函数调用架构

// Function接口
public interface Function<T, R> {
    /**
     * 函数名称
     */
    String getName();
    
    /**
     * 函数描述
     */
    String getDescription();
    
    /**
     * 输入参数模式
     */
    String getInputTypeSchema();
    
    /**
     * 执行函数
     */
    R apply(T input);
}

// FunctionCallback接口
public interface FunctionCallback {
    String getName();
    String getDescription();
    String getInputTypeSchema();
    String call(String functionInput);
}

// 函数调用请求
public class FunctionCallingOptions {
    private List<String> functions;
    private FunctionCallingMode functionCallingMode;
    
    public enum FunctionCallingMode {
        AUTO,    // 自动决定是否调用函数
        NONE,    // 不调用函数
        REQUIRED // 必须调用函数
    }
    
    // Getters and Setters
    public List<String> getFunctions() { return functions; }
    public void setFunctions(List<String> functions) { this.functions = functions; }
    
    public FunctionCallingMode getFunctionCallingMode() { return functionCallingMode; }
    public void setFunctionCallingMode(FunctionCallingMode mode) { this.functionCallingMode = mode; }
}

6.2 基础函数实现

6.2.1 天气查询函数

// WeatherFunction.java
package com.example.springai.function;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;

import java.util.function.Function;

@Component
public class WeatherFunction implements Function<WeatherFunction.Request, WeatherFunction.Response> {
    
    @Override
    public Response apply(Request request) {
        // 模拟天气API调用
        return getWeatherInfo(request.location());
    }
    
    /**
     * 获取天气信息
     */
    private Response getWeatherInfo(String location) {
        // 这里应该调用真实的天气API
        // 为了演示,返回模拟数据
        
        switch (location.toLowerCase()) {
            case "北京":
                return new Response(location, "晴天", 25, 60, "微风");
            case "上海":
                return new Response(location, "多云", 22, 70, "东南风");
            case "广州":
                return new Response(location, "雨天", 28, 85, "南风");
            default:
                return new Response(location, "未知", 0, 0, "无数据");
        }
    }
    
    /**
     * 请求参数
     */
    public record Request(
        @JsonProperty(required = true, value = "location")
        @JsonPropertyDescription("要查询天气的城市名称")
        String location
    ) {}
    
    /**
     * 响应结果
     */
    public record Response(
        String location,
        String weather,
        int temperature,
        int humidity,
        String wind
    ) {
        @Override
        public String toString() {
            return String.format(
                "%s的天气:%s,温度%d°C,湿度%d%%,风力%s",
                location, weather, temperature, humidity, wind
            );
        }
    }
}

// WeatherService.java
package com.example.springai.service;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import java.util.Map;

@Service
public class WeatherService {
    
    @Value("${weather.api.key:demo_key}")
    private String apiKey;
    
    @Value("${weather.api.url:https://api.openweathermap.org/data/2.5/weather}")
    private String apiUrl;
    
    private final RestTemplate restTemplate;
    
    public WeatherService() {
        this.restTemplate = new RestTemplate();
    }
    
    /**
     * 获取实时天气数据
     */
    public WeatherData getCurrentWeather(String city) {
        try {
            String url = String.format("%s?q=%s&appid=%s&units=metric&lang=zh_cn", 
                apiUrl, city, apiKey);
            
            Map<String, Object> response = restTemplate.getForObject(url, Map.class);
            
            if (response != null) {
                return parseWeatherResponse(response);
            }
        } catch (Exception e) {
            System.err.println("获取天气数据失败: " + e.getMessage());
        }
        
        return null;
    }
    
    /**
     * 解析天气API响应
     */
    private WeatherData parseWeatherResponse(Map<String, Object> response) {
        Map<String, Object> main = (Map<String, Object>) response.get("main");
        Map<String, Object> weather = ((java.util.List<Map<String, Object>>) response.get("weather")).get(0);
        Map<String, Object> wind = (Map<String, Object>) response.get("wind");
        
        return new WeatherData(
            (String) response.get("name"),
            (String) weather.get("description"),
            ((Number) main.get("temp")).doubleValue(),
            ((Number) main.get("humidity")).intValue(),
            wind != null ? ((Number) wind.get("speed")).doubleValue() : 0.0
        );
    }
    
    /**
     * 天气数据模型
     */
    public record WeatherData(
        String city,
        String description,
        double temperature,
        int humidity,
        double windSpeed
    ) {}
}

6.2.2 计算器函数

// CalculatorFunction.java
package com.example.springai.function;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;

import java.util.function.Function;

@Component
public class CalculatorFunction implements Function<CalculatorFunction.Request, CalculatorFunction.Response> {
    
    @Override
    public Response apply(Request request) {
        try {
            double result = performCalculation(request.expression());
            return new Response(request.expression(), result, true, null);
        } catch (Exception e) {
            return new Response(request.expression(), 0.0, false, e.getMessage());
        }
    }
    
    /**
     * 执行数学计算
     */
    private double performCalculation(String expression) {
        // 简单的表达式计算器
        // 在实际应用中,应该使用更强大的数学表达式解析器
        
        expression = expression.replaceAll("\\s+", "");
        
        // 支持基本的四则运算
        if (expression.contains("+")) {
            String[] parts = expression.split("\\+");
            return Double.parseDouble(parts[0]) + Double.parseDouble(parts[1]);
        } else if (expression.contains("-")) {
            String[] parts = expression.split("-");
            if (parts.length == 2) {
                return Double.parseDouble(parts[0]) - Double.parseDouble(parts[1]);
            }
        } else if (expression.contains("*")) {
            String[] parts = expression.split("\\*");
            return Double.parseDouble(parts[0]) * Double.parseDouble(parts[1]);
        } else if (expression.contains("/")) {
            String[] parts = expression.split("/");
            double divisor = Double.parseDouble(parts[1]);
            if (divisor == 0) {
                throw new ArithmeticException("除数不能为零");
            }
            return Double.parseDouble(parts[0]) / divisor;
        }
        
        // 如果没有运算符,尝试解析为数字
        return Double.parseDouble(expression);
    }
    
    /**
     * 请求参数
     */
    public record Request(
        @JsonProperty(required = true, value = "expression")
        @JsonPropertyDescription("要计算的数学表达式,支持 +、-、*、/ 运算")
        String expression
    ) {}
    
    /**
     * 响应结果
     */
    public record Response(
        String expression,
        double result,
        boolean success,
        String error
    ) {
        @Override
        public String toString() {
            if (success) {
                return String.format("%s = %.2f", expression, result);
            } else {
                return String.format("计算错误:%s", error);
            }
        }
    }
}

// AdvancedCalculatorFunction.java
package com.example.springai.function;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;

import javax.script.ScriptEngine;
import javax.script.ScriptEngineManager;
import java.util.function.Function;

@Component
public class AdvancedCalculatorFunction implements Function<AdvancedCalculatorFunction.Request, AdvancedCalculatorFunction.Response> {
    
    private final ScriptEngine scriptEngine;
    
    public AdvancedCalculatorFunction() {
        ScriptEngineManager manager = new ScriptEngineManager();
        this.scriptEngine = manager.getEngineByName("JavaScript");
    }
    
    @Override
    public Response apply(Request request) {
        try {
            Object result = evaluateExpression(request.expression());
            double numericResult = ((Number) result).doubleValue();
            
            return new Response(
                request.expression(),
                numericResult,
                true,
                null,
                formatResult(numericResult)
            );
        } catch (Exception e) {
            return new Response(
                request.expression(),
                0.0,
                false,
                e.getMessage(),
                null
            );
        }
    }
    
    /**
     * 计算数学表达式
     */
    private Object evaluateExpression(String expression) throws Exception {
        // 预处理表达式,添加数学函数支持
        String processedExpression = preprocessExpression(expression);
        
        // 使用JavaScript引擎计算
        return scriptEngine.eval(processedExpression);
    }
    
    /**
     * 预处理表达式
     */
    private String preprocessExpression(String expression) {
        // 替换数学函数
        expression = expression.replaceAll("sin\\(", "Math.sin(");
        expression = expression.replaceAll("cos\\(", "Math.cos(");
        expression = expression.replaceAll("tan\\(", "Math.tan(");
        expression = expression.replaceAll("sqrt\\(", "Math.sqrt(");
        expression = expression.replaceAll("log\\(", "Math.log(");
        expression = expression.replaceAll("pow\\(", "Math.pow(");
        expression = expression.replaceAll("abs\\(", "Math.abs(");
        
        // 替换常数
        expression = expression.replaceAll("\\bpi\\b", "Math.PI");
        expression = expression.replaceAll("\\be\\b", "Math.E");
        
        return expression;
    }
    
    /**
     * 格式化结果
     */
    private String formatResult(double result) {
        if (result == (long) result) {
            return String.valueOf((long) result);
        } else {
            return String.format("%.6f", result).replaceAll("0+$", "").replaceAll("\\.$", "");
        }
    }
    
    /**
     * 请求参数
     */
    public record Request(
        @JsonProperty(required = true, value = "expression")
        @JsonPropertyDescription("要计算的数学表达式,支持基本运算和数学函数(sin, cos, tan, sqrt, log, pow, abs等)")
        String expression
    ) {}
    
    /**
     * 响应结果
     */
    public record Response(
        String expression,
        double result,
        boolean success,
        String error,
        String formattedResult
    ) {
        @Override
        public String toString() {
            if (success) {
                return String.format("%s = %s", expression, formattedResult);
            } else {
                return String.format("计算错误:%s", error);
            }
        }
    }
}

6.2.3 时间日期函数

// DateTimeFunction.java
package com.example.springai.function;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;

import java.time.*;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.function.Function;

@Component
public class DateTimeFunction implements Function<DateTimeFunction.Request, DateTimeFunction.Response> {
    
    @Override
    public Response apply(Request request) {
        try {
            String result = processDateTimeRequest(request);
            return new Response(request.operation(), result, true, null);
        } catch (Exception e) {
            return new Response(request.operation(), null, false, e.getMessage());
        }
    }
    
    /**
     * 处理日期时间请求
     */
    private String processDateTimeRequest(Request request) {
        LocalDateTime now = LocalDateTime.now();
        DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
        
        switch (request.operation().toLowerCase()) {
            case "current_time":
                return "当前时间:" + now.format(formatter);
                
            case "current_date":
                return "当前日期:" + now.toLocalDate().toString();
                
            case "timestamp":
                return "当前时间戳:" + Instant.now().getEpochSecond();
                
            case "day_of_week":
                return "今天是:" + getDayOfWeekInChinese(now.getDayOfWeek());
                
            case "days_until_weekend":
                int daysUntilWeekend = calculateDaysUntilWeekend(now);
                return "距离周末还有:" + daysUntilWeekend + "天";
                
            case "add_days":
                if (request.value() != null) {
                    LocalDateTime future = now.plusDays(request.value());
                    return String.format("%d天后是:%s", request.value(), future.format(formatter));
                }
                break;
                
            case "subtract_days":
                if (request.value() != null) {
                    LocalDateTime past = now.minusDays(request.value());
                    return String.format("%d天前是:%s", request.value(), past.format(formatter));
                }
                break;
                
            case "format_date":
                if (request.dateString() != null && request.format() != null) {
                    return formatCustomDate(request.dateString(), request.format());
                }
                break;
                
            case "calculate_age":
                if (request.birthDate() != null) {
                    return calculateAge(request.birthDate());
                }
                break;
                
            case "time_zone_convert":
                if (request.timeZone() != null) {
                    return convertTimeZone(now, request.timeZone());
                }
                break;
        }
        
        return "不支持的操作:" + request.operation();
    }
    
    /**
     * 获取中文星期
     */
    private String getDayOfWeekInChinese(DayOfWeek dayOfWeek) {
        return switch (dayOfWeek) {
            case MONDAY -> "星期一";
            case TUESDAY -> "星期二";
            case WEDNESDAY -> "星期三";
            case THURSDAY -> "星期四";
            case FRIDAY -> "星期五";
            case SATURDAY -> "星期六";
            case SUNDAY -> "星期日";
        };
    }
    
    /**
     * 计算距离周末的天数
     */
    private int calculateDaysUntilWeekend(LocalDateTime now) {
        DayOfWeek currentDay = now.getDayOfWeek();
        if (currentDay == DayOfWeek.SATURDAY || currentDay == DayOfWeek.SUNDAY) {
            return 0;
        }
        return DayOfWeek.SATURDAY.getValue() - currentDay.getValue();
    }
    
    /**
     * 自定义日期格式化
     */
    private String formatCustomDate(String dateString, String format) {
        try {
            LocalDate date = LocalDate.parse(dateString);
            DateTimeFormatter formatter = DateTimeFormatter.ofPattern(format);
            return "格式化结果:" + date.format(formatter);
        } catch (Exception e) {
            return "日期格式化失败:" + e.getMessage();
        }
    }
    
    /**
     * 计算年龄
     */
    private String calculateAge(String birthDateString) {
        try {
            LocalDate birthDate = LocalDate.parse(birthDateString);
            LocalDate now = LocalDate.now();
            long years = ChronoUnit.YEARS.between(birthDate, now);
            long months = ChronoUnit.MONTHS.between(birthDate.plusYears(years), now);
            long days = ChronoUnit.DAYS.between(birthDate.plusYears(years).plusMonths(months), now);
            
            return String.format("年龄:%d岁%d个月%d天", years, months, days);
        } catch (Exception e) {
            return "年龄计算失败:" + e.getMessage();
        }
    }
    
    /**
     * 时区转换
     */
    private String convertTimeZone(LocalDateTime dateTime, String timeZoneId) {
        try {
            ZoneId sourceZone = ZoneId.systemDefault();
            ZoneId targetZone = ZoneId.of(timeZoneId);
            
            ZonedDateTime sourceTime = dateTime.atZone(sourceZone);
            ZonedDateTime targetTime = sourceTime.withZoneSameInstant(targetZone);
            
            DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss z");
            
            return String.format("时区转换:%s -> %s", 
                sourceTime.format(formatter), 
                targetTime.format(formatter));
        } catch (Exception e) {
            return "时区转换失败:" + e.getMessage();
        }
    }
    
    /**
     * 请求参数
     */
    public record Request(
        @JsonProperty(required = true, value = "operation")
        @JsonPropertyDescription("操作类型:current_time, current_date, timestamp, day_of_week, days_until_weekend, add_days, subtract_days, format_date, calculate_age, time_zone_convert")
        String operation,
        
        @JsonProperty(value = "value")
        @JsonPropertyDescription("数值参数(用于add_days, subtract_days等操作)")
        Integer value,
        
        @JsonProperty(value = "date_string")
        @JsonPropertyDescription("日期字符串(格式:yyyy-MM-dd)")
        String dateString,
        
        @JsonProperty(value = "format")
        @JsonPropertyDescription("日期格式模式")
        String format,
        
        @JsonProperty(value = "birth_date")
        @JsonPropertyDescription("出生日期(格式:yyyy-MM-dd)")
        String birthDate,
        
        @JsonProperty(value = "time_zone")
        @JsonPropertyDescription("目标时区ID(如:Asia/Shanghai, America/New_York)")
        String timeZone
    ) {}
    
    /**
     * 响应结果
     */
    public record Response(
        String operation,
        String result,
        boolean success,
        String error
    ) {
        @Override
        public String toString() {
            return success ? result : "错误:" + error;
        }
    }
}

6.3 数据库查询函数

6.3.1 用户查询函数

// UserQueryFunction.java
package com.example.springai.function;

import com.example.springai.entity.User;
import com.example.springai.repository.UserRepository;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

@Component
public class UserQueryFunction implements Function<UserQueryFunction.Request, UserQueryFunction.Response> {
    
    private final UserRepository userRepository;
    
    public UserQueryFunction(UserRepository userRepository) {
        this.userRepository = userRepository;
    }
    
    @Override
    public Response apply(Request request) {
        try {
            List<User> users = executeQuery(request);
            String result = formatUserList(users, request.operation());
            return new Response(request.operation(), result, users.size(), true, null);
        } catch (Exception e) {
            return new Response(request.operation(), null, 0, false, e.getMessage());
        }
    }
    
    /**
     * 执行查询
     */
    private List<User> executeQuery(Request request) {
        return switch (request.operation().toLowerCase()) {
            case "find_by_id" -> {
                if (request.userId() != null) {
                    yield userRepository.findById(request.userId())
                        .map(List::of)
                        .orElse(List.of());
                }
                yield List.of();
            }
            case "find_by_email" -> {
                if (request.email() != null) {
                    yield userRepository.findByEmail(request.email())
                        .map(List::of)
                        .orElse(List.of());
                }
                yield List.of();
            }
            case "find_by_name" -> {
                if (request.name() != null) {
                    yield userRepository.findByNameContainingIgnoreCase(request.name());
                }
                yield List.of();
            }
            case "find_active_users" -> userRepository.findByActiveTrue();
            case "find_recent_users" -> {
                int days = request.days() != null ? request.days() : 7;
                yield userRepository.findUsersCreatedInLastDays(days);
            }
            case "count_users" -> {
                long count = userRepository.count();
                yield List.of(); // 返回空列表,计数在结果中体现
            }
            case "find_by_role" -> {
                if (request.role() != null) {
                    yield userRepository.findByRole(request.role());
                }
                yield List.of();
            }
            default -> throw new IllegalArgumentException("不支持的操作:" + request.operation());
        };
    }
    
    /**
     * 格式化用户列表
     */
    private String formatUserList(List<User> users, String operation) {
        if ("count_users".equals(operation.toLowerCase())) {
            long count = userRepository.count();
            return "用户总数:" + count;
        }
        
        if (users.isEmpty()) {
            return "未找到匹配的用户";
        }
        
        if (users.size() == 1) {
            User user = users.get(0);
            return String.format(
                "用户信息:\n" +
                "ID: %d\n" +
                "姓名: %s\n" +
                "邮箱: %s\n" +
                "角色: %s\n" +
                "状态: %s\n" +
                "创建时间: %s",
                user.getId(),
                user.getName(),
                user.getEmail(),
                user.getRole(),
                user.isActive() ? "活跃" : "非活跃",
                user.getCreatedAt()
            );
        }
        
        // 多个用户的简要列表
        String userList = users.stream()
            .map(user -> String.format("- %s (%s) - %s", 
                user.getName(), 
                user.getEmail(), 
                user.getRole()))
            .collect(Collectors.joining("\n"));
        
        return String.format("找到 %d 个用户:\n%s", users.size(), userList);
    }
    
    /**
     * 请求参数
     */
    public record Request(
        @JsonProperty(required = true, value = "operation")
        @JsonPropertyDescription("查询操作:find_by_id, find_by_email, find_by_name, find_active_users, find_recent_users, count_users, find_by_role")
        String operation,
        
        @JsonProperty(value = "user_id")
        @JsonPropertyDescription("用户ID")
        Long userId,
        
        @JsonProperty(value = "email")
        @JsonPropertyDescription("用户邮箱")
        String email,
        
        @JsonProperty(value = "name")
        @JsonPropertyDescription("用户姓名(支持模糊匹配)")
        String name,
        
        @JsonProperty(value = "role")
        @JsonPropertyDescription("用户角色")
        String role,
        
        @JsonProperty(value = "days")
        @JsonPropertyDescription("天数(用于查询最近注册的用户)")
        Integer days
    ) {}
    
    /**
     * 响应结果
     */
    public record Response(
        String operation,
        String result,
        int count,
        boolean success,
        String error
    ) {
        @Override
        public String toString() {
            return success ? result : "查询失败:" + error;
        }
    }
}

// User实体类
package com.example.springai.entity;

import jakarta.persistence.*;
import java.time.LocalDateTime;

@Entity
@Table(name = "users")
public class User {
    
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;
    
    @Column(nullable = false)
    private String name;
    
    @Column(nullable = false, unique = true)
    private String email;
    
    @Column(nullable = false)
    private String role;
    
    @Column(nullable = false)
    private boolean active = true;
    
    @Column(name = "created_at", nullable = false)
    private LocalDateTime createdAt;
    
    @Column(name = "updated_at")
    private LocalDateTime updatedAt;
    
    // 构造函数
    public User() {
        this.createdAt = LocalDateTime.now();
    }
    
    public User(String name, String email, String role) {
        this();
        this.name = name;
        this.email = email;
        this.role = role;
    }
    
    // Getters and Setters
    public Long getId() { return id; }
    public void setId(Long id) { this.id = id; }
    
    public String getName() { return name; }
    public void setName(String name) { this.name = name; }
    
    public String getEmail() { return email; }
    public void setEmail(String email) { this.email = email; }
    
    public String getRole() { return role; }
    public void setRole(String role) { this.role = role; }
    
    public boolean isActive() { return active; }
    public void setActive(boolean active) { this.active = active; }
    
    public LocalDateTime getCreatedAt() { return createdAt; }
    public void setCreatedAt(LocalDateTime createdAt) { this.createdAt = createdAt; }
    
    public LocalDateTime getUpdatedAt() { return updatedAt; }
    public void setUpdatedAt(LocalDateTime updatedAt) { this.updatedAt = updatedAt; }
    
    @PreUpdate
    public void preUpdate() {
        this.updatedAt = LocalDateTime.now();
    }
}

// UserRepository接口
package com.example.springai.repository;

import com.example.springai.entity.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

import java.util.List;
import java.util.Optional;

@Repository
public interface UserRepository extends JpaRepository<User, Long> {
    
    Optional<User> findByEmail(String email);
    
    List<User> findByNameContainingIgnoreCase(String name);
    
    List<User> findByActiveTrue();
    
    List<User> findByRole(String role);
    
    @Query("SELECT u FROM User u WHERE u.createdAt >= CURRENT_DATE - :days")
    List<User> findUsersCreatedInLastDays(@Param("days") int days);
    
    @Query("SELECT COUNT(u) FROM User u WHERE u.active = true")
    long countActiveUsers();
    
    @Query("SELECT u FROM User u WHERE u.email LIKE %:domain")
    List<User> findByEmailDomain(@Param("domain") String domain);
}

6.4 函数调用服务

6.4.1 函数注册管理器

// FunctionRegistryManager.java
package com.example.springai.service;

import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.stereotype.Service;

import java.util.*;
import java.util.function.Function;

@Service
public class FunctionRegistryManager {
    
    private final Map<String, Function<?, ?>> functions = new HashMap<>();
    private final Map<String, FunctionMetadata> functionMetadata = new HashMap<>();
    
    /**
     * 注册函数
     */
    public <T, R> void registerFunction(String name, Function<T, R> function, 
                                       String description, Class<T> inputType) {
        functions.put(name, function);
        functionMetadata.put(name, new FunctionMetadata(name, description, inputType));
    }
    
    /**
     * 获取函数
     */
    @SuppressWarnings("unchecked")
    public <T, R> Function<T, R> getFunction(String name) {
        return (Function<T, R>) functions.get(name);
    }
    
    /**
     * 获取所有函数名称
     */
    public Set<String> getFunctionNames() {
        return functions.keySet();
    }
    
    /**
     * 获取函数元数据
     */
    public FunctionMetadata getFunctionMetadata(String name) {
        return functionMetadata.get(name);
    }
    
    /**
     * 获取所有函数元数据
     */
    public Collection<FunctionMetadata> getAllFunctionMetadata() {
        return functionMetadata.values();
    }
    
    /**
     * 创建聊天选项
     */
    public ChatOptions createChatOptions(List<String> functionNames) {
        return ChatOptions.builder()
            .withFunctions(functionNames)
            .build();
    }
    
    /**
     * 函数元数据
     */
    public record FunctionMetadata(
        String name,
        String description,
        Class<?> inputType
    ) {}
}

// FunctionCallService.java
package com.example.springai.service;

import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.Map;

@Service
public class FunctionCallService {
    
    private final ChatModel chatModel;
    private final FunctionRegistryManager functionRegistry;
    
    public FunctionCallService(ChatModel chatModel, FunctionRegistryManager functionRegistry) {
        this.chatModel = chatModel;
        this.functionRegistry = functionRegistry;
    }
    
    /**
     * 执行带函数调用的聊天
     */
    public String chatWithFunctions(String userMessage, List<String> availableFunctions) {
        try {
            // 创建提示词
            PromptTemplate promptTemplate = new PromptTemplate(
                "用户问题:{question}\n\n" +
                "可用函数:{functions}\n\n" +
                "请根据用户问题,决定是否需要调用函数来获取信息。如果需要,请调用相应的函数。"
            );
            
            Map<String, Object> variables = Map.of(
                "question", userMessage,
                "functions", String.join(", ", availableFunctions)
            );
            
            Prompt prompt = promptTemplate.create(variables, 
                functionRegistry.createChatOptions(availableFunctions));
            
            // 执行聊天
            ChatResponse response = chatModel.call(prompt);
            
            return response.getResult().getOutput().getContent();
            
        } catch (Exception e) {
            return "处理请求时发生错误:" + e.getMessage();
        }
    }
    
    /**
     * 自动选择函数执行聊天
     */
    public String chatWithAutoFunctionSelection(String userMessage) {
        // 根据用户消息内容自动选择可能需要的函数
        List<String> selectedFunctions = selectFunctionsForMessage(userMessage);
        return chatWithFunctions(userMessage, selectedFunctions);
    }
    
    /**
     * 根据消息内容选择函数
     */
    private List<String> selectFunctionsForMessage(String message) {
        List<String> selectedFunctions = new java.util.ArrayList<>();
        String lowerMessage = message.toLowerCase();
        
        // 天气相关
        if (lowerMessage.contains("天气") || lowerMessage.contains("温度") || 
            lowerMessage.contains("下雨") || lowerMessage.contains("晴天")) {
            selectedFunctions.add("weatherFunction");
        }
        
        // 计算相关
        if (lowerMessage.contains("计算") || lowerMessage.contains("加") || 
            lowerMessage.contains("减") || lowerMessage.contains("乘") || 
            lowerMessage.contains("除") || lowerMessage.matches(".*\\d+.*[+\\-*/].*\\d+.*")) {
            selectedFunctions.add("calculatorFunction");
            selectedFunctions.add("advancedCalculatorFunction");
        }
        
        // 时间日期相关
        if (lowerMessage.contains("时间") || lowerMessage.contains("日期") || 
            lowerMessage.contains("星期") || lowerMessage.contains("年龄") || 
            lowerMessage.contains("时区")) {
            selectedFunctions.add("dateTimeFunction");
        }
        
        // 用户查询相关
        if (lowerMessage.contains("用户") || lowerMessage.contains("查找") || 
            lowerMessage.contains("搜索") || lowerMessage.contains("邮箱")) {
            selectedFunctions.add("userQueryFunction");
        }
        
        // 如果没有匹配到特定函数,返回所有函数
        if (selectedFunctions.isEmpty()) {
            selectedFunctions.addAll(functionRegistry.getFunctionNames());
        }
        
        return selectedFunctions;
    }
    
    /**
     * 获取函数调用统计
     */
    public FunctionCallStats getFunctionCallStats() {
        return new FunctionCallStats(
            functionRegistry.getFunctionNames().size(),
            functionRegistry.getAllFunctionMetadata()
        );
    }
    
    /**
     * 函数调用统计
     */
    public record FunctionCallStats(
        int totalFunctions,
        java.util.Collection<FunctionRegistryManager.FunctionMetadata> functions
    ) {}
}

6.4.2 函数调用配置

// FunctionConfiguration.java
package com.example.springai.config;

import com.example.springai.function.*;
import com.example.springai.service.FunctionRegistryManager;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class FunctionConfiguration {
    
    /**
     * 配置函数注册
     */
    @Bean
    public FunctionRegistryManager functionRegistryManager(
            WeatherFunction weatherFunction,
            CalculatorFunction calculatorFunction,
            AdvancedCalculatorFunction advancedCalculatorFunction,
            DateTimeFunction dateTimeFunction,
            UserQueryFunction userQueryFunction) {
        
        FunctionRegistryManager registry = new FunctionRegistryManager();
        
        // 注册天气函数
        registry.registerFunction(
            "weatherFunction",
            weatherFunction,
            "获取指定城市的天气信息",
            WeatherFunction.Request.class
        );
        
        // 注册计算器函数
        registry.registerFunction(
            "calculatorFunction",
            calculatorFunction,
            "执行基本的数学计算",
            CalculatorFunction.Request.class
        );
        
        // 注册高级计算器函数
        registry.registerFunction(
            "advancedCalculatorFunction",
            advancedCalculatorFunction,
            "执行高级数学计算,支持数学函数",
            AdvancedCalculatorFunction.Request.class
        );
        
        // 注册日期时间函数
        registry.registerFunction(
            "dateTimeFunction",
            dateTimeFunction,
            "处理日期时间相关操作",
            DateTimeFunction.Request.class
        );
        
        // 注册用户查询函数
        registry.registerFunction(
            "userQueryFunction",
            userQueryFunction,
            "查询用户信息",
            UserQueryFunction.Request.class
        );
        
        return registry;
    }
}

6.5 函数调用控制器

6.5.1 REST API接口

// FunctionCallController.java
package com.example.springai.controller;

import com.example.springai.service.FunctionCallService;
import com.example.springai.service.FunctionRegistryManager;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import java.util.List;
import java.util.Map;

@RestController
@RequestMapping("/api/functions")
public class FunctionCallController {
    
    private final FunctionCallService functionCallService;
    private final FunctionRegistryManager functionRegistry;
    
    public FunctionCallController(FunctionCallService functionCallService, 
                                 FunctionRegistryManager functionRegistry) {
        this.functionCallService = functionCallService;
        this.functionRegistry = functionRegistry;
    }
    
    /**
     * 执行带函数调用的聊天
     */
    @PostMapping("/chat")
    public ResponseEntity<ChatResponse> chatWithFunctions(@RequestBody ChatRequest request) {
        try {
            String response = functionCallService.chatWithFunctions(
                request.message(), 
                request.functions()
            );
            
            return ResponseEntity.ok(new ChatResponse(response, true, null));
        } catch (Exception e) {
            return ResponseEntity.ok(new ChatResponse(null, false, e.getMessage()));
        }
    }
    
    /**
     * 自动选择函数的聊天
     */
    @PostMapping("/chat/auto")
    public ResponseEntity<ChatResponse> chatWithAutoFunctions(@RequestBody AutoChatRequest request) {
        try {
            String response = functionCallService.chatWithAutoFunctionSelection(request.message());
            return ResponseEntity.ok(new ChatResponse(response, true, null));
        } catch (Exception e) {
            return ResponseEntity.ok(new ChatResponse(null, false, e.getMessage()));
        }
    }
    
    /**
     * 获取所有可用函数
     */
    @GetMapping("/available")
    public ResponseEntity<List<FunctionInfo>> getAvailableFunctions() {
        List<FunctionInfo> functions = functionRegistry.getAllFunctionMetadata()
            .stream()
            .map(metadata -> new FunctionInfo(
                metadata.name(),
                metadata.description(),
                metadata.inputType().getSimpleName()
            ))
            .toList();
        
        return ResponseEntity.ok(functions);
    }
    
    /**
     * 获取函数调用统计
     */
    @GetMapping("/stats")
    public ResponseEntity<FunctionCallService.FunctionCallStats> getFunctionStats() {
        return ResponseEntity.ok(functionCallService.getFunctionCallStats());
    }
    
    /**
     * 获取函数详细信息
     */
    @GetMapping("/{functionName}")
    public ResponseEntity<FunctionDetail> getFunctionDetail(@PathVariable String functionName) {
        FunctionRegistryManager.FunctionMetadata metadata = 
            functionRegistry.getFunctionMetadata(functionName);
        
        if (metadata == null) {
            return ResponseEntity.notFound().build();
        }
        
        FunctionDetail detail = new FunctionDetail(
            metadata.name(),
            metadata.description(),
            metadata.inputType().getSimpleName(),
            getInputSchema(metadata.inputType())
        );
        
        return ResponseEntity.ok(detail);
    }
    
    /**
     * 测试函数调用
     */
    @PostMapping("/test/{functionName}")
    public ResponseEntity<TestResult> testFunction(
            @PathVariable String functionName,
            @RequestBody Map<String, Object> input) {
        
        try {
            // 这里应该实现具体的函数测试逻辑
            // 为了演示,返回模拟结果
            
            return ResponseEntity.ok(new TestResult(
                functionName,
                input,
                "测试成功",
                true,
                null
            ));
        } catch (Exception e) {
            return ResponseEntity.ok(new TestResult(
                functionName,
                input,
                null,
                false,
                e.getMessage()
            ));
        }
    }
    
    /**
     * 获取输入模式
     */
    private String getInputSchema(Class<?> inputType) {
        // 这里应该实现JSON Schema生成逻辑
        // 为了演示,返回简化的模式描述
        return "输入类型:" + inputType.getSimpleName();
    }
    
    // 请求和响应记录类
    public record ChatRequest(String message, List<String> functions) {}
    public record AutoChatRequest(String message) {}
    public record ChatResponse(String response, boolean success, String error) {}
    public record FunctionInfo(String name, String description, String inputType) {}
    public record FunctionDetail(String name, String description, String inputType, String schema) {}
    public record TestResult(String functionName, Map<String, Object> input, 
                           String result, boolean success, String error) {}
}

6.5.2 WebSocket实时聊天

// FunctionChatWebSocketHandler.java
package com.example.springai.websocket;

import com.example.springai.service.FunctionCallService;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Component
public class FunctionChatWebSocketHandler implements WebSocketHandler {
    
    private final FunctionCallService functionCallService;
    private final ObjectMapper objectMapper;
    private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
    
    public FunctionChatWebSocketHandler(FunctionCallService functionCallService) {
        this.functionCallService = functionCallService;
        this.objectMapper = new ObjectMapper();
    }
    
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        sessions.put(session.getId(), session);
        
        // 发送欢迎消息
        WebSocketMessage welcomeMessage = new WebSocketMessage(
            "system",
            "连接成功!您可以开始与AI助手对话,支持函数调用功能。",
            System.currentTimeMillis()
        );
        
        session.sendMessage(new TextMessage(objectMapper.writeValueAsString(welcomeMessage)));
    }
    
    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        if (message instanceof TextMessage textMessage) {
            try {
                // 解析用户消息
                Map<String, Object> messageData = objectMapper.readValue(
                    textMessage.getPayload(), Map.class);
                
                String userMessage = (String) messageData.get("message");
                
                // 发送用户消息确认
                WebSocketMessage userEcho = new WebSocketMessage(
                    "user",
                    userMessage,
                    System.currentTimeMillis()
                );
                session.sendMessage(new TextMessage(objectMapper.writeValueAsString(userEcho)));
                
                // 处理AI响应
                String aiResponse = functionCallService.chatWithAutoFunctionSelection(userMessage);
                
                WebSocketMessage aiMessage = new WebSocketMessage(
                    "assistant",
                    aiResponse,
                    System.currentTimeMillis()
                );
                
                session.sendMessage(new TextMessage(objectMapper.writeValueAsString(aiMessage)));
                
            } catch (Exception e) {
                WebSocketMessage errorMessage = new WebSocketMessage(
                    "error",
                    "处理消息时发生错误:" + e.getMessage(),
                    System.currentTimeMillis()
                );
                
                session.sendMessage(new TextMessage(objectMapper.writeValueAsString(errorMessage)));
            }
        }
    }
    
    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        sessions.remove(session.getId());
    }
    
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        sessions.remove(session.getId());
    }
    
    @Override
    public boolean supportsPartialMessages() {
        return false;
    }
    
    /**
     * WebSocket消息格式
     */
    public record WebSocketMessage(
        String type,
        String content,
        long timestamp
    ) {}
}

// WebSocket配置
package com.example.springai.config;

import com.example.springai.websocket.FunctionChatWebSocketHandler;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.*;

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
    
    private final FunctionChatWebSocketHandler webSocketHandler;
    
    public WebSocketConfig(FunctionChatWebSocketHandler webSocketHandler) {
        this.webSocketHandler = webSocketHandler;
    }
    
    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(webSocketHandler, "/ws/chat")
                .setAllowedOrigins("*");
    }
}

6.6 配置文件

6.6.1 应用配置

# application.yml
spring:
  ai:
    openai:
      api-key: ${OPENAI_API_KEY}
      chat:
        options:
          model: gpt-4
          temperature: 0.7
          max-tokens: 2000
    
  datasource:
    url: jdbc:h2:mem:testdb
    driver-class-name: org.h2.Driver
    username: sa
    password: password
    
  jpa:
    hibernate:
      ddl-auto: create-drop
    show-sql: true
    
  h2:
    console:
      enabled: true

# 天气API配置
weather:
  api:
    key: ${WEATHER_API_KEY:demo_key}
    url: https://api.openweathermap.org/data/2.5/weather

# 函数调用配置
function:
  call:
    timeout: 30000
    retry:
      max-attempts: 3
      delay: 1000
    cache:
      enabled: true
      ttl: 300

# 日志配置
logging:
  level:
    com.example.springai: DEBUG
    org.springframework.ai: DEBUG
  pattern:
    console: "%d{yyyy-MM-dd HH:mm:ss} - %msg%n"

6.6.2 函数调用属性配置

// FunctionCallProperties.java
package com.example.springai.config;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

@Component
@ConfigurationProperties(prefix = "function.call")
public class FunctionCallProperties {
    
    private int timeout = 30000;
    private Retry retry = new Retry();
    private Cache cache = new Cache();
    
    // Getters and Setters
    public int getTimeout() { return timeout; }
    public void setTimeout(int timeout) { this.timeout = timeout; }
    
    public Retry getRetry() { return retry; }
    public void setRetry(Retry retry) { this.retry = retry; }
    
    public Cache getCache() { return cache; }
    public void setCache(Cache cache) { this.cache = cache; }
    
    public static class Retry {
        private int maxAttempts = 3;
        private long delay = 1000;
        
        public int getMaxAttempts() { return maxAttempts; }
        public void setMaxAttempts(int maxAttempts) { this.maxAttempts = maxAttempts; }
        
        public long getDelay() { return delay; }
        public void setDelay(long delay) { this.delay = delay; }
    }
    
    public static class Cache {
        private boolean enabled = true;
        private long ttl = 300;
        
        public boolean isEnabled() { return enabled; }
        public void setEnabled(boolean enabled) { this.enabled = enabled; }
        
        public long getTtl() { return ttl; }
        public void setTtl(long ttl) { this.ttl = ttl; }
    }
}

6.7 本章总结

6.7.1 核心要点

  1. 函数调用基础

    • 理解函数调用的概念和架构
    • 掌握Spring AI的函数调用接口
    • 学会函数定义和注册
  2. 函数实现

    • 天气查询函数:外部API集成
    • 计算器函数:数学表达式处理
    • 日期时间函数:时间处理和格式化
    • 数据库查询函数:数据访问集成
  3. 服务架构

    • 函数注册管理器:统一函数管理
    • 函数调用服务:智能函数选择
    • REST API:HTTP接口支持
    • WebSocket:实时聊天功能

6.7.2 最佳实践

  1. 函数设计原则

    • 单一职责:每个函数专注一个功能
    • 参数验证:确保输入参数的有效性
    • 错误处理:提供清晰的错误信息
    • 文档完善:详细的函数描述和参数说明
  2. 性能优化

    • 缓存机制:缓存函数调用结果
    • 超时控制:避免长时间等待
    • 重试策略:处理临时性错误
    • 异步处理:提高响应速度
  3. 安全考虑

    • 输入验证:防止恶意输入
    • 权限控制:限制函数访问权限
    • 敏感数据:保护API密钥等敏感信息
    • 日志记录:记录函数调用日志

6.7.3 练习题

  1. 基础练习

    • 实现一个货币转换函数
    • 创建一个文件操作函数
    • 开发一个邮件发送函数
  2. 进阶练习

    • 实现函数调用缓存机制
    • 添加函数调用监控和统计
    • 创建函数调用链功能
  3. 综合项目

    • 构建一个智能助手系统
    • 集成多个外部API服务
    • 实现函数调用的可视化界面

下一章预告: 第七章将介绍Spring AI的流式处理与异步编程,包括流式响应、异步处理、响应式编程等内容。 “`