veza/veza-backend-api/internal/middleware/validation.go
2026-03-05 23:03:43 +01:00

217 lines
5.1 KiB
Go

package middleware
import (
"net/http"
"strconv"
"veza-backend-api/internal/validators"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// QueryParamValidation middleware pour valider les paramètres de requête (BE-SVC-020)
type QueryParamValidation struct {
logger *zap.Logger
validator *validators.Validator
}
// NewQueryParamValidation crée un nouveau middleware de validation des paramètres de requête
func NewQueryParamValidation(logger *zap.Logger) *QueryParamValidation {
return &QueryParamValidation{
logger: logger,
validator: validators.NewValidator(),
}
}
// ValidateQueryParams valide les paramètres de requête selon les règles spécifiées
// Usage:
//
// router.GET("/users", queryValidation.ValidateQueryParams(map[string]string{
// "page": "numeric,min=1",
// "limit": "numeric,min=1,max=100",
// "sort": "oneof=asc,desc",
// }), handler)
func (q *QueryParamValidation) ValidateQueryParams(rules map[string]string) gin.HandlerFunc {
return func(c *gin.Context) {
errors := make(map[string]string)
for param, rule := range rules {
value := c.Query(param)
if value == "" {
// Si le paramètre est requis, on vérifie avec "required" dans la règle
if containsRule(rule, "required") {
errors[param] = "This query parameter is required"
}
continue
}
// Valider selon les règles
if err := q.validateParam(param, value, rule); err != nil {
errors[param] = err.Error()
}
}
if len(errors) > 0 {
q.logger.Warn("Query parameter validation failed",
zap.Any("errors", errors),
zap.String("path", c.Request.URL.Path),
)
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid query parameters",
"details": errors,
})
c.Abort()
return
}
c.Next()
}
}
// validateParam valide un paramètre selon une règle
func (q *QueryParamValidation) validateParam(param, value, rule string) error {
rules := parseRules(rule)
for _, r := range rules {
switch r.name {
case "numeric":
if _, err := strconv.ParseFloat(value, 64); err != nil {
return &ValidationError{Field: param, Message: "must be a numeric value"}
}
case "integer":
if _, err := strconv.Atoi(value); err != nil {
return &ValidationError{Field: param, Message: "must be an integer"}
}
case "min":
if num, err := strconv.ParseFloat(value, 64); err == nil {
if min, err2 := strconv.ParseFloat(r.param, 64); err2 == nil {
if num < min {
return &ValidationError{Field: param, Message: "must be at least " + r.param}
}
}
}
case "max":
if num, err := strconv.ParseFloat(value, 64); err == nil {
if max, err2 := strconv.ParseFloat(r.param, 64); err2 == nil {
if num > max {
return &ValidationError{Field: param, Message: "must be at most " + r.param}
}
}
}
case "oneof":
validValues := splitComma(r.param)
found := false
for _, valid := range validValues {
valid = trimSpace(valid)
if value == valid {
found = true
break
}
}
if !found {
return &ValidationError{Field: param, Message: "must be one of: " + r.param}
}
case "email":
if err := q.validator.ValidateVar(value, "email"); err != nil {
return &ValidationError{Field: param, Message: "must be a valid email address"}
}
case "uuid":
if err := q.validator.ValidateVar(value, "uuid"); err != nil {
return &ValidationError{Field: param, Message: "must be a valid UUID"}
}
case "url":
if err := q.validator.ValidateVar(value, "url"); err != nil {
return &ValidationError{Field: param, Message: "must be a valid URL"}
}
}
}
return nil
}
// ValidationError représente une erreur de validation
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return e.Message
}
// rule représente une règle de validation
type rule struct {
name string
param string
}
// parseRules parse une chaîne de règles (ex: "numeric,min=1,max=100")
func parseRules(ruleStr string) []rule {
var rules []rule
parts := splitComma(ruleStr)
for _, part := range parts {
part = trimSpace(part)
if idx := indexOf(part, "="); idx != -1 {
rules = append(rules, rule{
name: part[:idx],
param: part[idx+1:],
})
} else {
rules = append(rules, rule{
name: part,
param: "",
})
}
}
return rules
}
// Helper functions
func containsRule(ruleStr, ruleName string) bool {
rules := parseRules(ruleStr)
for _, r := range rules {
if r.name == ruleName {
return true
}
}
return false
}
func splitComma(s string) []string {
var result []string
start := 0
for i, char := range s {
if char == ',' {
if i > start {
result = append(result, s[start:i])
}
start = i + 1
}
}
if start < len(s) {
result = append(result, s[start:])
}
return result
}
func trimSpace(s string) string {
start := 0
end := len(s)
for start < end && s[start] == ' ' {
start++
}
for end > start && s[end-1] == ' ' {
end--
}
return s[start:end]
}
func indexOf(s string, substr string) int {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}