微服务安全挑战
微服务架构带来了独特的安全挑战:
1. 攻击面扩大
- 服务数量增加:更多的服务意味着更多的潜在攻击点
- 网络通信复杂:服务间通信增加了网络攻击风险
- API暴露增多:每个服务都可能暴露API端点
2. 身份认证复杂
- 多服务认证:需要在多个服务间传递和验证身份
- 令牌管理:JWT、OAuth等令牌的生成、验证和刷新
- 会话管理:分布式环境下的会话状态管理
3. 授权控制困难
- 细粒度权限:需要实现服务级、方法级的权限控制
- 动态授权:根据上下文动态调整权限
- 权限传播:在服务调用链中传播权限信息
4. 数据安全
- 传输加密:服务间通信的加密
- 存储加密:敏感数据的存储加密
- 数据脱敏:日志和监控中的敏感信息处理
安全架构概览
# 微服务安全架构配置
apiVersion: v1
kind: ConfigMap
metadata:
name: security-architecture-config
data:
architecture.yaml: |
# 微服务安全架构
security_architecture:
# 边界安全层
perimeter_security:
api_gateway:
- rate_limiting
- ddos_protection
- ssl_termination
- request_validation
load_balancer:
- ssl_offloading
- health_checks
- traffic_filtering
firewall:
- network_segmentation
- port_filtering
- ip_whitelisting
# 身份认证层
authentication:
identity_providers:
- oauth2_server
- openid_connect
- ldap_integration
- social_login
token_management:
- jwt_tokens
- refresh_tokens
- token_validation
- token_revocation
multi_factor_auth:
- sms_verification
- email_verification
- totp_authentication
# 授权控制层
authorization:
access_control:
- rbac (role-based)
- abac (attribute-based)
- policy_engine
- permission_matrix
api_security:
- endpoint_protection
- method_level_auth
- resource_level_auth
- scope_validation
# 通信安全层
communication_security:
service_mesh:
- mtls_encryption
- traffic_encryption
- certificate_management
- identity_verification
api_security:
- request_signing
- response_validation
- api_versioning
- cors_policy
# 数据安全层
data_security:
encryption:
- data_at_rest
- data_in_transit
- key_management
- encryption_algorithms
data_protection:
- data_masking
- data_anonymization
- pii_protection
- gdpr_compliance
# 监控审计层
monitoring_audit:
security_monitoring:
- intrusion_detection
- anomaly_detection
- threat_intelligence
- security_alerts
audit_logging:
- access_logs
- security_events
- compliance_reports
- forensic_analysis
数据加密与传输安全
TLS/SSL 配置
package security
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"time"
)
// TLSConfig TLS配置
type TLSConfig struct {
CertFile string
KeyFile string
CAFile string
ServerName string
InsecureSkip bool
MinVersion uint16
MaxVersion uint16
CipherSuites []uint16
}
// TLSManager TLS管理器
type TLSManager struct {
config *TLSConfig
cert tls.Certificate
caPool *x509.CertPool
}
// NewTLSManager 创建TLS管理器
func NewTLSManager(config *TLSConfig) (*TLSManager, error) {
manager := &TLSManager{
config: config,
}
// 加载证书
if config.CertFile != "" && config.KeyFile != "" {
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load certificate: %v", err)
}
manager.cert = cert
}
// 加载CA证书
if config.CAFile != "" {
caData, err := ioutil.ReadFile(config.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file: %v", err)
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caData) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
manager.caPool = caPool
}
return manager, nil
}
// GetServerTLSConfig 获取服务器TLS配置
func (tm *TLSManager) GetServerTLSConfig() *tls.Config {
return &tls.Config{
Certificates: []tls.Certificate{tm.cert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: tm.caPool,
MinVersion: tm.config.MinVersion,
MaxVersion: tm.config.MaxVersion,
CipherSuites: tm.config.CipherSuites,
}
}
// GetClientTLSConfig 获取客户端TLS配置
func (tm *TLSManager) GetClientTLSConfig() *tls.Config {
return &tls.Config{
Certificates: []tls.Certificate{tm.cert},
RootCAs: tm.caPool,
ServerName: tm.config.ServerName,
InsecureSkipVerify: tm.config.InsecureSkip,
MinVersion: tm.config.MinVersion,
MaxVersion: tm.config.MaxVersion,
CipherSuites: tm.config.CipherSuites,
}
}
// CreateSecureHTTPClient 创建安全HTTP客户端
func (tm *TLSManager) CreateSecureHTTPClient() *http.Client {
return &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: tm.GetClientTLSConfig(),
},
}
}
数据加密服务
package security
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io"
)
// EncryptionService 加密服务
type EncryptionService struct {
privateKey *rsa.PrivateKey
publicKey *rsa.PublicKey
aesKey []byte
}
// NewEncryptionService 创建加密服务
func NewEncryptionService() (*EncryptionService, error) {
// 生成RSA密钥对
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA key: %v", err)
}
// 生成AES密钥
aesKey := make([]byte, 32) // AES-256
if _, err := rand.Read(aesKey); err != nil {
return nil, fmt.Errorf("failed to generate AES key: %v", err)
}
return &EncryptionService{
privateKey: privateKey,
publicKey: &privateKey.PublicKey,
aesKey: aesKey,
}, nil
}
// EncryptAES AES加密
func (es *EncryptionService) EncryptAES(plaintext []byte) (string, error) {
block, err := aes.NewCipher(es.aesKey)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptAES AES解密
func (es *EncryptionService) DecryptAES(ciphertext string) ([]byte, error) {
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(es.aesKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
// EncryptRSA RSA加密
func (es *EncryptionService) EncryptRSA(plaintext []byte) (string, error) {
ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, es.publicKey, plaintext, nil)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptRSA RSA解密
func (es *EncryptionService) DecryptRSA(ciphertext string) ([]byte, error) {
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return nil, err
}
plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, es.privateKey, data, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
// GetPublicKeyPEM 获取公钥PEM格式
func (es *EncryptionService) GetPublicKeyPEM() (string, error) {
pubKeyBytes, err := x509.MarshalPKIXPublicKey(es.publicKey)
if err != nil {
return "", err
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
})
return string(pubKeyPEM), nil
}
安全审计与监控
安全事件监控
package security
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
)
// SecurityEventType 安全事件类型
type SecurityEventType string
const (
EventTypeLogin SecurityEventType = "login"
EventTypeLogout SecurityEventType = "logout"
EventTypeLoginFailed SecurityEventType = "login_failed"
EventTypeAccessDenied SecurityEventType = "access_denied"
EventTypeTokenExpired SecurityEventType = "token_expired"
EventTypeDataAccess SecurityEventType = "data_access"
EventTypeDataModify SecurityEventType = "data_modify"
EventTypeSecurityBreach SecurityEventType = "security_breach"
)
// SecurityEvent 安全事件
type SecurityEvent struct {
ID string `json:"id"`
Type SecurityEventType `json:"type"`
UserID string `json:"user_id"`
IP string `json:"ip"`
UserAgent string `json:"user_agent"`
Resource string `json:"resource"`
Action string `json:"action"`
Result string `json:"result"`
Message string `json:"message"`
Metadata map[string]string `json:"metadata"`
Timestamp time.Time `json:"timestamp"`
}
// SecurityAuditor 安全审计器
type SecurityAuditor struct {
events chan SecurityEvent
handlers []SecurityEventHandler
mu sync.RWMutex
metrics *SecurityMetrics
}
// SecurityEventHandler 安全事件处理器
type SecurityEventHandler interface {
Handle(event SecurityEvent) error
}
// SecurityMetrics 安全指标
type SecurityMetrics struct {
SecurityEvents *prometheus.CounterVec
LoginAttempts *prometheus.CounterVec
AccessDenied *prometheus.CounterVec
}
// NewSecurityAuditor 创建安全审计器
func NewSecurityAuditor() *SecurityAuditor {
metrics := &SecurityMetrics{
SecurityEvents: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "security_events_total",
Help: "Total number of security events",
},
[]string{"type", "result"},
),
LoginAttempts: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "login_attempts_total",
Help: "Total number of login attempts",
},
[]string{"result", "user_id"},
),
AccessDenied: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "access_denied_total",
Help: "Total number of access denied events",
},
[]string{"resource", "user_id"},
),
}
// 注册指标
prometheus.MustRegister(metrics.SecurityEvents)
prometheus.MustRegister(metrics.LoginAttempts)
prometheus.MustRegister(metrics.AccessDenied)
auditor := &SecurityAuditor{
events: make(chan SecurityEvent, 1000),
metrics: metrics,
}
// 启动事件处理器
go auditor.processEvents()
return auditor
}
// AddHandler 添加事件处理器
func (sa *SecurityAuditor) AddHandler(handler SecurityEventHandler) {
sa.mu.Lock()
defer sa.mu.Unlock()
sa.handlers = append(sa.handlers, handler)
}
// LogEvent 记录安全事件
func (sa *SecurityAuditor) LogEvent(event SecurityEvent) {
event.Timestamp = time.Now()
if event.ID == "" {
event.ID = generateEventID()
}
select {
case sa.events <- event:
// 更新指标
sa.updateMetrics(event)
default:
log.Printf("Security event queue full, dropping event: %+v", event)
}
}
// processEvents 处理安全事件
func (sa *SecurityAuditor) processEvents() {
for event := range sa.events {
sa.mu.RLock()
handlers := sa.handlers
sa.mu.RUnlock()
for _, handler := range handlers {
if err := handler.Handle(event); err != nil {
log.Printf("Error handling security event: %v", err)
}
}
}
}
// updateMetrics 更新安全指标
func (sa *SecurityAuditor) updateMetrics(event SecurityEvent) {
sa.metrics.SecurityEvents.WithLabelValues(string(event.Type), event.Result).Inc()
switch event.Type {
case EventTypeLogin, EventTypeLoginFailed:
sa.metrics.LoginAttempts.WithLabelValues(event.Result, event.UserID).Inc()
case EventTypeAccessDenied:
sa.metrics.AccessDenied.WithLabelValues(event.Resource, event.UserID).Inc()
}
}
// generateEventID 生成事件ID
func generateEventID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
// LogEventHandler 日志事件处理器
type LogEventHandler struct{}
// Handle 处理安全事件
func (leh *LogEventHandler) Handle(event SecurityEvent) error {
eventJSON, err := json.Marshal(event)
if err != nil {
return err
}
log.Printf("Security Event: %s", string(eventJSON))
return nil
}
// DatabaseEventHandler 数据库事件处理器
type DatabaseEventHandler struct {
// 数据库连接等
}
// Handle 处理安全事件
func (deh *DatabaseEventHandler) Handle(event SecurityEvent) error {
// 将事件存储到数据库
// 实现数据库存储逻辑
return nil
}
安全最佳实践
1. 身份认证最佳实践
- 多因素认证(MFA):为敏感操作启用多因素认证
- 密码策略:实施强密码策略和定期更换
- 会话管理:合理设置会话超时和安全注销
- 令牌安全:使用短期访问令牌和长期刷新令牌
2. 授权控制最佳实践
- 最小权限原则:只授予必要的最小权限
- 权限分离:关键操作需要多人授权
- 定期审计:定期审查和清理权限
- 动态授权:基于上下文的动态权限控制
3. 数据安全最佳实践
- 数据分类:对数据进行敏感性分类
- 加密存储:敏感数据必须加密存储
- 传输加密:所有数据传输使用TLS/SSL
- 数据脱敏:非生产环境使用脱敏数据
4. 网络安全最佳实践
- 网络隔离:使用VPC和安全组隔离网络
- API网关:统一API入口和安全控制
- DDoS防护:部署DDoS防护措施
- 入侵检测:部署入侵检测和防护系统
5. 监控审计最佳实践
- 全面日志:记录所有安全相关事件
- 实时监控:实时监控异常行为
- 告警机制:建立安全事件告警机制
- 合规审计:满足行业合规要求
使用示例
package main
import (
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
func main() {
// 创建JWT服务
jwtConfig := &JWTConfig{
PrivateKeyPath: "private.pem",
PublicKeyPath: "public.pem",
Issuer: "microservice-auth",
AccessTTL: 15 * time.Minute,
RefreshTTL: 7 * 24 * time.Hour,
}
jwtService, err := NewJWTService(jwtConfig)
if err != nil {
log.Fatal("Failed to create JWT service:", err)
}
// 创建RBAC
rbac := NewRBAC()
// 创建安全审计器
auditor := NewSecurityAuditor()
auditor.AddHandler(&LogEventHandler{})
// 创建认证中间件
authMiddleware := &AuthMiddleware{
JWTService: jwtService,
RBAC: rbac,
Auditor: auditor,
}
// 创建Gin路由
r := gin.Default()
// 添加安全中间件
r.Use(SecurityHeadersMiddleware())
r.Use(CORSMiddleware())
r.Use(RateLimitMiddleware(100, time.Minute))
// 公开路由
public := r.Group("/api/v1")
{
public.POST("/login", loginHandler(jwtService, auditor))
public.POST("/register", registerHandler(auditor))
}
// 需要认证的路由
auth := r.Group("/api/v1")
auth.Use(authMiddleware.RequireAuth())
{
auth.GET("/profile", profileHandler)
auth.POST("/refresh", refreshHandler(jwtService))
}
// 需要特定权限的路由
admin := r.Group("/api/v1/admin")
admin.Use(authMiddleware.RequireAuth())
admin.Use(authMiddleware.RequirePermission("admin:read"))
{
admin.GET("/users", listUsersHandler)
}
// 启动服务器
log.Println("Starting server on :8080")
log.Fatal(http.ListenAndServe(":8080", r))
}
func loginHandler(jwtService *JWTService, auditor *SecurityAuditor) gin.HandlerFunc {
return func(c *gin.Context) {
// 登录逻辑
// ...
// 记录登录事件
auditor.LogEvent(SecurityEvent{
Type: EventTypeLogin,
UserID: "user123",
IP: c.ClientIP(),
Result: "success",
})
c.JSON(http.StatusOK, gin.H{"message": "Login successful"})
}
}
func registerHandler(auditor *SecurityAuditor) gin.HandlerFunc {
return func(c *gin.Context) {
// 注册逻辑
c.JSON(http.StatusOK, gin.H{"message": "Registration successful"})
}
}
func profileHandler(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "User profile"})
}
func refreshHandler(jwtService *JWTService) gin.HandlerFunc {
return func(c *gin.Context) {
// 刷新令牌逻辑
c.JSON(http.StatusOK, gin.H{"message": "Token refreshed"})
}
}
func listUsersHandler(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"users": []string{"user1", "user2"}})
}
本章总结
本章详细介绍了微服务安全与认证的核心概念和实现方法:
安全挑战:分析了微服务架构带来的安全挑战,包括攻击面扩大、身份认证复杂、授权控制困难等
JWT认证:实现了完整的JWT认证服务,包括令牌生成、验证、刷新等功能
OAuth2协议:实现了OAuth2授权服务器,支持多种授权模式
RBAC权限控制:实现了基于角色的访问控制系统,支持细粒度权限管理
数据加密:提供了TLS/SSL配置和数据加密服务,确保数据传输和存储安全
安全审计:实现了安全事件监控和审计系统,支持实时监控和告警
最佳实践:总结了微服务安全的最佳实践和实施建议
通过本章的学习,你将能够构建安全可靠的微服务系统,有效防范各种安全威胁。
下一章我们将探讨微服务的测试策略与质量保证。
JWT服务实现
// auth/jwt.go
package auth
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
)
// Claims JWT声明
type Claims struct {
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
Roles []string `json:"roles"`
Scopes []string `json:"scopes"`
jwt.RegisteredClaims
}
// TokenPair 令牌对
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
// JWTConfig JWT配置
type JWTConfig struct {
PrivateKey *rsa.PrivateKey
PublicKey *rsa.PublicKey
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
Issuer string
Audience string
}
// JWTService JWT服务
type JWTService struct {
config *JWTConfig
}
// NewJWTService 创建JWT服务
func NewJWTService(config *JWTConfig) *JWTService {
return &JWTService{
config: config,
}
}
// LoadRSAKeys 加载RSA密钥对
func LoadRSAKeys(privateKeyPEM, publicKeyPEM string) (*rsa.PrivateKey, *rsa.PublicKey, error) {
// 解析私钥
privateBlock, _ := pem.Decode([]byte(privateKeyPEM))
if privateBlock == nil {
return nil, nil, errors.New("failed to decode private key PEM")
}
privateKey, err := x509.ParsePKCS1PrivateKey(privateBlock.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse private key: %w", err)
}
// 解析公钥
publicBlock, _ := pem.Decode([]byte(publicKeyPEM))
if publicBlock == nil {
return nil, nil, errors.New("failed to decode public key PEM")
}
publicKeyInterface, err := x509.ParsePKIXPublicKey(publicBlock.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse public key: %w", err)
}
publicKey, ok := publicKeyInterface.(*rsa.PublicKey)
if !ok {
return nil, nil, errors.New("public key is not RSA")
}
return privateKey, publicKey, nil
}
// GenerateTokenPair 生成令牌对
func (js *JWTService) GenerateTokenPair(userID, username, email string, roles, scopes []string) (*TokenPair, error) {
now := time.Now()
// 生成访问令牌
accessClaims := &Claims{
UserID: userID,
Username: username,
Email: email,
Roles: roles,
Scopes: scopes,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: js.config.Issuer,
Audience: jwt.ClaimStrings{js.config.Audience},
Subject: userID,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(js.config.AccessTokenTTL)),
NotBefore: jwt.NewNumericDate(now),
ID: fmt.Sprintf("%s-%d", userID, now.Unix()),
},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodRS256, accessClaims)
accessTokenString, err := accessToken.SignedString(js.config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to sign access token: %w", err)
}
// 生成刷新令牌
refreshClaims := &Claims{
UserID: userID,
Username: username,
Email: email,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: js.config.Issuer,
Audience: jwt.ClaimStrings{js.config.Audience},
Subject: userID,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(js.config.RefreshTokenTTL)),
NotBefore: jwt.NewNumericDate(now),
ID: fmt.Sprintf("%s-refresh-%d", userID, now.Unix()),
},
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodRS256, refreshClaims)
refreshTokenString, err := refreshToken.SignedString(js.config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to sign refresh token: %w", err)
}
return &TokenPair{
AccessToken: accessTokenString,
RefreshToken: refreshTokenString,
TokenType: "Bearer",
ExpiresIn: int64(js.config.AccessTokenTTL.Seconds()),
}, nil
}
// ValidateToken 验证令牌
func (js *JWTService) ValidateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return js.config.PublicKey, nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, errors.New("invalid token")
}
// 验证发行者
if claims.Issuer != js.config.Issuer {
return nil, errors.New("invalid issuer")
}
// 验证受众
validAudience := false
for _, aud := range claims.Audience {
if aud == js.config.Audience {
validAudience = true
break
}
}
if !validAudience {
return nil, errors.New("invalid audience")
}
return claims, nil
}
// RefreshToken 刷新令牌
func (js *JWTService) RefreshToken(refreshTokenString string) (*TokenPair, error) {
// 验证刷新令牌
claims, err := js.ValidateToken(refreshTokenString)
if err != nil {
return nil, fmt.Errorf("invalid refresh token: %w", err)
}
// 生成新的令牌对
return js.GenerateTokenPair(claims.UserID, claims.Username, claims.Email, claims.Roles, claims.Scopes)
}
// ExtractTokenFromHeader 从HTTP头中提取令牌
func ExtractTokenFromHeader(authHeader string) (string, error) {
if authHeader == "" {
return "", errors.New("authorization header is empty")
}
const bearerPrefix = "Bearer "
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
return "", errors.New("authorization header must start with 'Bearer '")
}
return authHeader[len(bearerPrefix):], nil
}
OAuth2服务实现
// auth/oauth2.go
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// GrantType OAuth2授权类型
type GrantType string
const (
GrantTypeAuthorizationCode GrantType = "authorization_code"
GrantTypeClientCredentials GrantType = "client_credentials"
GrantTypeRefreshToken GrantType = "refresh_token"
GrantTypePassword GrantType = "password"
)
// Client OAuth2客户端
type Client struct {
ID string `json:"id"`
Secret string `json:"secret"`
Name string `json:"name"`
RedirectURIs []string `json:"redirect_uris"`
Scopes []string `json:"scopes"`
GrantTypes []GrantType `json:"grant_types"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AuthorizationCode 授权码
type AuthorizationCode struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
UserID string `json:"user_id"`
RedirectURI string `json:"redirect_uri"`
Scopes []string `json:"scopes"`
ExpiresAt time.Time `json:"expires_at"`
Used bool `json:"used"`
}
// TokenRequest 令牌请求
type TokenRequest struct {
GrantType GrantType `json:"grant_type" form:"grant_type"`
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
RefreshToken string `json:"refresh_token" form:"refresh_token"`
Scope string `json:"scope" form:"scope"`
}
// TokenResponse 令牌响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// ErrorResponse 错误响应
type ErrorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
ErrorURI string `json:"error_uri,omitempty"`
}
// ClientStore 客户端存储接口
type ClientStore interface {
GetClient(ctx context.Context, clientID string) (*Client, error)
ValidateClient(ctx context.Context, clientID, clientSecret string) (*Client, error)
}
// CodeStore 授权码存储接口
type CodeStore interface {
StoreCode(ctx context.Context, code *AuthorizationCode) error
GetCode(ctx context.Context, code string) (*AuthorizationCode, error)
DeleteCode(ctx context.Context, code string) error
}
// UserStore 用户存储接口
type UserStore interface {
ValidateUser(ctx context.Context, username, password string) (*User, error)
GetUser(ctx context.Context, userID string) (*User, error)
}
// User 用户信息
type User struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Roles []string `json:"roles"`
}
// OAuth2Server OAuth2服务器
type OAuth2Server struct {
clientStore ClientStore
codeStore CodeStore
userStore UserStore
jwtService *JWTService
config *OAuth2Config
}
// OAuth2Config OAuth2配置
type OAuth2Config struct {
CodeTTL time.Duration
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
Issuer string
Audience string
}
// NewOAuth2Server 创建OAuth2服务器
func NewOAuth2Server(clientStore ClientStore, codeStore CodeStore, userStore UserStore, jwtService *JWTService, config *OAuth2Config) *OAuth2Server {
return &OAuth2Server{
clientStore: clientStore,
codeStore: codeStore,
userStore: userStore,
jwtService: jwtService,
config: config,
}
}
// generateRandomString 生成随机字符串
func generateRandomString(length int) (string, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
}
// AuthorizeHandler 授权端点处理器
func (s *OAuth2Server) AuthorizeHandler(c *gin.Context) {
responseType := c.Query("response_type")
clientID := c.Query("client_id")
redirectURI := c.Query("redirect_uri")
scope := c.Query("scope")
state := c.Query("state")
// 验证响应类型
if responseType != "code" {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "unsupported_response_type",
ErrorDescription: "Only 'code' response type is supported",
})
return
}
// 验证客户端
client, err := s.clientStore.GetClient(c.Request.Context(), clientID)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "invalid_client",
ErrorDescription: "Invalid client ID",
})
return
}
// 验证重定向URI
validRedirectURI := false
for _, uri := range client.RedirectURIs {
if uri == redirectURI {
validRedirectURI = true
break
}
}
if !validRedirectURI {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "invalid_redirect_uri",
ErrorDescription: "Invalid redirect URI",
})
return
}
// 验证作用域
requestedScopes := strings.Fields(scope)
validScopes := s.validateScopes(client.Scopes, requestedScopes)
if len(validScopes) == 0 {
redirectWithError(c, redirectURI, "invalid_scope", "Invalid scope", state)
return
}
// 这里应该显示授权页面让用户确认
// 为了简化,我们假设用户已经登录并同意授权
userID := c.GetHeader("X-User-ID") // 从认证中间件获取用户ID
if userID == "" {
// 重定向到登录页面
loginURL := fmt.Sprintf("/login?redirect_uri=%s", url.QueryEscape(c.Request.URL.String()))
c.Redirect(http.StatusFound, loginURL)
return
}
// 生成授权码
code, err := generateRandomString(32)
if err != nil {
redirectWithError(c, redirectURI, "server_error", "Failed to generate authorization code", state)
return
}
// 存储授权码
authCode := &AuthorizationCode{
Code: code,
ClientID: clientID,
UserID: userID,
RedirectURI: redirectURI,
Scopes: validScopes,
ExpiresAt: time.Now().Add(s.config.CodeTTL),
Used: false,
}
err = s.codeStore.StoreCode(c.Request.Context(), authCode)
if err != nil {
redirectWithError(c, redirectURI, "server_error", "Failed to store authorization code", state)
return
}
// 重定向到客户端
redirectURL, _ := url.Parse(redirectURI)
query := redirectURL.Query()
query.Set("code", code)
if state != "" {
query.Set("state", state)
}
redirectURL.RawQuery = query.Encode()
c.Redirect(http.StatusFound, redirectURL.String())
}
// TokenHandler 令牌端点处理器
func (s *OAuth2Server) TokenHandler(c *gin.Context) {
var req TokenRequest
if err := c.ShouldBind(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "invalid_request",
ErrorDescription: "Invalid request format",
})
return
}
switch req.GrantType {
case GrantTypeAuthorizationCode:
s.handleAuthorizationCodeGrant(c, &req)
case GrantTypeClientCredentials:
s.handleClientCredentialsGrant(c, &req)
case GrantTypeRefreshToken:
s.handleRefreshTokenGrant(c, &req)
case GrantTypePassword:
s.handlePasswordGrant(c, &req)
default:
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "unsupported_grant_type",
ErrorDescription: "Unsupported grant type",
})
}
}
// handleAuthorizationCodeGrant 处理授权码授权
func (s *OAuth2Server) handleAuthorizationCodeGrant(c *gin.Context, req *TokenRequest) {
// 验证客户端
client, err := s.clientStore.ValidateClient(c.Request.Context(), req.ClientID, req.ClientSecret)
if err != nil {
c.JSON(http.StatusUnauthorized, ErrorResponse{
Error: "invalid_client",
ErrorDescription: "Invalid client credentials",
})
return
}
// 获取授权码
authCode, err := s.codeStore.GetCode(c.Request.Context(), req.Code)
if err != nil || authCode.Used || time.Now().After(authCode.ExpiresAt) {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "invalid_grant",
ErrorDescription: "Invalid or expired authorization code",
})
return
}
// 验证客户端ID和重定向URI
if authCode.ClientID != req.ClientID || authCode.RedirectURI != req.RedirectURI {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "invalid_grant",
ErrorDescription: "Authorization code was issued to another client",
})
return
}
// 标记授权码为已使用
err = s.codeStore.DeleteCode(c.Request.Context(), req.Code)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: "server_error",
ErrorDescription: "Failed to invalidate authorization code",
})
return
}
// 获取用户信息
user, err := s.userStore.GetUser(c.Request.Context(), authCode.UserID)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: "server_error",
ErrorDescription: "Failed to get user information",
})
return
}
// 生成令牌
tokenPair, err := s.jwtService.GenerateTokenPair(user.ID, user.Username, user.Email, user.Roles, authCode.Scopes)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: "server_error",
ErrorDescription: "Failed to generate tokens",
})
return
}
c.JSON(http.StatusOK, TokenResponse{
AccessToken: tokenPair.AccessToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
RefreshToken: tokenPair.RefreshToken,
Scope: strings.Join(authCode.Scopes, " "),
})
}
// handleClientCredentialsGrant 处理客户端凭证授权
func (s *OAuth2Server) handleClientCredentialsGrant(c *gin.Context, req *TokenRequest) {
// 验证客户端
client, err := s.clientStore.ValidateClient(c.Request.Context(), req.ClientID, req.ClientSecret)
if err != nil {
c.JSON(http.StatusUnauthorized, ErrorResponse{
Error: "invalid_client",
ErrorDescription: "Invalid client credentials",
})
return
}
// 验证授权类型
validGrantType := false
for _, gt := range client.GrantTypes {
if gt == GrantTypeClientCredentials {
validGrantType = true
break
}
}
if !validGrantType {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "unauthorized_client",
ErrorDescription: "Client is not authorized for this grant type",
})
return
}
// 验证作用域
requestedScopes := strings.Fields(req.Scope)
validScopes := s.validateScopes(client.Scopes, requestedScopes)
// 生成令牌(客户端凭证模式不需要刷新令牌)
tokenPair, err := s.jwtService.GenerateTokenPair(client.ID, client.Name, "", []string{"client"}, validScopes)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: "server_error",
ErrorDescription: "Failed to generate token",
})
return
}
c.JSON(http.StatusOK, TokenResponse{
AccessToken: tokenPair.AccessToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
Scope: strings.Join(validScopes, " "),
})
}
// handleRefreshTokenGrant 处理刷新令牌授权
func (s *OAuth2Server) handleRefreshTokenGrant(c *gin.Context, req *TokenRequest) {
// 验证客户端
_, err := s.clientStore.ValidateClient(c.Request.Context(), req.ClientID, req.ClientSecret)
if err != nil {
c.JSON(http.StatusUnauthorized, ErrorResponse{
Error: "invalid_client",
ErrorDescription: "Invalid client credentials",
})
return
}
// 刷新令牌
tokenPair, err := s.jwtService.RefreshToken(req.RefreshToken)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "invalid_grant",
ErrorDescription: "Invalid refresh token",
})
return
}
c.JSON(http.StatusOK, TokenResponse{
AccessToken: tokenPair.AccessToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
RefreshToken: tokenPair.RefreshToken,
})
}
// handlePasswordGrant 处理密码授权
func (s *OAuth2Server) handlePasswordGrant(c *gin.Context, req *TokenRequest) {
// 验证客户端
client, err := s.clientStore.ValidateClient(c.Request.Context(), req.ClientID, req.ClientSecret)
if err != nil {
c.JSON(http.StatusUnauthorized, ErrorResponse{
Error: "invalid_client",
ErrorDescription: "Invalid client credentials",
})
return
}
// 验证授权类型
validGrantType := false
for _, gt := range client.GrantTypes {
if gt == GrantTypePassword {
validGrantType = true
break
}
}
if !validGrantType {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "unauthorized_client",
ErrorDescription: "Client is not authorized for this grant type",
})
return
}
// 验证用户凭证
user, err := s.userStore.ValidateUser(c.Request.Context(), req.Username, req.Password)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: "invalid_grant",
ErrorDescription: "Invalid username or password",
})
return
}
// 验证作用域
requestedScopes := strings.Fields(req.Scope)
validScopes := s.validateScopes(client.Scopes, requestedScopes)
// 生成令牌
tokenPair, err := s.jwtService.GenerateTokenPair(user.ID, user.Username, user.Email, user.Roles, validScopes)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: "server_error",
ErrorDescription: "Failed to generate tokens",
})
return
}
c.JSON(http.StatusOK, TokenResponse{
AccessToken: tokenPair.AccessToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
RefreshToken: tokenPair.RefreshToken,
Scope: strings.Join(validScopes, " "),
})
}
// validateScopes 验证作用域
func (s *OAuth2Server) validateScopes(clientScopes, requestedScopes []string) []string {
var validScopes []string
for _, requested := range requestedScopes {
for _, allowed := range clientScopes {
if requested == allowed {
validScopes = append(validScopes, requested)
break
}
}
}
return validScopes
}
// redirectWithError 重定向并携带错误信息
func redirectWithError(c *gin.Context, redirectURI, error, errorDescription, state string) {
redirectURL, _ := url.Parse(redirectURI)
query := redirectURL.Query()
query.Set("error", error)
if errorDescription != "" {
query.Set("error_description", errorDescription)
}
if state != "" {
query.Set("state", state)
}
redirectURL.RawQuery = query.Encode()
c.Redirect(http.StatusFound, redirectURL.String())
}
RBAC权限控制
RBAC实现
// rbac/rbac.go
package rbac
import (
"context"
"errors"
"fmt"
"strings"
"sync"
)
// Permission 权限
type Permission struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Resource string `json:"resource"`
Action string `json:"action"`
}
// Role 角色
type Role struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Permissions []Permission `json:"permissions"`
}
// Subject 主体(用户或服务)
type Subject struct {
ID string `json:"id"`
Type string `json:"type"` // user, service, etc.
Roles []Role `json:"roles"`
}
// PolicyRule 策略规则
type PolicyRule struct {
Subjects []string `json:"subjects"`
Resources []string `json:"resources"`
Actions []string `json:"actions"`
Effect string `json:"effect"` // allow, deny
}
// Policy 策略
type Policy struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Rules []PolicyRule `json:"rules"`
}
// RBACStore RBAC存储接口
type RBACStore interface {
GetSubject(ctx context.Context, subjectID string) (*Subject, error)
GetRole(ctx context.Context, roleID string) (*Role, error)
GetPermission(ctx context.Context, permissionID string) (*Permission, error)
GetPolicy(ctx context.Context, policyID string) (*Policy, error)
ListPolicies(ctx context.Context) ([]Policy, error)
}
// RBAC RBAC管理器
type RBAC struct {
store RBACStore
mu sync.RWMutex
cache map[string]interface{}
}
// NewRBAC 创建RBAC管理器
func NewRBAC(store RBACStore) *RBAC {
return &RBAC{
store: store,
cache: make(map[string]interface{}),
}
}
// CheckPermission 检查权限
func (r *RBAC) CheckPermission(ctx context.Context, subjectID, resource, action string) (bool, error) {
// 获取主体信息
subject, err := r.store.GetSubject(ctx, subjectID)
if err != nil {
return false, fmt.Errorf("failed to get subject: %w", err)
}
// 检查角色权限
for _, role := range subject.Roles {
for _, permission := range role.Permissions {
if r.matchPermission(permission, resource, action) {
return true, nil
}
}
}
// 检查策略规则
policies, err := r.store.ListPolicies(ctx)
if err != nil {
return false, fmt.Errorf("failed to get policies: %w", err)
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if r.matchRule(rule, subjectID, resource, action) {
return rule.Effect == "allow", nil
}
}
}
return false, nil
}
// matchPermission 匹配权限
func (r *RBAC) matchPermission(permission Permission, resource, action string) bool {
return r.matchPattern(permission.Resource, resource) && r.matchPattern(permission.Action, action)
}
// matchRule 匹配规则
func (r *RBAC) matchRule(rule PolicyRule, subjectID, resource, action string) bool {
// 匹配主体
subjectMatch := false
for _, subject := range rule.Subjects {
if r.matchPattern(subject, subjectID) {
subjectMatch = true
break
}
}
if !subjectMatch {
return false
}
// 匹配资源
resourceMatch := false
for _, res := range rule.Resources {
if r.matchPattern(res, resource) {
resourceMatch = true
break
}
}
if !resourceMatch {
return false
}
// 匹配动作
actionMatch := false
for _, act := range rule.Actions {
if r.matchPattern(act, action) {
actionMatch = true
break
}
}
return actionMatch
}
// matchPattern 模式匹配(支持通配符)
func (r *RBAC) matchPattern(pattern, value string) bool {
if pattern == "*" {
return true
}
if pattern == value {
return true
}
// 支持前缀匹配
if strings.HasSuffix(pattern, "*") {
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(value, prefix)
}
// 支持后缀匹配
if strings.HasPrefix(pattern, "*") {
suffix := strings.TrimPrefix(pattern, "*")
return strings.HasSuffix(value, suffix)
}
return false
}
// HasRole 检查是否具有角色
func (r *RBAC) HasRole(ctx context.Context, subjectID, roleID string) (bool, error) {
subject, err := r.store.GetSubject(ctx, subjectID)
if err != nil {
return false, fmt.Errorf("failed to get subject: %w", err)
}
for _, role := range subject.Roles {
if role.ID == roleID {
return true, nil
}
}
return false, nil
}
// GetSubjectPermissions 获取主体的所有权限
func (r *RBAC) GetSubjectPermissions(ctx context.Context, subjectID string) ([]Permission, error) {
subject, err := r.store.GetSubject(ctx, subjectID)
if err != nil {
return nil, fmt.Errorf("failed to get subject: %w", err)
}
permissionMap := make(map[string]Permission)
for _, role := range subject.Roles {
for _, permission := range role.Permissions {
permissionMap[permission.ID] = permission
}
}
permissions := make([]Permission, 0, len(permissionMap))
for _, permission := range permissionMap {
permissions = append(permissions, permission)
}
return permissions, nil
}
// ValidateAccess 验证访问权限
func (r *RBAC) ValidateAccess(ctx context.Context, subjectID, resource, action string, requiredPermissions []string) error {
// 检查基本权限
allowed, err := r.CheckPermission(ctx, subjectID, resource, action)
if err != nil {
return fmt.Errorf("failed to check permission: %w", err)
}
if !allowed {
return errors.New("access denied: insufficient permissions")
}
// 检查特定权限要求
if len(requiredPermissions) > 0 {
subjectPermissions, err := r.GetSubjectPermissions(ctx, subjectID)
if err != nil {
return fmt.Errorf("failed to get subject permissions: %w", err)
}
permissionSet := make(map[string]bool)
for _, permission := range subjectPermissions {
permissionSet[permission.ID] = true
}
for _, required := range requiredPermissions {
if !permissionSet[required] {
return fmt.Errorf("access denied: missing required permission '%s'", required)
}
}
}
return nil
}
认证中间件
”`go // middleware/auth.go package middleware
import ( “context” “net/http” “strings”
"github.com/gin-gonic/gin"
"your-project/auth"
"your-project/rbac"
)
// AuthMiddleware 认证中间件 type AuthMiddleware struct { jwtService *auth.JWTService rbac *rbac.RBAC }
// NewAuthMiddleware 创建认证中间件 func NewAuthMiddleware(jwtService *auth.JWTService, rbacManager *rbac.RBAC) *AuthMiddleware { return &AuthMiddleware{ jwtService: jwtService, rbac: rbacManager, } }
// RequireAuth 要求认证 func (am *AuthMiddleware) RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { // 提取令牌 authHeader := c.GetHeader(“Authorization”) token, err := auth.ExtractTokenFromHeader(authHeader) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ “error”: “missing or invalid authorization header”, }) c.Abort() return }
// 验证令牌
claims, err := am.jwtService.ValidateToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "invalid token",
})
c.Abort()
return
}
// 设置用户信息到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("roles", claims.Roles)
c.Set("scopes", claims.Scopes)
c.Set("claims", claims)
c.Next()
}
}
// RequirePermission 要求特定权限 func (am *AuthMiddleware) RequirePermission(resource, action string) gin.HandlerFunc { return func(c *gin.Context) { // 获取用户ID userID, exists := c.Get(“user_id”) if !exists { c.JSON(http.StatusUnauthorized, gin.H{ “error”: “user not authenticated”, }) c.Abort() return }
// 检查权限
allowed, err := am.rbac.CheckPermission(c.Request.Context(), userID.(string), resource, action)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "failed to check permission",
})
c.Abort()
return
}
if !allowed {
c.JSON(http.StatusForbidden, gin.H{
"error": "insufficient permissions",
})
c.Abort()
return
}
c.Next()
}
}
// RequireRole 要求特定角色 func (am *AuthMiddleware) RequireRole(roleID string) gin.HandlerFunc { return func(c *gin.Context) { // 获取用户ID userID, exists := c.Get(“user_id”) if !exists { c.JSON(http.StatusUnauthorized, gin.H{ “error”: “user not authenticated”, }) c.Abort() return }
// 检查角色
hasRole, err := am.rbac.HasRole(c.Request.Context(), userID.(string), roleID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "failed to check role",
})
c.Abort()
return
}
if !hasRole {
c.JSON(http.StatusForbidden, gin.H{
"error": "insufficient role",
})
c.Abort()
return
}
c.Next()
}
}
// RequireScope 要求特定作用域 func (am *AuthMiddleware) RequireScope(requiredScopes …string) gin.HandlerFunc { return func(c *gin.Context) { // 获取作用域 scopes, exists := c.Get(“scopes”) if !exists { c.JSON(http.StatusUnauthorized, gin.H{ “error”: “user not authenticated”, }) c.Abort() return }
userScopes := scopes.([]string)
scopeSet := make(map[string]bool)
for _, scope := range userScopes {
scopeSet[scope] = true
}
// 检查是否具有所需作用域
for _, required := range requiredScopes {
if !scopeSet[required] {
c.JSON(http.StatusForbidden, gin.H{
"error": "insufficient scope",
"required_scope": required,
})
c.Abort()
return
}
}
c.Next()
}
}
// OptionalAuth 可选认证 func (am *AuthMiddleware) OptionalAuth() gin.HandlerFunc { return func(c *gin.Context) { // 提取令牌 authHeader := c.GetHeader(“Authorization”) if authHeader == “” { c.Next() return }
token, err := auth.ExtractTokenFromHeader(authHeader)
if err != nil {
c.Next()
return
}
// 验证令牌
claims, err := am.jwtService.ValidateToken(token)
if err != nil {
c.Next()
return
}
// 设置用户信息到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("roles", claims.Roles)
c.Set("scopes", claims.Scopes)
c.Set("claims", claims)
c.Next()
}
}
// RateLimitMiddleware 限流中间件 func RateLimitMiddleware(requestsPerMinute int) gin.HandlerFunc { // 这里应该实现基于令牌桶或滑动窗口的限流算法 // 为了简化,这里只是一个示例 return func(c *gin.Context) { // 获取客户端IP或用户ID clientID := c.ClientIP() if userID, exists := c.Get(“user_id”); exists { clientID = userID.(string) }
// 检查限流
// 这里应该实现实际的限流逻辑
c.Next()
}
}
// CORSMiddleware CORS中间件 func CORSMiddleware() gin.HandlerFunc { return func(c gin.Context) { c.Header(“Access-Control-Allow-Origin”, “”) c.Header(“Access-Control-Allow-Methods”, “GET, POST, PUT, DELETE, OPTIONS”) c.Header(“Access-Control-Allow-Headers”, “Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization”)
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// SecurityHeadersMiddleware 安全头中间件 func SecurityHeadersMiddleware() gin.HandlerFunc { return func(c *gin.Context) { c.Header(“X-Content-Type-Options”, “nosniff”) c.Header(“X-Frame-Options”, “DENY”) c.Header(“X-XSS-Protection”, “1; mode=block”) c.Header(“Strict-Transport-Security”, “max-age=31536000; includeSubDomains”) c.Header(“Content-Security-Policy”, “default-src ‘self’”