概述

本章将深入介绍如何开发和实现 gRPC 服务端,包括服务注册、方法实现、错误处理、中间件集成等核心内容。我们将通过实际代码示例,学习如何构建高质量、可扩展的 gRPC 服务。

学习目标

  • 掌握 gRPC 服务端的基本架构和实现模式
  • 学习如何实现四种不同类型的 RPC 方法
  • 了解服务生命周期管理和优雅关闭
  • 掌握错误处理和状态码的使用
  • 学习中间件和拦截器的开发

gRPC 服务端架构

from enum import Enum
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Callable, AsyncIterator
from abc import ABC, abstractmethod
import asyncio
import logging
import time

class ServerState(Enum):
    """服务器状态枚举"""
    STOPPED = "stopped"
    STARTING = "starting"
    RUNNING = "running"
    STOPPING = "stopping"
    ERROR = "error"

class MethodType(Enum):
    """方法类型枚举"""
    UNARY = "unary"
    SERVER_STREAMING = "server_streaming"
    CLIENT_STREAMING = "client_streaming"
    BIDIRECTIONAL_STREAMING = "bidirectional_streaming"

class LogLevel(Enum):
    """日志级别枚举"""
    DEBUG = "debug"
    INFO = "info"
    WARN = "warn"
    ERROR = "error"

@dataclass
class ServerConfig:
    """服务器配置"""
    host: str = "localhost"
    port: int = 50051
    max_workers: int = 10
    max_receive_message_length: int = 4 * 1024 * 1024  # 4MB
    max_send_message_length: int = 4 * 1024 * 1024     # 4MB
    keepalive_time_ms: int = 30000
    keepalive_timeout_ms: int = 5000
    keepalive_permit_without_calls: bool = True
    max_connection_idle_ms: int = 300000
    max_connection_age_ms: int = 600000
    enable_reflection: bool = True
    enable_health_check: bool = True

@dataclass
class RequestContext:
    """请求上下文"""
    method_name: str
    peer: str
    metadata: Dict[str, str]
    start_time: float
    request_id: str
    user_id: Optional[str] = None
    trace_id: Optional[str] = None

class GRPCServerManager:
    """gRPC 服务器管理器"""
    
    def __init__(self, config: ServerConfig):
        self.config = config
        self.state = ServerState.STOPPED
        self.services = []
        self.interceptors = []
        self.middleware = []
        
    def create_basic_server_implementation(self) -> str:
        """创建基础服务器实现"""
        return """
// server.go - gRPC 服务器基础实现
package main

import (
    "context"
    "fmt"
    "log"
    "net"
    "os"
    "os/signal"
    "syscall"
    "time"
    
    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/health"
    "google.golang.org/grpc/health/grpc_health_v1"
    "google.golang.org/grpc/keepalive"
    "google.golang.org/grpc/reflection"
    "google.golang.org/grpc/status"
    
    pb "./proto/user"
)

// UserServer 用户服务实现
type UserServer struct {
    pb.UnimplementedUserServiceServer
    users map[string]*pb.User
}

// NewUserServer 创建用户服务实例
func NewUserServer() *UserServer {
    return &UserServer{
        users: make(map[string]*pb.User),
    }
}

// CreateUser 创建用户(一元调用)
func (s *UserServer) CreateUser(ctx context.Context, req *pb.CreateUserRequest) (*pb.CreateUserResponse, error) {
    // 输入验证
    if req.Username == "" {
        return nil, status.Errorf(codes.InvalidArgument, "username is required")
    }
    if req.Email == "" {
        return nil, status.Errorf(codes.InvalidArgument, "email is required")
    }
    
    // 检查用户是否已存在
    for _, user := range s.users {
        if user.Username == req.Username {
            return nil, status.Errorf(codes.AlreadyExists, "username already exists")
        }
        if user.Email == req.Email {
            return nil, status.Errorf(codes.AlreadyExists, "email already exists")
        }
    }
    
    // 创建新用户
    user := &pb.User{
        Id:       generateUserID(),
        Username: req.Username,
        Email:    req.Email,
        FullName: req.FullName,
        Status:   pb.UserStatus_ACTIVE,
        CreatedAt: timestamppb.Now(),
        UpdatedAt: timestamppb.Now(),
    }
    
    s.users[user.Id] = user
    
    return &pb.CreateUserResponse{
        User:    user,
        Success: true,
    }, nil
}

// GetUser 获取用户信息(一元调用)
func (s *UserServer) GetUser(ctx context.Context, req *pb.GetUserRequest) (*pb.GetUserResponse, error) {
    user, exists := s.users[req.UserId]
    if !exists {
        return nil, status.Errorf(codes.NotFound, "user not found")
    }
    
    return &pb.GetUserResponse{
        User: user,
    }, nil
}

// ListUsers 列出用户(服务端流)
func (s *UserServer) ListUsers(req *pb.ListUsersRequest, stream pb.UserService_ListUsersServer) error {
    for _, user := range s.users {
        // 检查上下文是否被取消
        if err := stream.Context().Err(); err != nil {
            return err
        }
        
        // 发送用户数据
        if err := stream.Send(user); err != nil {
            return err
        }
        
        // 模拟处理延迟
        time.Sleep(100 * time.Millisecond)
    }
    
    return nil
}

// BatchCreateUsers 批量创建用户(客户端流)
func (s *UserServer) BatchCreateUsers(stream pb.UserService_BatchCreateUsersServer) error {
    var createdUsers []*pb.User
    var errors []string
    
    for {
        req, err := stream.Recv()
        if err == io.EOF {
            // 客户端完成发送
            return stream.SendAndClose(&pb.BatchCreateUsersResponse{
                Users:        createdUsers,
                SuccessCount: int32(len(createdUsers)),
                ErrorCount:   int32(len(errors)),
                Errors:       errors,
            })
        }
        if err != nil {
            return err
        }
        
        // 处理单个用户创建请求
        user := &pb.User{
            Id:       generateUserID(),
            Username: req.Username,
            Email:    req.Email,
            FullName: req.FullName,
            Status:   pb.UserStatus_ACTIVE,
            CreatedAt: timestamppb.Now(),
            UpdatedAt: timestamppb.Now(),
        }
        
        // 验证用户数据
        if err := validateUser(user); err != nil {
            errors = append(errors, fmt.Sprintf("Invalid user %s: %v", user.Username, err))
            continue
        }
        
        s.users[user.Id] = user
        createdUsers = append(createdUsers, user)
    }
}

// SyncUsers 同步用户数据(双向流)
func (s *UserServer) SyncUsers(stream pb.UserService_SyncUsersServer) error {
    for {
        req, err := stream.Recv()
        if err == io.EOF {
            return nil
        }
        if err != nil {
            return err
        }
        
        var response *pb.SyncUserResponse
        
        switch req.Operation {
        case pb.SyncOperation_CREATE:
            user, err := s.createUserInternal(req.User)
            if err != nil {
                response = &pb.SyncUserResponse{
                    Success: false,
                    Error:   err.Error(),
                }
            } else {
                response = &pb.SyncUserResponse{
                    Success: true,
                    User:    user,
                }
            }
            
        case pb.SyncOperation_UPDATE:
            user, err := s.updateUserInternal(req.User)
            if err != nil {
                response = &pb.SyncUserResponse{
                    Success: false,
                    Error:   err.Error(),
                }
            } else {
                response = &pb.SyncUserResponse{
                    Success: true,
                    User:    user,
                }
            }
            
        case pb.SyncOperation_DELETE:
            err := s.deleteUserInternal(req.User.Id)
            response = &pb.SyncUserResponse{
                Success: err == nil,
                Error:   func() string {
                    if err != nil {
                        return err.Error()
                    }
                    return ""
                }(),
            }
        }
        
        if err := stream.Send(response); err != nil {
            return err
        }
    }
}

// 辅助函数
func generateUserID() string {
    return fmt.Sprintf("user_%d", time.Now().UnixNano())
}

func validateUser(user *pb.User) error {
    if user.Username == "" {
        return fmt.Errorf("username is required")
    }
    if user.Email == "" {
        return fmt.Errorf("email is required")
    }
    return nil
}

func (s *UserServer) createUserInternal(user *pb.User) (*pb.User, error) {
    if err := validateUser(user); err != nil {
        return nil, err
    }
    
    user.Id = generateUserID()
    user.CreatedAt = timestamppb.Now()
    user.UpdatedAt = timestamppb.Now()
    
    s.users[user.Id] = user
    return user, nil
}

func (s *UserServer) updateUserInternal(user *pb.User) (*pb.User, error) {
    existing, exists := s.users[user.Id]
    if !exists {
        return nil, fmt.Errorf("user not found")
    }
    
    // 更新字段
    existing.Username = user.Username
    existing.Email = user.Email
    existing.FullName = user.FullName
    existing.UpdatedAt = timestamppb.Now()
    
    return existing, nil
}

func (s *UserServer) deleteUserInternal(userID string) error {
    if _, exists := s.users[userID]; !exists {
        return fmt.Errorf("user not found")
    }
    
    delete(s.users, userID)
    return nil
}

// main 函数
func main() {
    // 创建监听器
    lis, err := net.Listen("tcp", ":50051")
    if err != nil {
        log.Fatalf("Failed to listen: %v", err)
    }
    
    // 创建 gRPC 服务器
    s := grpc.NewServer(
        grpc.KeepaliveParams(keepalive.ServerParameters{
            Time:    30 * time.Second,
            Timeout: 5 * time.Second,
        }),
        grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
            MinTime:             5 * time.Second,
            PermitWithoutStream: true,
        }),
        grpc.MaxRecvMsgSize(4*1024*1024), // 4MB
        grpc.MaxSendMsgSize(4*1024*1024), // 4MB
    )
    
    // 注册服务
    userServer := NewUserServer()
    pb.RegisterUserServiceServer(s, userServer)
    
    // 注册健康检查服务
    healthServer := health.NewServer()
    grpc_health_v1.RegisterHealthServer(s, healthServer)
    healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING)
    
    // 启用反射(开发环境)
    reflection.Register(s)
    
    // 优雅关闭处理
    go func() {
        sigChan := make(chan os.Signal, 1)
        signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
        <-sigChan
        
        log.Println("Shutting down gRPC server...")
        s.GracefulStop()
    }()
    
    log.Printf("gRPC server listening on %s", lis.Addr())
    if err := s.Serve(lis); err != nil {
        log.Fatalf("Failed to serve: %v", err)
    }
}
"""
    
    def create_middleware_system(self) -> str:
        """创建中间件系统"""
        return """
// middleware.go - gRPC 中间件系统
package middleware

import (
    "context"
    "fmt"
    "log"
    "time"
    
    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/metadata"
    "google.golang.org/grpc/status"
)

// LoggingInterceptor 日志拦截器
func LoggingInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        start := time.Now()
        
        // 获取客户端信息
        peer, _ := peer.FromContext(ctx)
        
        log.Printf("[REQUEST] Method: %s, Peer: %s, Start: %s", 
            info.FullMethod, peer.Addr, start.Format(time.RFC3339))
        
        // 调用实际的处理函数
        resp, err := handler(ctx, req)
        
        // 记录响应信息
        duration := time.Since(start)
        status := "OK"
        if err != nil {
            status = fmt.Sprintf("ERROR: %v", err)
        }
        
        log.Printf("[RESPONSE] Method: %s, Status: %s, Duration: %s", 
            info.FullMethod, status, duration)
        
        return resp, err
    }
}

// AuthInterceptor 认证拦截器
func AuthInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        // 跳过健康检查和反射服务的认证
        if isPublicMethod(info.FullMethod) {
            return handler(ctx, req)
        }
        
        // 从元数据中获取认证信息
        md, ok := metadata.FromIncomingContext(ctx)
        if !ok {
            return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
        }
        
        // 验证 Authorization header
        authHeaders := md.Get("authorization")
        if len(authHeaders) == 0 {
            return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
        }
        
        token := authHeaders[0]
        userID, err := validateToken(token)
        if err != nil {
            return nil, status.Errorf(codes.Unauthenticated, "invalid token: %v", err)
        }
        
        // 将用户ID添加到上下文
        ctx = context.WithValue(ctx, "user_id", userID)
        
        return handler(ctx, req)
    }
}

// RateLimitInterceptor 限流拦截器
func RateLimitInterceptor(limiter *RateLimiter) grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        // 获取客户端IP
        peer, _ := peer.FromContext(ctx)
        clientIP := extractClientIP(peer.Addr.String())
        
        // 检查限流
        if !limiter.Allow(clientIP) {
            return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded")
        }
        
        return handler(ctx, req)
    }
}

// MetricsInterceptor 指标收集拦截器
func MetricsInterceptor(metrics *MetricsCollector) grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        start := time.Now()
        
        // 增加请求计数
        metrics.IncRequestCount(info.FullMethod)
        
        // 调用处理函数
        resp, err := handler(ctx, req)
        
        // 记录响应时间
        duration := time.Since(start)
        metrics.RecordDuration(info.FullMethod, duration)
        
        // 记录错误
        if err != nil {
            code := status.Code(err)
            metrics.IncErrorCount(info.FullMethod, code.String())
        }
        
        return resp, err
    }
}

// RecoveryInterceptor 恢复拦截器
func RecoveryInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
        defer func() {
            if r := recover(); r != nil {
                log.Printf("[PANIC] Method: %s, Error: %v", info.FullMethod, r)
                err = status.Errorf(codes.Internal, "internal server error")
            }
        }()
        
        return handler(ctx, req)
    }
}

// 流式拦截器
func LoggingStreamInterceptor() grpc.StreamServerInterceptor {
    return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
        start := time.Now()
        
        log.Printf("[STREAM_START] Method: %s, Start: %s", 
            info.FullMethod, start.Format(time.RFC3339))
        
        err := handler(srv, stream)
        
        duration := time.Since(start)
        status := "OK"
        if err != nil {
            status = fmt.Sprintf("ERROR: %v", err)
        }
        
        log.Printf("[STREAM_END] Method: %s, Status: %s, Duration: %s", 
            info.FullMethod, status, duration)
        
        return err
    }
}

// 辅助函数
func isPublicMethod(method string) bool {
    publicMethods := []string{
        "/grpc.health.v1.Health/Check",
        "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
    }
    
    for _, publicMethod := range publicMethods {
        if method == publicMethod {
            return true
        }
    }
    return false
}

func validateToken(token string) (string, error) {
    // 这里应该实现实际的token验证逻辑
    // 例如JWT验证、数据库查询等
    if token == "" {
        return "", fmt.Errorf("empty token")
    }
    
    // 简单示例:假设token格式为 "Bearer user_id"
    if len(token) > 7 && token[:7] == "Bearer " {
        return token[7:], nil
    }
    
    return "", fmt.Errorf("invalid token format")
}

func extractClientIP(addr string) string {
    // 从地址中提取IP
    if idx := strings.LastIndex(addr, ":"); idx != -1 {
        return addr[:idx]
    }
    return addr
}
"""
    
    def create_error_handling_system(self) -> str:
        """创建错误处理系统"""
        return """
// errors.go - gRPC 错误处理系统
package errors

import (
    "fmt"
    
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
    "google.golang.org/protobuf/types/known/anypb"
)

// BusinessError 业务错误接口
type BusinessError interface {
    error
    Code() string
    Message() string
    Details() map[string]interface{}
}

// ValidationError 验证错误
type ValidationError struct {
    Field   string `json:"field"`
    Message string `json:"message"`
    Value   interface{} `json:"value,omitempty"`
}

func (e ValidationError) Error() string {
    return fmt.Sprintf("validation failed for field '%s': %s", e.Field, e.Message)
}

// BusinessErrorImpl 业务错误实现
type BusinessErrorImpl struct {
    code    string
    message string
    details map[string]interface{}
}

func NewBusinessError(code, message string) *BusinessErrorImpl {
    return &BusinessErrorImpl{
        code:    code,
        message: message,
        details: make(map[string]interface{}),
    }
}

func (e *BusinessErrorImpl) Error() string {
    return e.message
}

func (e *BusinessErrorImpl) Code() string {
    return e.code
}

func (e *BusinessErrorImpl) Message() string {
    return e.message
}

func (e *BusinessErrorImpl) Details() map[string]interface{} {
    return e.details
}

func (e *BusinessErrorImpl) WithDetail(key string, value interface{}) *BusinessErrorImpl {
    e.details[key] = value
    return e
}

// 预定义的业务错误
var (
    ErrUserNotFound     = NewBusinessError("USER_NOT_FOUND", "User not found")
    ErrUserAlreadyExists = NewBusinessError("USER_ALREADY_EXISTS", "User already exists")
    ErrInvalidEmail     = NewBusinessError("INVALID_EMAIL", "Invalid email format")
    ErrInvalidPassword  = NewBusinessError("INVALID_PASSWORD", "Invalid password")
    ErrPermissionDenied = NewBusinessError("PERMISSION_DENIED", "Permission denied")
)

// ToGRPCError 将业务错误转换为gRPC错误
func ToGRPCError(err error) error {
    if err == nil {
        return nil
    }
    
    switch e := err.(type) {
    case *BusinessErrorImpl:
        return businessErrorToGRPC(e)
    case ValidationError:
        return validationErrorToGRPC(e)
    default:
        // 未知错误,返回内部错误
        return status.Errorf(codes.Internal, "internal server error")
    }
}

func businessErrorToGRPC(err *BusinessErrorImpl) error {
    var code codes.Code
    
    switch err.Code() {
    case "USER_NOT_FOUND":
        code = codes.NotFound
    case "USER_ALREADY_EXISTS":
        code = codes.AlreadyExists
    case "INVALID_EMAIL", "INVALID_PASSWORD":
        code = codes.InvalidArgument
    case "PERMISSION_DENIED":
        code = codes.PermissionDenied
    default:
        code = codes.Internal
    }
    
    st := status.New(code, err.Message())
    
    // 添加详细信息
    if len(err.Details()) > 0 {
        details, _ := anypb.New(&ErrorDetails{
            Code:    err.Code(),
            Details: err.Details(),
        })
        st, _ = st.WithDetails(details)
    }
    
    return st.Err()
}

func validationErrorToGRPC(err ValidationError) error {
    st := status.New(codes.InvalidArgument, "validation failed")
    
    details, _ := anypb.New(&ValidationErrorDetails{
        Field:   err.Field,
        Message: err.Message,
        Value:   fmt.Sprintf("%v", err.Value),
    })
    
    st, _ = st.WithDetails(details)
    return st.Err()
}

// 错误详情消息
type ErrorDetails struct {
    Code    string                 `json:"code"`
    Details map[string]interface{} `json:"details"`
}

type ValidationErrorDetails struct {
    Field   string `json:"field"`
    Message string `json:"message"`
    Value   string `json:"value"`
}

// 验证函数
func ValidateCreateUserRequest(req *pb.CreateUserRequest) error {
    if req.Username == "" {
        return ValidationError{
            Field:   "username",
            Message: "username is required",
        }
    }
    
    if len(req.Username) < 3 {
        return ValidationError{
            Field:   "username",
            Message: "username must be at least 3 characters",
            Value:   req.Username,
        }
    }
    
    if req.Email == "" {
        return ValidationError{
            Field:   "email",
            Message: "email is required",
        }
    }
    
    if !isValidEmail(req.Email) {
        return ValidationError{
            Field:   "email",
            Message: "invalid email format",
            Value:   req.Email,
        }
    }
    
    return nil
}

func isValidEmail(email string) bool {
    // 简单的邮箱验证
    return strings.Contains(email, "@") && strings.Contains(email, ".")
}
"""
    
    def create_testing_framework(self) -> str:
        """创建测试框架"""
        return """
// server_test.go - gRPC 服务测试
package main

import (
    "context"
    "net"
    "testing"
    "time"
    
    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
    "google.golang.org/grpc/test/bufconn"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    
    pb "./proto/user"
)

const bufSize = 1024 * 1024

var lis *bufconn.Listener

func init() {
    lis = bufconn.Listen(bufSize)
    s := grpc.NewServer()
    pb.RegisterUserServiceServer(s, NewUserServer())
    go func() {
        if err := s.Serve(lis); err != nil {
            log.Fatalf("Server exited with error: %v", err)
        }
    }()
}

func bufDialer(context.Context, string) (net.Conn, error) {
    return lis.Dial()
}

func TestUserService_CreateUser(t *testing.T) {
    ctx := context.Background()
    conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
    require.NoError(t, err)
    defer conn.Close()
    
    client := pb.NewUserServiceClient(conn)
    
    tests := []struct {
        name    string
        request *pb.CreateUserRequest
        wantErr bool
        errCode codes.Code
    }{
        {
            name: "valid user creation",
            request: &pb.CreateUserRequest{
                Username: "testuser",
                Email:    "test@example.com",
                FullName: "Test User",
            },
            wantErr: false,
        },
        {
            name: "missing username",
            request: &pb.CreateUserRequest{
                Email:    "test@example.com",
                FullName: "Test User",
            },
            wantErr: true,
            errCode: codes.InvalidArgument,
        },
        {
            name: "missing email",
            request: &pb.CreateUserRequest{
                Username: "testuser",
                FullName: "Test User",
            },
            wantErr: true,
            errCode: codes.InvalidArgument,
        },
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            resp, err := client.CreateUser(ctx, tt.request)
            
            if tt.wantErr {
                assert.Error(t, err)
                st, ok := status.FromError(err)
                assert.True(t, ok)
                assert.Equal(t, tt.errCode, st.Code())
            } else {
                assert.NoError(t, err)
                assert.NotNil(t, resp)
                assert.True(t, resp.Success)
                assert.NotNil(t, resp.User)
                assert.Equal(t, tt.request.Username, resp.User.Username)
                assert.Equal(t, tt.request.Email, resp.User.Email)
            }
        })
    }
}

func TestUserService_ListUsers(t *testing.T) {
    ctx := context.Background()
    conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
    require.NoError(t, err)
    defer conn.Close()
    
    client := pb.NewUserServiceClient(conn)
    
    // 首先创建一些用户
    users := []*pb.CreateUserRequest{
        {Username: "user1", Email: "user1@example.com", FullName: "User One"},
        {Username: "user2", Email: "user2@example.com", FullName: "User Two"},
        {Username: "user3", Email: "user3@example.com", FullName: "User Three"},
    }
    
    for _, user := range users {
        _, err := client.CreateUser(ctx, user)
        require.NoError(t, err)
    }
    
    // 测试列出用户
    stream, err := client.ListUsers(ctx, &pb.ListUsersRequest{})
    require.NoError(t, err)
    
    var receivedUsers []*pb.User
    for {
        user, err := stream.Recv()
        if err == io.EOF {
            break
        }
        require.NoError(t, err)
        receivedUsers = append(receivedUsers, user)
    }
    
    assert.Len(t, receivedUsers, len(users))
}

func TestUserService_BatchCreateUsers(t *testing.T) {
    ctx := context.Background()
    conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
    require.NoError(t, err)
    defer conn.Close()
    
    client := pb.NewUserServiceClient(conn)
    
    stream, err := client.BatchCreateUsers(ctx)
    require.NoError(t, err)
    
    // 发送多个用户创建请求
    users := []*pb.CreateUserRequest{
        {Username: "batch1", Email: "batch1@example.com", FullName: "Batch User 1"},
        {Username: "batch2", Email: "batch2@example.com", FullName: "Batch User 2"},
        {Username: "batch3", Email: "batch3@example.com", FullName: "Batch User 3"},
    }
    
    for _, user := range users {
        err := stream.Send(user)
        require.NoError(t, err)
    }
    
    resp, err := stream.CloseAndRecv()
    require.NoError(t, err)
    
    assert.Equal(t, int32(len(users)), resp.SuccessCount)
    assert.Equal(t, int32(0), resp.ErrorCount)
    assert.Len(t, resp.Users, len(users))
}

// 基准测试
func BenchmarkUserService_CreateUser(b *testing.B) {
    ctx := context.Background()
    conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
    require.NoError(b, err)
    defer conn.Close()
    
    client := pb.NewUserServiceClient(conn)
    
    b.ResetTimer()
    for i := 0; i < b.N; i++ {
        req := &pb.CreateUserRequest{
            Username: fmt.Sprintf("user%d", i),
            Email:    fmt.Sprintf("user%d@example.com", i),
            FullName: fmt.Sprintf("User %d", i),
        }
        
        _, err := client.CreateUser(ctx, req)
        if err != nil {
            b.Fatal(err)
        }
    }
}
"""

# 创建服务器管理器实例
server_mgr = GRPCServerManager(ServerConfig())

# 生成基础服务器实现
server_impl = server_mgr.create_basic_server_implementation()
print("=== gRPC 服务器基础实现 ===")
print("✓ 一元调用实现")
print("✓ 服务端流实现")
print("✓ 客户端流实现")
print("✓ 双向流实现")
print("✓ 优雅关闭处理")

# 生成中间件系统
middleware_system = server_mgr.create_middleware_system()
print("\n=== 中间件系统 ===")
print("✓ 日志拦截器")
print("✓ 认证拦截器")
print("✓ 限流拦截器")
print("✓ 指标收集拦截器")
print("✓ 恢复拦截器")

# 生成错误处理系统
error_system = server_mgr.create_error_handling_system()
print("\n=== 错误处理系统 ===")
print("✓ 业务错误定义")
print("✓ 验证错误处理")
print("✓ gRPC状态码转换")
print("✓ 错误详情附加")

# 生成测试框架
testing_framework = server_mgr.create_testing_framework()
print("\n=== 测试框架 ===")
print("✓ 单元测试")
print("✓ 集成测试")
print("✓ 流式测试")
print("✓ 基准测试")

服务生命周期管理

1. 服务启动流程

// 完整的服务启动流程
func StartServer(config *ServerConfig) error {
    // 1. 创建监听器
    lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", config.Host, config.Port))
    if err != nil {
        return fmt.Errorf("failed to listen: %w", err)
    }
    
    // 2. 配置服务器选项
    opts := []grpc.ServerOption{
        grpc.KeepaliveParams(keepalive.ServerParameters{
            Time:    time.Duration(config.KeepaliveTimeMs) * time.Millisecond,
            Timeout: time.Duration(config.KeepaliveTimeoutMs) * time.Millisecond,
        }),
        grpc.MaxRecvMsgSize(config.MaxReceiveMessageLength),
        grpc.MaxSendMsgSize(config.MaxSendMessageLength),
    }
    
    // 3. 添加拦截器
    opts = append(opts, 
        grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
            LoggingInterceptor(),
            AuthInterceptor(),
            MetricsInterceptor(metricsCollector),
            RecoveryInterceptor(),
        )),
        grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
            LoggingStreamInterceptor(),
            // 其他流式拦截器...
        )),
    )
    
    // 4. 创建服务器
    s := grpc.NewServer(opts...)
    
    // 5. 注册服务
    registerServices(s)
    
    // 6. 启用可选功能
    if config.EnableReflection {
        reflection.Register(s)
    }
    
    if config.EnableHealthCheck {
        healthServer := health.NewServer()
        grpc_health_v1.RegisterHealthServer(s, healthServer)
        healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING)
    }
    
    // 7. 启动服务器
    log.Printf("gRPC server starting on %s", lis.Addr())
    return s.Serve(lis)
}

2. 优雅关闭

func GracefulShutdown(server *grpc.Server, timeout time.Duration) {
    // 创建关闭信号通道
    sigChan := make(chan os.Signal, 1)
    signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
    
    // 等待关闭信号
    <-sigChan
    log.Println("Received shutdown signal")
    
    // 创建超时上下文
    ctx, cancel := context.WithTimeout(context.Background(), timeout)
    defer cancel()
    
    // 优雅关闭
    done := make(chan struct{})
    go func() {
        server.GracefulStop()
        close(done)
    }()
    
    select {
    case <-done:
        log.Println("Server gracefully stopped")
    case <-ctx.Done():
        log.Println("Shutdown timeout, forcing stop")
        server.Stop()
    }
}

总结

通过本章的学习,您应该已经掌握了:

  1. 服务端架构:基本结构、配置管理、生命周期
  2. 方法实现:四种RPC模式的具体实现
  3. 中间件系统:拦截器链、认证、日志、监控
  4. 错误处理:业务错误、验证错误、状态码转换
  5. 测试策略:单元测试、集成测试、性能测试

在下一章中,我们将学习如何开发 gRPC 客户端,包括连接管理、调用模式和最佳实践。