package middleware import ( "context" "crypto/rand" "encoding/hex" "fmt" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) // CSRFMiddleware crée un middleware pour la protection CSRF // Utilise Redis pour stocker les tokens CSRF associés aux utilisateurs type CSRFMiddleware struct { redisClient *redis.Client logger *zap.Logger ttl time.Duration // TTL pour les tokens CSRF (défaut: 1 heure) } // NewCSRFMiddleware crée une nouvelle instance du middleware CSRF func NewCSRFMiddleware(redisClient *redis.Client, logger *zap.Logger) *CSRFMiddleware { return &CSRFMiddleware{ redisClient: redisClient, logger: logger, ttl: 1 * time.Hour, // Tokens CSRF valides pendant 1 heure } } // Middleware retourne le handler Gin pour la protection CSRF func (m *CSRFMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { // Ignorer GET, HEAD, OPTIONS (méthodes sûres) method := c.Request.Method if method == "GET" || method == "HEAD" || method == "OPTIONS" { c.Next() return } // Récupérer le userID depuis le contexte (défini par AuthMiddleware) userIDInterface, exists := c.Get("user_id") if !exists { // Si pas d'utilisateur authentifié, pas besoin de CSRF // (les routes publiques comme login/register sont exclues) c.Next() return } userID, ok := userIDInterface.(uuid.UUID) if !ok { m.logger.Warn("Invalid user_id type in context for CSRF check") c.Next() return } // Récupérer le token CSRF depuis le header token := c.GetHeader("X-CSRF-Token") if token == "" { c.JSON(403, gin.H{ "success": false, "error": gin.H{ "code": 403, "message": "CSRF token required", }, }) c.Abort() return } // Vérifier le token dans Redis ctx := c.Request.Context() key := m.getCSRFKey(userID) storedToken, err := m.redisClient.Get(ctx, key).Result() if err != nil { if err == redis.Nil { m.logger.Warn("CSRF token not found in Redis", zap.String("user_id", userID.String()), zap.String("ip", c.ClientIP()), ) c.JSON(403, gin.H{ "success": false, "error": gin.H{ "code": 403, "message": "Invalid or expired CSRF token", }, }) c.Abort() return } m.logger.Error("Failed to get CSRF token from Redis", zap.Error(err), zap.String("user_id", userID.String()), ) c.JSON(500, gin.H{ "success": false, "error": gin.H{ "code": 500, "message": "Internal server error", }, }) c.Abort() return } // Comparer les tokens if storedToken != token { m.logger.Warn("CSRF token mismatch", zap.String("user_id", userID.String()), zap.String("ip", c.ClientIP()), ) c.JSON(403, gin.H{ "success": false, "error": gin.H{ "code": 403, "message": "Invalid CSRF token", }, }) c.Abort() return } // Token valide, continuer c.Next() } } // getCSRFKey génère la clé Redis pour un token CSRF func (m *CSRFMiddleware) getCSRFKey(userID uuid.UUID) string { return fmt.Sprintf("csrf:token:%s", userID.String()) } // GenerateToken génère un nouveau token CSRF et le stocke dans Redis func (m *CSRFMiddleware) GenerateToken(ctx context.Context, userID uuid.UUID) (string, error) { // Générer un token aléatoire de 32 bytes (64 caractères hex) tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { return "", fmt.Errorf("failed to generate CSRF token: %w", err) } token := hex.EncodeToString(tokenBytes) // Stocker le token dans Redis avec TTL key := m.getCSRFKey(userID) if err := m.redisClient.Set(ctx, key, token, m.ttl).Err(); err != nil { return "", fmt.Errorf("failed to store CSRF token: %w", err) } return token, nil } // GetToken récupère le token CSRF actuel pour un utilisateur func (m *CSRFMiddleware) GetToken(ctx context.Context, userID uuid.UUID) (string, error) { key := m.getCSRFKey(userID) token, err := m.redisClient.Get(ctx, key).Result() if err != nil { if err == redis.Nil { // Pas de token existant, en générer un nouveau return m.GenerateToken(ctx, userID) } return "", fmt.Errorf("failed to get CSRF token: %w", err) } return token, nil }