概述
本章将深入探讨 gRPC 中间件与拦截器的开发和应用,包括拦截器的工作原理、常用中间件模式、自定义拦截器开发、拦截器链管理等。我们将学习如何通过中间件实现横切关注点,提升应用的可维护性和扩展性。
学习目标
- 理解 gRPC 拦截器的工作原理和类型
- 掌握常用中间件的实现和应用
- 学习自定义拦截器的开发方法
- 了解拦截器链的管理和优化
- 掌握中间件的最佳实践
拦截器基础
from enum import Enum
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Callable, Union
from abc import ABC, abstractmethod
import time
import logging
import json
import uuid
from contextlib import contextmanager
class InterceptorType(Enum):
"""拦截器类型枚举"""
UNARY = "unary"
STREAM = "stream"
CLIENT_STREAM = "client_stream"
SERVER_STREAM = "server_stream"
BIDIRECTIONAL_STREAM = "bidirectional_stream"
class InterceptorPhase(Enum):
"""拦截器阶段枚举"""
PRE_CALL = "pre_call"
POST_CALL = "post_call"
ON_ERROR = "on_error"
ON_COMPLETE = "on_complete"
class LogLevel(Enum):
"""日志级别枚举"""
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"
FATAL = "fatal"
@dataclass
class InterceptorConfig:
"""拦截器配置"""
name: str
enabled: bool = True
priority: int = 0
timeout_ms: int = 5000
retry_count: int = 3
log_level: LogLevel = LogLevel.INFO
metadata: Dict[str, Any] = None
@dataclass
class CallContext:
"""调用上下文"""
request_id: str
method: str
service: str
client_ip: str
user_agent: str
start_time: float
metadata: Dict[str, str]
user_id: Optional[str] = None
trace_id: Optional[str] = None
span_id: Optional[str] = None
class InterceptorManager:
"""拦截器管理器"""
def __init__(self):
self.interceptors = []
self.config = {}
def create_basic_interceptors(self) -> str:
"""创建基础拦截器"""
return """
// interceptors.go - 基础拦截器实现
package middleware
import (
"context"
"fmt"
"log"
"runtime"
"strings"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
// CallInfo 调用信息
type CallInfo struct {
RequestID string
Method string
Service string
ClientIP string
UserAgent string
StartTime time.Time
EndTime time.Time
Duration time.Duration
StatusCode codes.Code
Error error
Metadata map[string]string
UserID string
TraceID string
SpanID string
}
// LoggingInterceptor 日志拦截器
func LoggingInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
start := time.Now()
// 生成请求ID
requestID := generateRequestID()
ctx = context.WithValue(ctx, "request_id", requestID)
// 获取调用信息
callInfo := extractCallInfo(ctx, info.FullMethod, requestID)
// 记录请求开始
log.Printf("[%s] START %s from %s", requestID, info.FullMethod, callInfo.ClientIP)
// 执行处理器
resp, err := handler(ctx, req)
// 更新调用信息
callInfo.EndTime = time.Now()
callInfo.Duration = callInfo.EndTime.Sub(callInfo.StartTime)
callInfo.Error = err
if err != nil {
callInfo.StatusCode = status.Code(err)
log.Printf("[%s] ERROR %s: %v (duration: %v)",
requestID, info.FullMethod, err, callInfo.Duration)
} else {
callInfo.StatusCode = codes.OK
log.Printf("[%s] SUCCESS %s (duration: %v)",
requestID, info.FullMethod, callInfo.Duration)
}
return resp, err
}
}
// StreamLoggingInterceptor 流日志拦截器
func StreamLoggingInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
start := time.Now()
// 生成请求ID
requestID := generateRequestID()
// 获取调用信息
callInfo := extractCallInfo(ss.Context(), info.FullMethod, requestID)
// 记录流开始
log.Printf("[%s] STREAM START %s from %s", requestID, info.FullMethod, callInfo.ClientIP)
// 包装流
wrappedStream := &wrappedServerStream{
ServerStream: ss,
ctx: context.WithValue(ss.Context(), "request_id", requestID),
}
// 执行处理器
err := handler(srv, wrappedStream)
// 记录流结束
duration := time.Since(start)
if err != nil {
log.Printf("[%s] STREAM ERROR %s: %v (duration: %v)",
requestID, info.FullMethod, err, duration)
} else {
log.Printf("[%s] STREAM SUCCESS %s (duration: %v)",
requestID, info.FullMethod, duration)
}
return err
}
}
// MetricsInterceptor 指标拦截器
func MetricsInterceptor(collector MetricsCollector) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
start := time.Now()
// 增加请求计数
collector.IncRequestCount(info.FullMethod)
// 执行处理器
resp, err := handler(ctx, req)
// 记录指标
duration := time.Since(start)
collector.RecordDuration(info.FullMethod, duration)
if err != nil {
collector.IncErrorCount(info.FullMethod, status.Code(err))
} else {
collector.IncSuccessCount(info.FullMethod)
}
return resp, err
}
}
// RecoveryInterceptor 恢复拦截器
func RecoveryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
defer func() {
if r := recover(); r != nil {
// 记录panic信息
stack := make([]byte, 64<<10) // 64KB
stack = stack[:runtime.Stack(stack, false)]
log.Printf("PANIC in %s: %v\n%s", info.FullMethod, r, stack)
// 返回内部错误
err = status.Errorf(codes.Internal, "internal server error")
}
}()
return handler(ctx, req)
}
}
// TimeoutInterceptor 超时拦截器
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 创建带超时的上下文
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// 使用通道来处理超时
type result struct {
resp interface{}
err error
}
resultChan := make(chan result, 1)
go func() {
resp, err := handler(ctx, req)
resultChan <- result{resp: resp, err: err}
}()
select {
case res := <-resultChan:
return res.resp, res.err
case <-ctx.Done():
return nil, status.Errorf(codes.DeadlineExceeded, "request timeout")
}
}
}
// ValidationInterceptor 验证拦截器
func ValidationInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 检查请求是否实现了验证接口
if validator, ok := req.(interface{ Validate() error }); ok {
if err := validator.Validate(); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "validation failed: %v", err)
}
}
return handler(ctx, req)
}
}
// RateLimitInterceptor 速率限制拦截器
func RateLimitInterceptor(limiter RateLimiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 获取客户端IP
clientIP := "unknown"
if p, ok := peer.FromContext(ctx); ok {
clientIP = p.Addr.String()
}
// 检查速率限制
if !limiter.Allow(clientIP, info.FullMethod) {
return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded")
}
return handler(ctx, req)
}
}
// TracingInterceptor 链路追踪拦截器
func TracingInterceptor(tracer Tracer) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 开始跨度
span := tracer.StartSpan(ctx, info.FullMethod)
defer span.Finish()
// 添加标签
span.SetTag("grpc.method", info.FullMethod)
span.SetTag("component", "grpc-server")
// 获取客户端信息
if p, ok := peer.FromContext(ctx); ok {
span.SetTag("peer.address", p.Addr.String())
}
// 将跨度添加到上下文
ctx = tracer.ContextWithSpan(ctx, span)
// 执行处理器
resp, err := handler(ctx, req)
// 记录错误
if err != nil {
span.SetTag("error", true)
span.SetTag("grpc.status_code", status.Code(err).String())
span.LogFields("error.message", err.Error())
} else {
span.SetTag("grpc.status_code", codes.OK.String())
}
return resp, err
}
}
// 辅助函数和接口
func generateRequestID() string {
return fmt.Sprintf("req_%d", time.Now().UnixNano())
}
func extractCallInfo(ctx context.Context, method, requestID string) *CallInfo {
callInfo := &CallInfo{
RequestID: requestID,
Method: method,
StartTime: time.Now(),
Metadata: make(map[string]string),
}
// 提取服务名
parts := strings.Split(method, "/")
if len(parts) >= 2 {
callInfo.Service = parts[len(parts)-2]
}
// 获取客户端信息
if p, ok := peer.FromContext(ctx); ok {
callInfo.ClientIP = p.Addr.String()
}
// 获取元数据
if md, ok := metadata.FromIncomingContext(ctx); ok {
for key, values := range md {
if len(values) > 0 {
callInfo.Metadata[key] = values[0]
}
}
// 获取用户代理
if userAgent := md.Get("user-agent"); len(userAgent) > 0 {
callInfo.UserAgent = userAgent[0]
}
// 获取追踪信息
if traceID := md.Get("x-trace-id"); len(traceID) > 0 {
callInfo.TraceID = traceID[0]
}
if spanID := md.Get("x-span-id"); len(spanID) > 0 {
callInfo.SpanID = spanID[0]
}
}
return callInfo
}
// wrappedServerStream 包装的服务器流
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}
// MetricsCollector 指标收集器接口
type MetricsCollector interface {
IncRequestCount(method string)
IncSuccessCount(method string)
IncErrorCount(method string, code codes.Code)
RecordDuration(method string, duration time.Duration)
}
// RateLimiter 速率限制器接口
type RateLimiter interface {
Allow(clientIP, method string) bool
}
// Tracer 追踪器接口
type Tracer interface {
StartSpan(ctx context.Context, operationName string) Span
ContextWithSpan(ctx context.Context, span Span) context.Context
}
// Span 跨度接口
type Span interface {
SetTag(key string, value interface{})
LogFields(fields ...interface{})
Finish()
}
"""
def create_advanced_interceptors(self) -> str:
"""创建高级拦截器"""
return """
// advanced_interceptors.go - 高级拦截器实现
package middleware
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// CircuitBreakerInterceptor 熔断器拦截器
func CircuitBreakerInterceptor(cb CircuitBreaker) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 检查熔断器状态
if !cb.Allow(info.FullMethod) {
return nil, status.Errorf(codes.Unavailable, "circuit breaker is open")
}
// 执行处理器
resp, err := handler(ctx, req)
// 记录结果
if err != nil {
cb.RecordFailure(info.FullMethod)
} else {
cb.RecordSuccess(info.FullMethod)
}
return resp, err
}
}
// CacheInterceptor 缓存拦截器
func CacheInterceptor(cache Cache) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 检查是否为只读方法
if !isReadOnlyMethod(info.FullMethod) {
return handler(ctx, req)
}
// 生成缓存键
cacheKey := generateCacheKey(info.FullMethod, req)
// 尝试从缓存获取
if cached, found := cache.Get(cacheKey); found {
log.Printf("Cache HIT for %s", info.FullMethod)
return cached, nil
}
// 执行处理器
resp, err := handler(ctx, req)
// 缓存成功响应
if err == nil {
cache.Set(cacheKey, resp, getCacheTTL(info.FullMethod))
log.Printf("Cache SET for %s", info.FullMethod)
}
return resp, err
}
}
// RetryInterceptor 重试拦截器
func RetryInterceptor(config RetryConfig) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var lastErr error
for attempt := 0; attempt <= config.MaxRetries; attempt++ {
if attempt > 0 {
// 等待重试间隔
select {
case <-time.After(config.RetryDelay):
case <-ctx.Done():
return nil, ctx.Err()
}
log.Printf("Retrying %s (attempt %d/%d)", info.FullMethod, attempt+1, config.MaxRetries+1)
}
// 执行处理器
resp, err := handler(ctx, req)
// 检查是否需要重试
if err == nil || !isRetryableError(err, config.RetryableCodes) {
return resp, err
}
lastErr = err
}
return nil, fmt.Errorf("max retries exceeded: %w", lastErr)
}
}
// BulkheadInterceptor 舱壁隔离拦截器
func BulkheadInterceptor(semaphore Semaphore) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 尝试获取信号量
if !semaphore.TryAcquire(info.FullMethod) {
return nil, status.Errorf(codes.ResourceExhausted, "service overloaded")
}
defer semaphore.Release(info.FullMethod)
return handler(ctx, req)
}
}
// CompressionInterceptor 压缩拦截器
func CompressionInterceptor(compressor Compressor) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 执行处理器
resp, err := handler(ctx, req)
if err != nil {
return resp, err
}
// 检查是否需要压缩
if shouldCompress(ctx, resp) {
compressed, compressErr := compressor.Compress(resp)
if compressErr != nil {
log.Printf("Compression failed: %v", compressErr)
return resp, err
}
// 添加压缩标识到元数据
grpc.SetHeader(ctx, metadata.Pairs("content-encoding", compressor.Name()))
return compressed, err
}
return resp, err
}
}
// AuditInterceptor 审计拦截器
func AuditInterceptor(auditor Auditor) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
start := time.Now()
// 创建审计记录
auditRecord := &AuditRecord{
RequestID: getRequestID(ctx),
Method: info.FullMethod,
StartTime: start,
Request: req,
}
// 获取用户信息
if userID := getUserID(ctx); userID != "" {
auditRecord.UserID = userID
}
// 获取客户端信息
if clientIP := getClientIP(ctx); clientIP != "" {
auditRecord.ClientIP = clientIP
}
// 执行处理器
resp, err := handler(ctx, req)
// 完成审计记录
auditRecord.EndTime = time.Now()
auditRecord.Duration = auditRecord.EndTime.Sub(auditRecord.StartTime)
auditRecord.Response = resp
auditRecord.Error = err
if err != nil {
auditRecord.StatusCode = status.Code(err)
} else {
auditRecord.StatusCode = codes.OK
}
// 记录审计
go auditor.Record(auditRecord)
return resp, err
}
}
// TransactionInterceptor 事务拦截器
func TransactionInterceptor(txManager TransactionManager) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 检查是否需要事务
if !needsTransaction(info.FullMethod) {
return handler(ctx, req)
}
// 开始事务
tx, err := txManager.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
// 将事务添加到上下文
ctx = txManager.ContextWithTransaction(ctx, tx)
// 执行处理器
resp, err := handler(ctx, req)
if err != nil {
// 回滚事务
if rollbackErr := tx.Rollback(); rollbackErr != nil {
log.Printf("Failed to rollback transaction: %v", rollbackErr)
}
return nil, err
}
// 提交事务
if commitErr := tx.Commit(); commitErr != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", commitErr)
}
return resp, nil
}
}
// 支持接口和类型定义
type CircuitBreaker interface {
Allow(method string) bool
RecordSuccess(method string)
RecordFailure(method string)
}
type Cache interface {
Get(key string) (interface{}, bool)
Set(key string, value interface{}, ttl time.Duration)
Delete(key string)
}
type RetryConfig struct {
MaxRetries int
RetryDelay time.Duration
RetryableCodes []codes.Code
}
type Semaphore interface {
TryAcquire(method string) bool
Release(method string)
}
type Compressor interface {
Name() string
Compress(data interface{}) (interface{}, error)
Decompress(data interface{}) (interface{}, error)
}
type Auditor interface {
Record(record *AuditRecord)
}
type AuditRecord struct {
RequestID string
Method string
UserID string
ClientIP string
StartTime time.Time
EndTime time.Time
Duration time.Duration
Request interface{}
Response interface{}
Error error
StatusCode codes.Code
}
type TransactionManager interface {
Begin(ctx context.Context) (Transaction, error)
ContextWithTransaction(ctx context.Context, tx Transaction) context.Context
}
type Transaction interface {
Commit() error
Rollback() error
}
// 辅助函数
func isReadOnlyMethod(method string) bool {
readOnlyMethods := []string{"Get", "List", "Search", "Query"}
for _, readOnly := range readOnlyMethods {
if strings.Contains(method, readOnly) {
return true
}
}
return false
}
func generateCacheKey(method string, req interface{}) string {
reqBytes, _ := json.Marshal(req)
return fmt.Sprintf("%s:%x", method, reqBytes)
}
func getCacheTTL(method string) time.Duration {
// 根据方法返回不同的TTL
if strings.Contains(method, "List") {
return 5 * time.Minute
}
return 1 * time.Minute
}
func isRetryableError(err error, retryableCodes []codes.Code) bool {
code := status.Code(err)
for _, retryableCode := range retryableCodes {
if code == retryableCode {
return true
}
}
return false
}
func shouldCompress(ctx context.Context, resp interface{}) bool {
// 检查响应大小和类型
respBytes, _ := json.Marshal(resp)
return len(respBytes) > 1024 // 大于1KB时压缩
}
func getRequestID(ctx context.Context) string {
if requestID, ok := ctx.Value("request_id").(string); ok {
return requestID
}
return "unknown"
}
func getUserID(ctx context.Context) string {
if userID, ok := ctx.Value("user_id").(string); ok {
return userID
}
return ""
}
func getClientIP(ctx context.Context) string {
if clientIP, ok := ctx.Value("client_ip").(string); ok {
return clientIP
}
return ""
}
func needsTransaction(method string) bool {
transactionalMethods := []string{"Create", "Update", "Delete", "Transfer"}
for _, txMethod := range transactionalMethods {
if strings.Contains(method, txMethod) {
return true
}
}
return false
}
"""
def create_interceptor_chain(self) -> str:
"""创建拦截器链管理"""
return """
// interceptor_chain.go - 拦截器链管理
package middleware
import (
"context"
"fmt"
"log"
"sort"
"sync"
"google.golang.org/grpc"
)
// InterceptorChain 拦截器链
type InterceptorChain struct {
mu sync.RWMutex
interceptors []InterceptorWrapper
enabled bool
}
// InterceptorWrapper 拦截器包装器
type InterceptorWrapper struct {
Name string
Priority int
Enabled bool
Unary grpc.UnaryServerInterceptor
Stream grpc.StreamServerInterceptor
Description string
Tags []string
}
// NewInterceptorChain 创建拦截器链
func NewInterceptorChain() *InterceptorChain {
return &InterceptorChain{
interceptors: make([]InterceptorWrapper, 0),
enabled: true,
}
}
// AddInterceptor 添加拦截器
func (ic *InterceptorChain) AddInterceptor(wrapper InterceptorWrapper) {
ic.mu.Lock()
defer ic.mu.Unlock()
ic.interceptors = append(ic.interceptors, wrapper)
// 按优先级排序
sort.Slice(ic.interceptors, func(i, j int) bool {
return ic.interceptors[i].Priority < ic.interceptors[j].Priority
})
log.Printf("Added interceptor: %s (priority: %d)", wrapper.Name, wrapper.Priority)
}
// RemoveInterceptor 移除拦截器
func (ic *InterceptorChain) RemoveInterceptor(name string) bool {
ic.mu.Lock()
defer ic.mu.Unlock()
for i, interceptor := range ic.interceptors {
if interceptor.Name == name {
ic.interceptors = append(ic.interceptors[:i], ic.interceptors[i+1:]...)
log.Printf("Removed interceptor: %s", name)
return true
}
}
return false
}
// EnableInterceptor 启用拦截器
func (ic *InterceptorChain) EnableInterceptor(name string) bool {
ic.mu.Lock()
defer ic.mu.Unlock()
for i, interceptor := range ic.interceptors {
if interceptor.Name == name {
ic.interceptors[i].Enabled = true
log.Printf("Enabled interceptor: %s", name)
return true
}
}
return false
}
// DisableInterceptor 禁用拦截器
func (ic *InterceptorChain) DisableInterceptor(name string) bool {
ic.mu.Lock()
defer ic.mu.Unlock()
for i, interceptor := range ic.interceptors {
if interceptor.Name == name {
ic.interceptors[i].Enabled = false
log.Printf("Disabled interceptor: %s", name)
return true
}
}
return false
}
// GetUnaryInterceptor 获取一元拦截器链
func (ic *InterceptorChain) GetUnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if !ic.enabled {
return handler(ctx, req)
}
ic.mu.RLock()
interceptors := make([]grpc.UnaryServerInterceptor, 0)
for _, wrapper := range ic.interceptors {
if wrapper.Enabled && wrapper.Unary != nil {
interceptors = append(interceptors, wrapper.Unary)
}
}
ic.mu.RUnlock()
return chainUnaryInterceptors(interceptors)(ctx, req, info, handler)
}
}
// GetStreamInterceptor 获取流拦截器链
func (ic *InterceptorChain) GetStreamInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if !ic.enabled {
return handler(srv, ss)
}
ic.mu.RLock()
interceptors := make([]grpc.StreamServerInterceptor, 0)
for _, wrapper := range ic.interceptors {
if wrapper.Enabled && wrapper.Stream != nil {
interceptors = append(interceptors, wrapper.Stream)
}
}
ic.mu.RUnlock()
return chainStreamInterceptors(interceptors)(srv, ss, info, handler)
}
}
// ListInterceptors 列出所有拦截器
func (ic *InterceptorChain) ListInterceptors() []InterceptorInfo {
ic.mu.RLock()
defer ic.mu.RUnlock()
infos := make([]InterceptorInfo, len(ic.interceptors))
for i, wrapper := range ic.interceptors {
infos[i] = InterceptorInfo{
Name: wrapper.Name,
Priority: wrapper.Priority,
Enabled: wrapper.Enabled,
Description: wrapper.Description,
Tags: wrapper.Tags,
HasUnary: wrapper.Unary != nil,
HasStream: wrapper.Stream != nil,
}
}
return infos
}
// InterceptorInfo 拦截器信息
type InterceptorInfo struct {
Name string `json:"name"`
Priority int `json:"priority"`
Enabled bool `json:"enabled"`
Description string `json:"description"`
Tags []string `json:"tags"`
HasUnary bool `json:"has_unary"`
HasStream bool `json:"has_stream"`
}
// Enable 启用拦截器链
func (ic *InterceptorChain) Enable() {
ic.mu.Lock()
defer ic.mu.Unlock()
ic.enabled = true
log.Println("Interceptor chain enabled")
}
// Disable 禁用拦截器链
func (ic *InterceptorChain) Disable() {
ic.mu.Lock()
defer ic.mu.Unlock()
ic.enabled = false
log.Println("Interceptor chain disabled")
}
// chainUnaryInterceptors 链接一元拦截器
func chainUnaryInterceptors(interceptors []grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
if len(interceptors) == 0 {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return handler(ctx, req)
}
}
if len(interceptors) == 1 {
return interceptors[0]
}
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
}
}
// chainStreamInterceptors 链接流拦截器
func chainStreamInterceptors(interceptors []grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
if len(interceptors) == 0 {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return handler(srv, ss)
}
}
if len(interceptors) == 1 {
return interceptors[0]
}
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
}
}
// getChainUnaryHandler 获取链式一元处理器
func getChainUnaryHandler(interceptors []grpc.UnaryServerInterceptor, curr int, info *grpc.UnaryServerInfo, finalHandler grpc.UnaryHandler) grpc.UnaryHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
}
}
// getChainStreamHandler 获取链式流处理器
func getChainStreamHandler(interceptors []grpc.StreamServerInterceptor, curr int, info *grpc.StreamServerInfo, finalHandler grpc.StreamHandler) grpc.StreamHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(srv interface{}, stream grpc.ServerStream) error {
return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
}
}
// InterceptorBuilder 拦截器构建器
type InterceptorBuilder struct {
chain *InterceptorChain
}
// NewInterceptorBuilder 创建拦截器构建器
func NewInterceptorBuilder() *InterceptorBuilder {
return &InterceptorBuilder{
chain: NewInterceptorChain(),
}
}
// WithLogging 添加日志拦截器
func (ib *InterceptorBuilder) WithLogging(priority int) *InterceptorBuilder {
ib.chain.AddInterceptor(InterceptorWrapper{
Name: "logging",
Priority: priority,
Enabled: true,
Unary: LoggingInterceptor(),
Stream: StreamLoggingInterceptor(),
Description: "Request/response logging interceptor",
Tags: []string{"logging", "observability"},
})
return ib
}
// WithMetrics 添加指标拦截器
func (ib *InterceptorBuilder) WithMetrics(collector MetricsCollector, priority int) *InterceptorBuilder {
ib.chain.AddInterceptor(InterceptorWrapper{
Name: "metrics",
Priority: priority,
Enabled: true,
Unary: MetricsInterceptor(collector),
Description: "Metrics collection interceptor",
Tags: []string{"metrics", "observability"},
})
return ib
}
// WithRecovery 添加恢复拦截器
func (ib *InterceptorBuilder) WithRecovery(priority int) *InterceptorBuilder {
ib.chain.AddInterceptor(InterceptorWrapper{
Name: "recovery",
Priority: priority,
Enabled: true,
Unary: RecoveryInterceptor(),
Description: "Panic recovery interceptor",
Tags: []string{"recovery", "stability"},
})
return ib
}
// WithTimeout 添加超时拦截器
func (ib *InterceptorBuilder) WithTimeout(timeout time.Duration, priority int) *InterceptorBuilder {
ib.chain.AddInterceptor(InterceptorWrapper{
Name: "timeout",
Priority: priority,
Enabled: true,
Unary: TimeoutInterceptor(timeout),
Description: fmt.Sprintf("Request timeout interceptor (%v)", timeout),
Tags: []string{"timeout", "reliability"},
})
return ib
}
// WithValidation 添加验证拦截器
func (ib *InterceptorBuilder) WithValidation(priority int) *InterceptorBuilder {
ib.chain.AddInterceptor(InterceptorWrapper{
Name: "validation",
Priority: priority,
Enabled: true,
Unary: ValidationInterceptor(),
Description: "Request validation interceptor",
Tags: []string{"validation", "security"},
})
return ib
}
// Build 构建拦截器链
func (ib *InterceptorBuilder) Build() *InterceptorChain {
return ib.chain
}
"""
# 创建拦截器管理器实例
interceptor_mgr = InterceptorManager()
# 生成基础拦截器
basic_interceptors = interceptor_mgr.create_basic_interceptors()
print("=== 基础拦截器 ===")
print("✓ 日志拦截器")
print("✓ 指标拦截器")
print("✓ 恢复拦截器")
print("✓ 超时拦截器")
print("✓ 验证拦截器")
print("✓ 速率限制拦截器")
print("✓ 链路追踪拦截器")
# 生成高级拦截器
advanced_interceptors = interceptor_mgr.create_advanced_interceptors()
print("\n=== 高级拦截器 ===")
print("✓ 熔断器拦截器")
print("✓ 缓存拦截器")
print("✓ 重试拦截器")
print("✓ 舱壁隔离拦截器")
print("✓ 压缩拦截器")
print("✓ 审计拦截器")
print("✓ 事务拦截器")
# 生成拦截器链管理
interceptor_chain = interceptor_mgr.create_interceptor_chain()
print("\n=== 拦截器链管理 ===")
print("✓ 拦截器链组织")
print("✓ 优先级排序")
print("✓ 动态启用/禁用")
print("✓ 链式调用")
print("✓ 构建器模式")
自定义拦截器开发
1. 拦截器接口
// 一元拦截器接口
type UnaryServerInterceptor func(
ctx context.Context,
req interface{},
info *UnaryServerInfo,
handler UnaryHandler,
) (resp interface{}, err error)
// 流拦截器接口
type StreamServerInterceptor func(
srv interface{},
ss ServerStream,
info *StreamServerInfo,
handler StreamHandler,
) error
2. 自定义拦截器示例
// 自定义业务拦截器
func CustomBusinessInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 前置处理
log.Printf("Before: %s", info.FullMethod)
// 执行业务逻辑
resp, err := handler(ctx, req)
// 后置处理
log.Printf("After: %s", info.FullMethod)
return resp, err
}
}
3. 拦截器最佳实践
// 最佳实践示例
func BestPracticeInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 1. 快速失败检查
if ctx.Err() != nil {
return nil, ctx.Err()
}
// 2. 资源清理
defer func() {
// 清理资源
}()
// 3. 错误处理
resp, err := handler(ctx, req)
if err != nil {
// 错误处理逻辑
return nil, err
}
return resp, nil
}
}
总结
本章深入探讨了 gRPC 中间件与拦截器的开发和应用,主要内容包括:
核心要点
拦截器基础
- 拦截器类型和工作原理
- 一元和流拦截器
- 拦截器生命周期
- 上下文传递
常用中间件
- 日志记录
- 指标收集
- 错误恢复
- 超时控制
- 请求验证
- 速率限制
- 链路追踪
高级中间件
- 熔断器
- 缓存机制
- 重试逻辑
- 舱壁隔离
- 数据压缩
- 审计日志
- 事务管理
拦截器链管理
- 链式组织
- 优先级控制
- 动态管理
- 构建器模式
最佳实践
设计原则
- 单一职责
- 最小侵入
- 高性能
- 可配置
性能优化
- 避免重复计算
- 合理使用缓存
- 异步处理
- 资源池化
错误处理
- 优雅降级
- 错误传播
- 日志记录
- 监控告警
可观测性
- 指标暴露
- 链路追踪
- 日志结构化
- 健康检查
下一步学习
- 学习服务网格中间件
- 掌握分布式追踪
- 了解可观测性最佳实践
- 实践微服务治理
通过本章学习,你已经掌握了 gRPC 中间件与拦截器的核心技能,能够构建功能强大、可维护的服务治理体系。