217 lines
5.1 KiB
Go
217 lines
5.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
|
|
"veza-backend-api/internal/validators"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// QueryParamValidation middleware pour valider les paramètres de requête (BE-SVC-020)
|
|
type QueryParamValidation struct {
|
|
logger *zap.Logger
|
|
validator *validators.Validator
|
|
}
|
|
|
|
// NewQueryParamValidation crée un nouveau middleware de validation des paramètres de requête
|
|
func NewQueryParamValidation(logger *zap.Logger) *QueryParamValidation {
|
|
return &QueryParamValidation{
|
|
logger: logger,
|
|
validator: validators.NewValidator(),
|
|
}
|
|
}
|
|
|
|
// ValidateQueryParams valide les paramètres de requête selon les règles spécifiées
|
|
// Usage:
|
|
//
|
|
// router.GET("/users", queryValidation.ValidateQueryParams(map[string]string{
|
|
// "page": "numeric,min=1",
|
|
// "limit": "numeric,min=1,max=100",
|
|
// "sort": "oneof=asc,desc",
|
|
// }), handler)
|
|
func (q *QueryParamValidation) ValidateQueryParams(rules map[string]string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
errors := make(map[string]string)
|
|
|
|
for param, rule := range rules {
|
|
value := c.Query(param)
|
|
if value == "" {
|
|
// Si le paramètre est requis, on vérifie avec "required" dans la règle
|
|
if containsRule(rule, "required") {
|
|
errors[param] = "This query parameter is required"
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Valider selon les règles
|
|
if err := q.validateParam(param, value, rule); err != nil {
|
|
errors[param] = err.Error()
|
|
}
|
|
}
|
|
|
|
if len(errors) > 0 {
|
|
q.logger.Warn("Query parameter validation failed",
|
|
zap.Any("errors", errors),
|
|
zap.String("path", c.Request.URL.Path),
|
|
)
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": "Invalid query parameters",
|
|
"details": errors,
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// validateParam valide un paramètre selon une règle
|
|
func (q *QueryParamValidation) validateParam(param, value, rule string) error {
|
|
rules := parseRules(rule)
|
|
|
|
for _, r := range rules {
|
|
switch r.name {
|
|
case "numeric":
|
|
if _, err := strconv.ParseFloat(value, 64); err != nil {
|
|
return &ValidationError{Field: param, Message: "must be a numeric value"}
|
|
}
|
|
case "integer":
|
|
if _, err := strconv.Atoi(value); err != nil {
|
|
return &ValidationError{Field: param, Message: "must be an integer"}
|
|
}
|
|
case "min":
|
|
if num, err := strconv.ParseFloat(value, 64); err == nil {
|
|
if min, err2 := strconv.ParseFloat(r.param, 64); err2 == nil {
|
|
if num < min {
|
|
return &ValidationError{Field: param, Message: "must be at least " + r.param}
|
|
}
|
|
}
|
|
}
|
|
case "max":
|
|
if num, err := strconv.ParseFloat(value, 64); err == nil {
|
|
if max, err2 := strconv.ParseFloat(r.param, 64); err2 == nil {
|
|
if num > max {
|
|
return &ValidationError{Field: param, Message: "must be at most " + r.param}
|
|
}
|
|
}
|
|
}
|
|
case "oneof":
|
|
validValues := splitComma(r.param)
|
|
found := false
|
|
for _, valid := range validValues {
|
|
valid = trimSpace(valid)
|
|
if value == valid {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return &ValidationError{Field: param, Message: "must be one of: " + r.param}
|
|
}
|
|
case "email":
|
|
if err := q.validator.ValidateVar(value, "email"); err != nil {
|
|
return &ValidationError{Field: param, Message: "must be a valid email address"}
|
|
}
|
|
case "uuid":
|
|
if err := q.validator.ValidateVar(value, "uuid"); err != nil {
|
|
return &ValidationError{Field: param, Message: "must be a valid UUID"}
|
|
}
|
|
case "url":
|
|
if err := q.validator.ValidateVar(value, "url"); err != nil {
|
|
return &ValidationError{Field: param, Message: "must be a valid URL"}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidationError représente une erreur de validation
|
|
type ValidationError struct {
|
|
Field string
|
|
Message string
|
|
}
|
|
|
|
func (e *ValidationError) Error() string {
|
|
return e.Message
|
|
}
|
|
|
|
// rule représente une règle de validation
|
|
type rule struct {
|
|
name string
|
|
param string
|
|
}
|
|
|
|
// parseRules parse une chaîne de règles (ex: "numeric,min=1,max=100")
|
|
func parseRules(ruleStr string) []rule {
|
|
var rules []rule
|
|
parts := splitComma(ruleStr)
|
|
for _, part := range parts {
|
|
part = trimSpace(part)
|
|
if idx := indexOf(part, "="); idx != -1 {
|
|
rules = append(rules, rule{
|
|
name: part[:idx],
|
|
param: part[idx+1:],
|
|
})
|
|
} else {
|
|
rules = append(rules, rule{
|
|
name: part,
|
|
param: "",
|
|
})
|
|
}
|
|
}
|
|
return rules
|
|
}
|
|
|
|
// Helper functions
|
|
func containsRule(ruleStr, ruleName string) bool {
|
|
rules := parseRules(ruleStr)
|
|
for _, r := range rules {
|
|
if r.name == ruleName {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func splitComma(s string) []string {
|
|
var result []string
|
|
start := 0
|
|
for i, char := range s {
|
|
if char == ',' {
|
|
if i > start {
|
|
result = append(result, s[start:i])
|
|
}
|
|
start = i + 1
|
|
}
|
|
}
|
|
if start < len(s) {
|
|
result = append(result, s[start:])
|
|
}
|
|
return result
|
|
}
|
|
|
|
func trimSpace(s string) string {
|
|
start := 0
|
|
end := len(s)
|
|
for start < end && s[start] == ' ' {
|
|
start++
|
|
}
|
|
for end > start && s[end-1] == ' ' {
|
|
end--
|
|
}
|
|
return s[start:end]
|
|
}
|
|
|
|
func indexOf(s string, substr string) int {
|
|
for i := 0; i <= len(s)-len(substr); i++ {
|
|
if s[i:i+len(substr)] == substr {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|