概述

本章将深入探讨 gRPC 的四种流式处理模式,包括服务端流、客户端流、双向流的高级应用,以及流式处理的性能优化策略。我们将学习如何构建高效的流式应用和处理大规模数据传输。

学习目标

  • 深入理解四种 gRPC 流式处理模式
  • 掌握流式数据的缓冲和背压处理
  • 学习流式处理的性能优化技巧
  • 了解流式应用的监控和调试方法
  • 掌握大文件传输和实时数据流处理

流式处理架构

from enum import Enum
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Callable, Iterator, AsyncIterator
from abc import ABC, abstractmethod
import asyncio
import time
import threading
from queue import Queue, Empty
import io

class StreamType(Enum):
    """流类型枚举"""
    UNARY = "unary"
    SERVER_STREAM = "server_stream"
    CLIENT_STREAM = "client_stream"
    BIDIRECTIONAL_STREAM = "bidirectional_stream"

class FlowControlStrategy(Enum):
    """流控策略枚举"""
    BUFFER_BASED = "buffer_based"
    WINDOW_BASED = "window_based"
    RATE_BASED = "rate_based"
    ADAPTIVE = "adaptive"

class CompressionType(Enum):
    """压缩类型枚举"""
    NONE = "none"
    GZIP = "gzip"
    DEFLATE = "deflate"
    SNAPPY = "snappy"
    LZ4 = "lz4"

@dataclass
class StreamConfig:
    """流配置"""
    buffer_size: int = 1024 * 1024  # 1MB
    max_message_size: int = 4 * 1024 * 1024  # 4MB
    flow_control: FlowControlStrategy = FlowControlStrategy.BUFFER_BASED
    compression: CompressionType = CompressionType.GZIP
    enable_backpressure: bool = True
    max_concurrent_streams: int = 100
    stream_timeout_ms: int = 300000  # 5分钟
    keepalive_interval_ms: int = 30000
    window_size: int = 65536
    max_window_size: int = 1024 * 1024
    enable_flow_control: bool = True
    chunk_size: int = 64 * 1024  # 64KB

@dataclass
class StreamMetrics:
    """流指标"""
    messages_sent: int = 0
    messages_received: int = 0
    bytes_sent: int = 0
    bytes_received: int = 0
    stream_duration_ms: int = 0
    error_count: int = 0
    throughput_mbps: float = 0.0
    latency_ms: float = 0.0
    buffer_utilization: float = 0.0

class StreamProcessor:
    """流处理器"""
    
    def __init__(self, config: StreamConfig):
        self.config = config
        self.metrics = StreamMetrics()
        self.active_streams = {}
        
    def create_server_streaming_example(self) -> str:
        """创建服务端流示例"""
        return """
// server_streaming.go - 服务端流处理
package streaming

import (
    "context"
    "fmt"
    "io"
    "log"
    "sync"
    "time"
    
    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
    
    pb "./proto/streaming"
)

// StreamingServer 流式服务器
type StreamingServer struct {
    pb.UnimplementedStreamingServiceServer
    dataStore map[string][]byte
    mu        sync.RWMutex
}

// NewStreamingServer 创建流式服务器
func NewStreamingServer() *StreamingServer {
    return &StreamingServer{
        dataStore: make(map[string][]byte),
    }
}

// DownloadFile 文件下载(服务端流)
func (s *StreamingServer) DownloadFile(req *pb.DownloadRequest, stream pb.StreamingService_DownloadFileServer) error {
    log.Printf("Starting file download: %s", req.Filename)
    
    // 模拟文件数据
    s.mu.RLock()
    fileData, exists := s.dataStore[req.Filename]
    s.mu.RUnlock()
    
    if !exists {
        return status.Errorf(codes.NotFound, "file not found: %s", req.Filename)
    }
    
    chunkSize := 64 * 1024 // 64KB chunks
    totalChunks := (len(fileData) + chunkSize - 1) / chunkSize
    
    for i := 0; i < totalChunks; i++ {
        // 检查上下文是否已取消
        if stream.Context().Err() != nil {
            return stream.Context().Err()
        }
        
        start := i * chunkSize
        end := start + chunkSize
        if end > len(fileData) {
            end = len(fileData)
        }
        
        chunk := &pb.FileChunk{
            ChunkId:    int32(i),
            Data:       fileData[start:end],
            TotalSize:  int64(len(fileData)),
            ChunkSize:  int32(end - start),
            IsLastChunk: i == totalChunks-1,
            Checksum:   calculateChecksum(fileData[start:end]),
        }
        
        if err := stream.Send(chunk); err != nil {
            log.Printf("Error sending chunk %d: %v", i, err)
            return err
        }
        
        // 流控:模拟网络延迟
        if req.ThrottleMs > 0 {
            time.Sleep(time.Duration(req.ThrottleMs) * time.Millisecond)
        }
        
        log.Printf("Sent chunk %d/%d (%d bytes)", i+1, totalChunks, end-start)
    }
    
    log.Printf("File download completed: %s (%d bytes)", req.Filename, len(fileData))
    return nil
}

// StreamLogs 日志流(服务端流)
func (s *StreamingServer) StreamLogs(req *pb.LogStreamRequest, stream pb.StreamingService_StreamLogsServer) error {
    log.Printf("Starting log stream for service: %s", req.ServiceName)
    
    ticker := time.NewTicker(time.Duration(req.IntervalMs) * time.Millisecond)
    defer ticker.Stop()
    
    logCounter := 0
    startTime := time.Now()
    
    for {
        select {
        case <-stream.Context().Done():
            log.Printf("Log stream cancelled for service: %s", req.ServiceName)
            return stream.Context().Err()
        case <-ticker.C:
            logCounter++
            
            logEntry := &pb.LogEntry{
                Timestamp:   time.Now().Unix(),
                Level:       getRandomLogLevel(),
                ServiceName: req.ServiceName,
                Message:     fmt.Sprintf("Log message #%d from %s", logCounter, req.ServiceName),
                Metadata: map[string]string{
                    "request_id": fmt.Sprintf("req_%d", logCounter),
                    "duration":   time.Since(startTime).String(),
                },
            }
            
            if err := stream.Send(logEntry); err != nil {
                log.Printf("Error sending log entry: %v", err)
                return err
            }
            
            // 检查是否达到最大日志数量
            if req.MaxLogs > 0 && logCounter >= int(req.MaxLogs) {
                log.Printf("Reached max logs limit: %d", req.MaxLogs)
                return nil
            }
        }
    }
}

// StreamMetrics 指标流(服务端流)
func (s *StreamingServer) StreamMetrics(req *pb.MetricsRequest, stream pb.StreamingService_StreamMetricsServer) error {
    log.Printf("Starting metrics stream for: %v", req.MetricNames)
    
    ticker := time.NewTicker(time.Duration(req.IntervalMs) * time.Millisecond)
    defer ticker.Stop()
    
    for {
        select {
        case <-stream.Context().Done():
            return stream.Context().Err()
        case <-ticker.C:
            metrics := &pb.MetricsData{
                Timestamp: time.Now().Unix(),
                Metrics:   make(map[string]float64),
            }
            
            // 生成模拟指标数据
            for _, metricName := range req.MetricNames {
                metrics.Metrics[metricName] = generateMetricValue(metricName)
            }
            
            if err := stream.Send(metrics); err != nil {
                return err
            }
        }
    }
}

// 辅助函数
func calculateChecksum(data []byte) string {
    // 简单的校验和计算
    sum := 0
    for _, b := range data {
        sum += int(b)
    }
    return fmt.Sprintf("%x", sum)
}

func getRandomLogLevel() string {
    levels := []string{"DEBUG", "INFO", "WARN", "ERROR"}
    return levels[time.Now().Nanosecond()%len(levels)]
}

func generateMetricValue(metricName string) float64 {
    // 生成模拟指标值
    switch metricName {
    case "cpu_usage":
        return float64(time.Now().Nanosecond()%100) / 100.0
    case "memory_usage":
        return float64(time.Now().Nanosecond()%80) / 100.0
    case "request_rate":
        return float64(time.Now().Nanosecond()%1000) / 10.0
    default:
        return float64(time.Now().Nanosecond()%100)
    }
}

// 客户端流处理
func (s *StreamingServer) UploadFile(stream pb.StreamingService_UploadFileServer) error {
    log.Println("Starting file upload")
    
    var filename string
    var fileData []byte
    var totalSize int64
    receivedChunks := make(map[int32]bool)
    
    for {
        chunk, err := stream.Recv()
        if err == io.EOF {
            // 上传完成
            break
        }
        if err != nil {
            log.Printf("Error receiving chunk: %v", err)
            return err
        }
        
        if filename == "" {
            filename = chunk.Filename
            totalSize = chunk.TotalSize
            fileData = make([]byte, totalSize)
            log.Printf("Uploading file: %s (size: %d bytes)", filename, totalSize)
        }
        
        // 检查重复块
        if receivedChunks[chunk.ChunkId] {
            log.Printf("Duplicate chunk received: %d", chunk.ChunkId)
            continue
        }
        
        // 验证校验和
        expectedChecksum := calculateChecksum(chunk.Data)
        if chunk.Checksum != expectedChecksum {
            return status.Errorf(codes.DataLoss, "checksum mismatch for chunk %d", chunk.ChunkId)
        }
        
        // 写入数据
        offset := int64(chunk.ChunkId) * 64 * 1024
        copy(fileData[offset:], chunk.Data)
        receivedChunks[chunk.ChunkId] = true
        
        log.Printf("Received chunk %d (%d bytes)", chunk.ChunkId, len(chunk.Data))
    }
    
    // 保存文件
    s.mu.Lock()
    s.dataStore[filename] = fileData
    s.mu.Unlock()
    
    response := &pb.UploadResponse{
        Filename:     filename,
        Size:         int64(len(fileData)),
        Checksum:     calculateChecksum(fileData),
        ChunksCount:  int32(len(receivedChunks)),
        UploadTime:   time.Now().Unix(),
    }
    
    log.Printf("File upload completed: %s (%d bytes)", filename, len(fileData))
    return stream.SendAndClose(response)
}

// 双向流处理
func (s *StreamingServer) ChatStream(stream pb.StreamingService_ChatStreamServer) error {
    log.Println("Starting chat stream")
    
    // 创建消息处理通道
    messageChan := make(chan *pb.ChatMessage, 100)
    errorChan := make(chan error, 1)
    
    // 启动接收协程
    go func() {
        defer close(messageChan)
        for {
            msg, err := stream.Recv()
            if err == io.EOF {
                return
            }
            if err != nil {
                errorChan <- err
                return
            }
            messageChan <- msg
        }
    }()
    
    // 处理消息
    for {
        select {
        case msg, ok := <-messageChan:
            if !ok {
                log.Println("Chat stream completed")
                return nil
            }
            
            // 处理接收到的消息
            log.Printf("Received message from %s: %s", msg.Username, msg.Content)
            
            // 生成响应
            response := &pb.ChatMessage{
                Username:  "server",
                Content:   fmt.Sprintf("Echo: %s", msg.Content),
                Timestamp: time.Now().Unix(),
                MessageId: fmt.Sprintf("resp_%d", time.Now().Nanosecond()),
            }
            
            if err := stream.Send(response); err != nil {
                log.Printf("Error sending response: %v", err)
                return err
            }
            
        case err := <-errorChan:
            log.Printf("Chat stream error: %v", err)
            return err
            
        case <-stream.Context().Done():
            log.Println("Chat stream cancelled")
            return stream.Context().Err()
        }
    }
}

// 流式数据处理
func (s *StreamingServer) ProcessDataStream(stream pb.StreamingService_ProcessDataStreamServer) error {
    log.Println("Starting data processing stream")
    
    processor := NewDataProcessor()
    
    for {
        data, err := stream.Recv()
        if err == io.EOF {
            // 发送最终结果
            finalResult := processor.GetFinalResult()
            return stream.Send(finalResult)
        }
        if err != nil {
            return err
        }
        
        // 处理数据
        result := processor.ProcessData(data)
        
        // 发送中间结果
        if err := stream.Send(result); err != nil {
            return err
        }
    }
}

// DataProcessor 数据处理器
type DataProcessor struct {
    processedCount int
    totalSum       float64
    mu             sync.Mutex
}

func NewDataProcessor() *DataProcessor {
    return &DataProcessor{}
}

func (dp *DataProcessor) ProcessData(data *pb.DataPacket) *pb.ProcessResult {
    dp.mu.Lock()
    defer dp.mu.Unlock()
    
    dp.processedCount++
    
    // 模拟数据处理
    var sum float64
    for _, value := range data.Values {
        sum += value
        dp.totalSum += value
    }
    
    return &pb.ProcessResult{
        PacketId:       data.PacketId,
        ProcessedCount: int32(dp.processedCount),
        Sum:           sum,
        Average:       sum / float64(len(data.Values)),
        TotalSum:      dp.totalSum,
        Timestamp:     time.Now().Unix(),
    }
}

func (dp *DataProcessor) GetFinalResult() *pb.ProcessResult {
    dp.mu.Lock()
    defer dp.mu.Unlock()
    
    return &pb.ProcessResult{
        PacketId:       -1, // 表示最终结果
        ProcessedCount: int32(dp.processedCount),
        TotalSum:      dp.totalSum,
        Average:       dp.totalSum / float64(dp.processedCount),
        Timestamp:     time.Now().Unix(),
    }
}
"""
    
    def create_client_streaming_example(self) -> str:
        """创建客户端流示例"""
        return """
// client_streaming.go - 客户端流处理
package streaming

import (
    "context"
    "fmt"
    "io"
    "log"
    "os"
    "time"
    
    "google.golang.org/grpc"
    "google.golang.org/grpc/credentials/insecure"
    
    pb "./proto/streaming"
)

// StreamingClient 流式客户端
type StreamingClient struct {
    client pb.StreamingServiceClient
    conn   *grpc.ClientConn
}

// NewStreamingClient 创建流式客户端
func NewStreamingClient(address string) (*StreamingClient, error) {
    conn, err := grpc.Dial(address, grpc.WithTransportCredentials(insecure.NewCredentials()))
    if err != nil {
        return nil, err
    }
    
    return &StreamingClient{
        client: pb.NewStreamingServiceClient(conn),
        conn:   conn,
    }, nil
}

// Close 关闭客户端
func (c *StreamingClient) Close() error {
    return c.conn.Close()
}

// DownloadFileStream 下载文件(服务端流)
func (c *StreamingClient) DownloadFileStream(ctx context.Context, filename string) error {
    req := &pb.DownloadRequest{
        Filename:   filename,
        ThrottleMs: 10, // 10ms 延迟
    }
    
    stream, err := c.client.DownloadFile(ctx, req)
    if err != nil {
        return fmt.Errorf("failed to start download: %w", err)
    }
    
    var fileData []byte
    var totalSize int64
    receivedChunks := 0
    
    log.Printf("Starting download: %s", filename)
    startTime := time.Now()
    
    for {
        chunk, err := stream.Recv()
        if err == io.EOF {
            break
        }
        if err != nil {
            return fmt.Errorf("error receiving chunk: %w", err)
        }
        
        if totalSize == 0 {
            totalSize = chunk.TotalSize
            fileData = make([]byte, totalSize)
        }
        
        // 验证校验和
        expectedChecksum := calculateChecksum(chunk.Data)
        if chunk.Checksum != expectedChecksum {
            return fmt.Errorf("checksum mismatch for chunk %d", chunk.ChunkId)
        }
        
        // 写入数据
        offset := int64(chunk.ChunkId) * 64 * 1024
        copy(fileData[offset:], chunk.Data)
        receivedChunks++
        
        progress := float64(len(chunk.Data)+int(offset)) / float64(totalSize) * 100
        log.Printf("Downloaded chunk %d: %d bytes (%.1f%%)", chunk.ChunkId, len(chunk.Data), progress)
        
        if chunk.IsLastChunk {
            break
        }
    }
    
    duration := time.Since(startTime)
    throughput := float64(totalSize) / duration.Seconds() / 1024 / 1024 // MB/s
    
    log.Printf("Download completed: %s (%d bytes, %d chunks, %.2f MB/s)", 
        filename, totalSize, receivedChunks, throughput)
    
    // 保存文件
    return os.WriteFile(fmt.Sprintf("downloaded_%s", filename), fileData, 0644)
}

// UploadFileStream 上传文件(客户端流)
func (c *StreamingClient) UploadFileStream(ctx context.Context, filename string) error {
    // 读取文件
    fileData, err := os.ReadFile(filename)
    if err != nil {
        return fmt.Errorf("failed to read file: %w", err)
    }
    
    stream, err := c.client.UploadFile(ctx)
    if err != nil {
        return fmt.Errorf("failed to start upload: %w", err)
    }
    
    chunkSize := 64 * 1024 // 64KB
    totalChunks := (len(fileData) + chunkSize - 1) / chunkSize
    
    log.Printf("Starting upload: %s (%d bytes, %d chunks)", filename, len(fileData), totalChunks)
    startTime := time.Now()
    
    for i := 0; i < totalChunks; i++ {
        start := i * chunkSize
        end := start + chunkSize
        if end > len(fileData) {
            end = len(fileData)
        }
        
        chunk := &pb.FileChunk{
            Filename:    filename,
            ChunkId:     int32(i),
            Data:        fileData[start:end],
            TotalSize:   int64(len(fileData)),
            ChunkSize:   int32(end - start),
            IsLastChunk: i == totalChunks-1,
            Checksum:    calculateChecksum(fileData[start:end]),
        }
        
        if err := stream.Send(chunk); err != nil {
            return fmt.Errorf("failed to send chunk %d: %w", i, err)
        }
        
        progress := float64(end) / float64(len(fileData)) * 100
        log.Printf("Uploaded chunk %d/%d: %d bytes (%.1f%%)", i+1, totalChunks, end-start, progress)
        
        // 流控:避免发送过快
        time.Sleep(1 * time.Millisecond)
    }
    
    response, err := stream.CloseAndRecv()
    if err != nil {
        return fmt.Errorf("failed to close upload: %w", err)
    }
    
    duration := time.Since(startTime)
    throughput := float64(len(fileData)) / duration.Seconds() / 1024 / 1024 // MB/s
    
    log.Printf("Upload completed: %s (%d bytes, %.2f MB/s)", 
        response.Filename, response.Size, throughput)
    
    return nil
}

// StreamLogsClient 日志流客户端
func (c *StreamingClient) StreamLogsClient(ctx context.Context, serviceName string) error {
    req := &pb.LogStreamRequest{
        ServiceName: serviceName,
        IntervalMs:  1000, // 1秒间隔
        MaxLogs:     10,   // 最多10条日志
    }
    
    stream, err := c.client.StreamLogs(ctx, req)
    if err != nil {
        return fmt.Errorf("failed to start log stream: %w", err)
    }
    
    log.Printf("Starting log stream for service: %s", serviceName)
    
    for {
        logEntry, err := stream.Recv()
        if err == io.EOF {
            log.Println("Log stream completed")
            break
        }
        if err != nil {
            return fmt.Errorf("error receiving log: %w", err)
        }
        
        log.Printf("[%s] %s: %s", 
            time.Unix(logEntry.Timestamp, 0).Format("15:04:05"),
            logEntry.Level,
            logEntry.Message)
    }
    
    return nil
}

// ChatStreamClient 聊天流客户端
func (c *StreamingClient) ChatStreamClient(ctx context.Context, username string) error {
    stream, err := c.client.ChatStream(ctx)
    if err != nil {
        return fmt.Errorf("failed to start chat stream: %w", err)
    }
    
    // 启动接收协程
    go func() {
        for {
            msg, err := stream.Recv()
            if err == io.EOF {
                log.Println("Chat stream ended")
                return
            }
            if err != nil {
                log.Printf("Error receiving message: %v", err)
                return
            }
            
            log.Printf("[%s]: %s", msg.Username, msg.Content)
        }
    }()
    
    // 发送消息
    messages := []string{
        "Hello, server!",
        "How are you?",
        "This is a test message",
        "Goodbye!",
    }
    
    for i, content := range messages {
        msg := &pb.ChatMessage{
            Username:  username,
            Content:   content,
            Timestamp: time.Now().Unix(),
            MessageId: fmt.Sprintf("msg_%d", i),
        }
        
        if err := stream.Send(msg); err != nil {
            return fmt.Errorf("failed to send message: %w", err)
        }
        
        log.Printf("Sent: %s", content)
        time.Sleep(2 * time.Second)
    }
    
    // 关闭发送
    if err := stream.CloseSend(); err != nil {
        return fmt.Errorf("failed to close send: %w", err)
    }
    
    // 等待接收完成
    time.Sleep(1 * time.Second)
    return nil
}

// ProcessDataStreamClient 数据处理流客户端
func (c *StreamingClient) ProcessDataStreamClient(ctx context.Context) error {
    stream, err := c.client.ProcessDataStream(ctx)
    if err != nil {
        return fmt.Errorf("failed to start data stream: %w", err)
    }
    
    // 启动接收协程
    resultChan := make(chan *pb.ProcessResult, 100)
    errorChan := make(chan error, 1)
    
    go func() {
        defer close(resultChan)
        for {
            result, err := stream.Recv()
            if err == io.EOF {
                return
            }
            if err != nil {
                errorChan <- err
                return
            }
            resultChan <- result
        }
    }()
    
    // 发送数据包
    for i := 0; i < 5; i++ {
        data := &pb.DataPacket{
            PacketId: int32(i),
            Values:   generateRandomValues(100),
            Metadata: map[string]string{
                "source": "client",
                "batch":  fmt.Sprintf("batch_%d", i),
            },
        }
        
        if err := stream.Send(data); err != nil {
            return fmt.Errorf("failed to send data packet %d: %w", i, err)
        }
        
        log.Printf("Sent data packet %d with %d values", i, len(data.Values))
        time.Sleep(500 * time.Millisecond)
    }
    
    // 关闭发送
    if err := stream.CloseSend(); err != nil {
        return fmt.Errorf("failed to close send: %w", err)
    }
    
    // 接收结果
    for {
        select {
        case result, ok := <-resultChan:
            if !ok {
                log.Println("Data processing completed")
                return nil
            }
            
            if result.PacketId == -1 {
                log.Printf("Final result: processed %d packets, total sum: %.2f, average: %.2f",
                    result.ProcessedCount, result.TotalSum, result.Average)
            } else {
                log.Printf("Processed packet %d: sum=%.2f, avg=%.2f",
                    result.PacketId, result.Sum, result.Average)
            }
            
        case err := <-errorChan:
            return fmt.Errorf("data processing error: %w", err)
            
        case <-ctx.Done():
            return ctx.Err()
        }
    }
}

// generateRandomValues 生成随机数值
func generateRandomValues(count int) []float64 {
    values := make([]float64, count)
    for i := 0; i < count; i++ {
        values[i] = float64(time.Now().Nanosecond()%1000) / 10.0
    }
    return values
}

// 使用示例
func main() {
    client, err := NewStreamingClient("localhost:50051")
    if err != nil {
        log.Fatalf("Failed to create client: %v", err)
    }
    defer client.Close()
    
    ctx := context.Background()
    
    // 测试文件下载
    if err := client.DownloadFileStream(ctx, "test.txt"); err != nil {
        log.Printf("Download failed: %v", err)
    }
    
    // 测试文件上传
    if err := client.UploadFileStream(ctx, "upload_test.txt"); err != nil {
        log.Printf("Upload failed: %v", err)
    }
    
    // 测试日志流
    if err := client.StreamLogsClient(ctx, "user-service"); err != nil {
        log.Printf("Log stream failed: %v", err)
    }
    
    // 测试聊天流
    if err := client.ChatStreamClient(ctx, "client_user"); err != nil {
        log.Printf("Chat stream failed: %v", err)
    }
    
    // 测试数据处理流
    if err := client.ProcessDataStreamClient(ctx); err != nil {
        log.Printf("Data processing failed: %v", err)
    }
}
"""
    
    def create_performance_optimization(self) -> str:
        """创建性能优化配置"""
        return """
// performance.go - 流式处理性能优化
package streaming

import (
    "context"
    "fmt"
    "log"
    "runtime"
    "sync"
    "sync/atomic"
    "time"
    
    "google.golang.org/grpc"
    "google.golang.org/grpc/keepalive"
    "google.golang.org/grpc/stats"
)

// PerformanceConfig 性能配置
type PerformanceConfig struct {
    // 缓冲区配置
    SendBufferSize    int
    ReceiveBufferSize int
    
    // 并发配置
    MaxConcurrentStreams int
    WorkerPoolSize       int
    
    // 流控配置
    InitialWindowSize     int
    InitialConnWindowSize int
    MaxMessageSize        int
    
    // 压缩配置
    EnableCompression bool
    CompressionLevel  int
    
    // 保活配置
    KeepaliveTime    time.Duration
    KeepaliveTimeout time.Duration
    
    // 监控配置
    EnableMetrics bool
    MetricsInterval time.Duration
}

// DefaultPerformanceConfig 默认性能配置
func DefaultPerformanceConfig() *PerformanceConfig {
    return &PerformanceConfig{
        SendBufferSize:        1024 * 1024,    // 1MB
        ReceiveBufferSize:     1024 * 1024,    // 1MB
        MaxConcurrentStreams:  100,
        WorkerPoolSize:        runtime.NumCPU(),
        InitialWindowSize:     65536,          // 64KB
        InitialConnWindowSize: 1024 * 1024,    // 1MB
        MaxMessageSize:        4 * 1024 * 1024, // 4MB
        EnableCompression:     true,
        CompressionLevel:      6,
        KeepaliveTime:         30 * time.Second,
        KeepaliveTimeout:      5 * time.Second,
        EnableMetrics:         true,
        MetricsInterval:       10 * time.Second,
    }
}

// PerformanceOptimizer 性能优化器
type PerformanceOptimizer struct {
    config  *PerformanceConfig
    metrics *StreamMetrics
    mu      sync.RWMutex
}

// NewPerformanceOptimizer 创建性能优化器
func NewPerformanceOptimizer(config *PerformanceConfig) *PerformanceOptimizer {
    return &PerformanceOptimizer{
        config:  config,
        metrics: &StreamMetrics{},
    }
}

// OptimizeServerOptions 优化服务器选项
func (po *PerformanceOptimizer) OptimizeServerOptions() []grpc.ServerOption {
    opts := []grpc.ServerOption{
        grpc.MaxConcurrentStreams(uint32(po.config.MaxConcurrentStreams)),
        grpc.InitialWindowSize(int32(po.config.InitialWindowSize)),
        grpc.InitialConnWindowSize(int32(po.config.InitialConnWindowSize)),
        grpc.MaxMsgSize(po.config.MaxMessageSize),
        grpc.MaxRecvMsgSize(po.config.MaxMessageSize),
        grpc.MaxSendMsgSize(po.config.MaxMessageSize),
        grpc.KeepaliveParams(keepalive.ServerParameters{
            Time:    po.config.KeepaliveTime,
            Timeout: po.config.KeepaliveTimeout,
        }),
        grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
            MinTime:             5 * time.Second,
            PermitWithoutStream: true,
        }),
    }
    
    if po.config.EnableMetrics {
        opts = append(opts, grpc.StatsHandler(&metricsHandler{optimizer: po}))
    }
    
    return opts
}

// OptimizeClientOptions 优化客户端选项
func (po *PerformanceOptimizer) OptimizeClientOptions() []grpc.DialOption {
    opts := []grpc.DialOption{
        grpc.WithInitialWindowSize(int32(po.config.InitialWindowSize)),
        grpc.WithInitialConnWindowSize(int32(po.config.InitialConnWindowSize)),
        grpc.WithDefaultCallOptions(
            grpc.MaxCallRecvMsgSize(po.config.MaxMessageSize),
            grpc.MaxCallSendMsgSize(po.config.MaxMessageSize),
        ),
        grpc.WithKeepaliveParams(keepalive.ClientParameters{
            Time:                po.config.KeepaliveTime,
            Timeout:             po.config.KeepaliveTimeout,
            PermitWithoutStream: true,
        }),
    }
    
    if po.config.EnableMetrics {
        opts = append(opts, grpc.WithStatsHandler(&metricsHandler{optimizer: po}))
    }
    
    return opts
}

// StreamMetrics 流指标
type StreamMetrics struct {
    ActiveStreams     int64
    TotalStreams      int64
    MessagesSent      int64
    MessagesReceived  int64
    BytesSent         int64
    BytesReceived     int64
    ErrorCount        int64
    AverageLatency    int64 // 纳秒
    Throughput        int64 // 字节/秒
}

// metricsHandler 指标处理器
type metricsHandler struct {
    optimizer *PerformanceOptimizer
}

// TagRPC 标记RPC
func (mh *metricsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
    return ctx
}

// HandleRPC 处理RPC统计
func (mh *metricsHandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
    switch st := s.(type) {
    case *stats.Begin:
        atomic.AddInt64(&mh.optimizer.metrics.ActiveStreams, 1)
        atomic.AddInt64(&mh.optimizer.metrics.TotalStreams, 1)
        
    case *stats.End:
        atomic.AddInt64(&mh.optimizer.metrics.ActiveStreams, -1)
        if st.Error != nil {
            atomic.AddInt64(&mh.optimizer.metrics.ErrorCount, 1)
        }
        
    case *stats.OutPayload:
        atomic.AddInt64(&mh.optimizer.metrics.MessagesSent, 1)
        atomic.AddInt64(&mh.optimizer.metrics.BytesSent, int64(st.Length))
        
    case *stats.InPayload:
        atomic.AddInt64(&mh.optimizer.metrics.MessagesReceived, 1)
        atomic.AddInt64(&mh.optimizer.metrics.BytesReceived, int64(st.Length))
    }
}

// TagConn 标记连接
func (mh *metricsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
    return ctx
}

// HandleConn 处理连接统计
func (mh *metricsHandler) HandleConn(ctx context.Context, s stats.ConnStats) {
    // 连接级别的统计处理
}

// GetMetrics 获取指标
func (po *PerformanceOptimizer) GetMetrics() *StreamMetrics {
    po.mu.RLock()
    defer po.mu.RUnlock()
    
    return &StreamMetrics{
        ActiveStreams:    atomic.LoadInt64(&po.metrics.ActiveStreams),
        TotalStreams:     atomic.LoadInt64(&po.metrics.TotalStreams),
        MessagesSent:     atomic.LoadInt64(&po.metrics.MessagesSent),
        MessagesReceived: atomic.LoadInt64(&po.metrics.MessagesReceived),
        BytesSent:        atomic.LoadInt64(&po.metrics.BytesSent),
        BytesReceived:    atomic.LoadInt64(&po.metrics.BytesReceived),
        ErrorCount:       atomic.LoadInt64(&po.metrics.ErrorCount),
        AverageLatency:   atomic.LoadInt64(&po.metrics.AverageLatency),
        Throughput:       atomic.LoadInt64(&po.metrics.Throughput),
    }
}

// PrintMetrics 打印指标
func (po *PerformanceOptimizer) PrintMetrics() {
    metrics := po.GetMetrics()
    
    log.Printf("=== Stream Metrics ===")
    log.Printf("Active Streams: %d", metrics.ActiveStreams)
    log.Printf("Total Streams: %d", metrics.TotalStreams)
    log.Printf("Messages Sent: %d", metrics.MessagesSent)
    log.Printf("Messages Received: %d", metrics.MessagesReceived)
    log.Printf("Bytes Sent: %d (%.2f MB)", metrics.BytesSent, float64(metrics.BytesSent)/1024/1024)
    log.Printf("Bytes Received: %d (%.2f MB)", metrics.BytesReceived, float64(metrics.BytesReceived)/1024/1024)
    log.Printf("Error Count: %d", metrics.ErrorCount)
    log.Printf("Average Latency: %.2f ms", float64(metrics.AverageLatency)/1000000)
    log.Printf("Throughput: %.2f MB/s", float64(metrics.Throughput)/1024/1024)
}

// StartMetricsReporting 启动指标报告
func (po *PerformanceOptimizer) StartMetricsReporting() {
    if !po.config.EnableMetrics {
        return
    }
    
    ticker := time.NewTicker(po.config.MetricsInterval)
    go func() {
        for range ticker.C {
            po.PrintMetrics()
        }
    }()
}

// BufferedStreamWrapper 缓冲流包装器
type BufferedStreamWrapper struct {
    stream     grpc.ServerStream
    sendBuffer chan interface{}
    recvBuffer chan interface{}
    ctx        context.Context
    cancel     context.CancelFunc
    wg         sync.WaitGroup
}

// NewBufferedStreamWrapper 创建缓冲流包装器
func NewBufferedStreamWrapper(stream grpc.ServerStream, bufferSize int) *BufferedStreamWrapper {
    ctx, cancel := context.WithCancel(stream.Context())
    
    wrapper := &BufferedStreamWrapper{
        stream:     stream,
        sendBuffer: make(chan interface{}, bufferSize),
        recvBuffer: make(chan interface{}, bufferSize),
        ctx:        ctx,
        cancel:     cancel,
    }
    
    // 启动发送协程
    wrapper.wg.Add(1)
    go wrapper.sendWorker()
    
    return wrapper
}

// Send 发送消息
func (bsw *BufferedStreamWrapper) Send(msg interface{}) error {
    select {
    case bsw.sendBuffer <- msg:
        return nil
    case <-bsw.ctx.Done():
        return bsw.ctx.Err()
    }
}

// sendWorker 发送工作协程
func (bsw *BufferedStreamWrapper) sendWorker() {
    defer bsw.wg.Done()
    
    for {
        select {
        case msg := <-bsw.sendBuffer:
            if err := bsw.stream.SendMsg(msg); err != nil {
                log.Printf("Error sending message: %v", err)
                return
            }
        case <-bsw.ctx.Done():
            return
        }
    }
}

// Close 关闭包装器
func (bsw *BufferedStreamWrapper) Close() {
    bsw.cancel()
    close(bsw.sendBuffer)
    bsw.wg.Wait()
}

// WorkerPool 工作池
type WorkerPool struct {
    workers    int
    taskQueue  chan func()
    wg         sync.WaitGroup
    ctx        context.Context
    cancel     context.CancelFunc
}

// NewWorkerPool 创建工作池
func NewWorkerPool(workers int) *WorkerPool {
    ctx, cancel := context.WithCancel(context.Background())
    
    pool := &WorkerPool{
        workers:   workers,
        taskQueue: make(chan func(), workers*2),
        ctx:       ctx,
        cancel:    cancel,
    }
    
    // 启动工作协程
    for i := 0; i < workers; i++ {
        pool.wg.Add(1)
        go pool.worker()
    }
    
    return pool
}

// Submit 提交任务
func (wp *WorkerPool) Submit(task func()) error {
    select {
    case wp.taskQueue <- task:
        return nil
    case <-wp.ctx.Done():
        return wp.ctx.Err()
    }
}

// worker 工作协程
func (wp *WorkerPool) worker() {
    defer wp.wg.Done()
    
    for {
        select {
        case task := <-wp.taskQueue:
            task()
        case <-wp.ctx.Done():
            return
        }
    }
}

// Close 关闭工作池
func (wp *WorkerPool) Close() {
    wp.cancel()
    close(wp.taskQueue)
    wp.wg.Wait()
}

// 使用示例
func ExamplePerformanceOptimization() {
    // 创建性能优化器
    config := DefaultPerformanceConfig()
    optimizer := NewPerformanceOptimizer(config)
    
    // 优化服务器
    serverOpts := optimizer.OptimizeServerOptions()
    server := grpc.NewServer(serverOpts...)
    
    // 启动指标报告
    optimizer.StartMetricsReporting()
    
    // 创建工作池
    pool := NewWorkerPool(config.WorkerPoolSize)
    defer pool.Close()
    
    log.Println("Performance optimization enabled")
}
"""

# 创建流处理器实例
stream_processor = StreamProcessor(StreamConfig())

# 生成服务端流示例
server_stream = stream_processor.create_server_streaming_example()
print("=== 服务端流处理 ===")
print("✓ 文件下载流")
print("✓ 日志流")
print("✓ 指标流")
print("✓ 数据处理流")

# 生成客户端流示例
client_stream = stream_processor.create_client_streaming_example()
print("\n=== 客户端流处理 ===")
print("✓ 文件上传流")
print("✓ 批量数据流")
print("✓ 聊天流")
print("✓ 实时数据流")

# 生成性能优化配置
perf_optimization = stream_processor.create_performance_optimization()
print("\n=== 性能优化 ===")
print("✓ 缓冲区优化")
print("✓ 并发控制")
print("✓ 流控管理")
print("✓ 指标监控")

流式处理最佳实践

1. 缓冲区管理

// 合理设置缓冲区大小
grpc.WithInitialWindowSize(64 * 1024)        // 64KB
grpc.WithInitialConnWindowSize(1024 * 1024)  // 1MB

2. 背压处理

// 实现背压控制
type BackpressureController struct {
    maxBufferSize int
    currentBuffer int
    mu           sync.Mutex
}

func (bc *BackpressureController) CanSend() bool {
    bc.mu.Lock()
    defer bc.mu.Unlock()
    return bc.currentBuffer < bc.maxBufferSize
}

3. 错误恢复

// 流错误恢复
func (s *StreamingServer) handleStreamError(err error) {
    if status.Code(err) == codes.Unavailable {
        // 重试逻辑
        time.Sleep(time.Second)
        // 重新建立流
    }
}

4. 资源清理

// 确保资源清理
defer func() {
    if stream != nil {
        stream.CloseSend()
    }
    if conn != nil {
        conn.Close()
    }
}()

总结

通过本章的学习,您应该已经掌握了:

  1. 流式处理模式:四种流式处理的深入应用
  2. 性能优化:缓冲区、并发、流控等优化策略
  3. 错误处理:流式处理中的错误恢复和重试机制
  4. 监控调试:流式应用的指标监控和性能分析
  5. 最佳实践:缓冲区管理、背压处理、资源清理

在下一章中,我们将学习 gRPC 的安全机制,包括 TLS/SSL、认证授权等安全特性。