package middleware import ( "net/http" "strconv" "github.com/gin-gonic/gin" "go.uber.org/zap" "veza-backend-api/internal/validators" ) // QueryParamValidation middleware pour valider les paramètres de requête (BE-SVC-020) type QueryParamValidation struct { logger *zap.Logger validator *validators.Validator } // NewQueryParamValidation crée un nouveau middleware de validation des paramètres de requête func NewQueryParamValidation(logger *zap.Logger) *QueryParamValidation { return &QueryParamValidation{ logger: logger, validator: validators.NewValidator(), } } // ValidateQueryParams valide les paramètres de requête selon les règles spécifiées // Usage: // router.GET("/users", queryValidation.ValidateQueryParams(map[string]string{ // "page": "numeric,min=1", // "limit": "numeric,min=1,max=100", // "sort": "oneof=asc,desc", // }), handler) func (q *QueryParamValidation) ValidateQueryParams(rules map[string]string) gin.HandlerFunc { return func(c *gin.Context) { errors := make(map[string]string) for param, rule := range rules { value := c.Query(param) if value == "" { // Si le paramètre est requis, on vérifie avec "required" dans la règle if containsRule(rule, "required") { errors[param] = "This query parameter is required" } continue } // Valider selon les règles if err := q.validateParam(param, value, rule); err != nil { errors[param] = err.Error() } } if len(errors) > 0 { q.logger.Warn("Query parameter validation failed", zap.Any("errors", errors), zap.String("path", c.Request.URL.Path), ) c.JSON(http.StatusBadRequest, gin.H{ "error": "Invalid query parameters", "details": errors, }) c.Abort() return } c.Next() } } // validateParam valide un paramètre selon une règle func (q *QueryParamValidation) validateParam(param, value, rule string) error { rules := parseRules(rule) for _, r := range rules { switch r.name { case "numeric": if _, err := strconv.ParseFloat(value, 64); err != nil { return &ValidationError{Field: param, Message: "must be a numeric value"} } case "integer": if _, err := strconv.Atoi(value); err != nil { return &ValidationError{Field: param, Message: "must be an integer"} } case "min": if num, err := strconv.ParseFloat(value, 64); err == nil { if min, err2 := strconv.ParseFloat(r.param, 64); err2 == nil { if num < min { return &ValidationError{Field: param, Message: "must be at least " + r.param} } } } case "max": if num, err := strconv.ParseFloat(value, 64); err == nil { if max, err2 := strconv.ParseFloat(r.param, 64); err2 == nil { if num > max { return &ValidationError{Field: param, Message: "must be at most " + r.param} } } } case "oneof": validValues := splitComma(r.param) found := false for _, valid := range validValues { valid = trimSpace(valid) if value == valid { found = true break } } if !found { return &ValidationError{Field: param, Message: "must be one of: " + r.param} } case "email": if err := q.validator.ValidateVar(value, "email"); err != nil { return &ValidationError{Field: param, Message: "must be a valid email address"} } case "uuid": if err := q.validator.ValidateVar(value, "uuid"); err != nil { return &ValidationError{Field: param, Message: "must be a valid UUID"} } case "url": if err := q.validator.ValidateVar(value, "url"); err != nil { return &ValidationError{Field: param, Message: "must be a valid URL"} } } } return nil } // ValidationError représente une erreur de validation type ValidationError struct { Field string Message string } func (e *ValidationError) Error() string { return e.Message } // rule représente une règle de validation type rule struct { name string param string } // parseRules parse une chaîne de règles (ex: "numeric,min=1,max=100") func parseRules(ruleStr string) []rule { var rules []rule parts := splitComma(ruleStr) for _, part := range parts { part = trimSpace(part) if idx := indexOf(part, "="); idx != -1 { rules = append(rules, rule{ name: part[:idx], param: part[idx+1:], }) } else { rules = append(rules, rule{ name: part, param: "", }) } } return rules } // Helper functions func containsRule(ruleStr, ruleName string) bool { rules := parseRules(ruleStr) for _, r := range rules { if r.name == ruleName { return true } } return false } func splitComma(s string) []string { var result []string start := 0 for i, char := range s { if char == ',' { if i > start { result = append(result, s[start:i]) } start = i + 1 } } if start < len(s) { result = append(result, s[start:]) } return result } func trimSpace(s string) string { start := 0 end := len(s) for start < end && s[start] == ' ' { start++ } for end > start && s[end-1] == ' ' { end-- } return s[start:end] } func indexOf(s string, substr string) int { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return i } } return -1 }