veza/veza-backend-api/internal/validators/validator.go

283 lines
9.6 KiB
Go

package validators
import (
"fmt"
"strings"
"github.com/go-playground/validator/v10"
"veza-backend-api/internal/dto"
)
// Validator est un wrapper autour de go-playground/validator
// GO-013: Validation input centralisée avec go-validator
type Validator struct {
validate *validator.Validate
}
// NewValidator crée une nouvelle instance de Validator
func NewValidator() *Validator {
v := validator.New()
// Enregistrer des validations personnalisées
registerCustomValidations(v)
return &Validator{
validate: v,
}
}
// Validate valide une structure et retourne des erreurs formatées
func (v *Validator) Validate(s interface{}) []dto.ValidationError {
var validationErrors []dto.ValidationError
err := v.validate.Struct(s)
if err != nil {
if validationErrs, ok := err.(validator.ValidationErrors); ok {
for _, fieldErr := range validationErrs {
validationErrors = append(validationErrors, dto.ValidationError{
Field: getFieldName(fieldErr),
Message: getErrorMessage(fieldErr),
Value: fmt.Sprintf("%v", fieldErr.Value()),
})
}
}
}
return validationErrors
}
// ValidateVar valide une variable unique
func (v *Validator) ValidateVar(field interface{}, tag string) error {
return v.validate.Var(field, tag)
}
// getFieldName extrait le nom du champ depuis l'erreur de validation
// GO-013: Extrait le tag JSON si disponible via StructNamespace, sinon convertit en camelCase
func getFieldName(fieldErr validator.FieldError) string {
// Utiliser StructNamespace qui donne le chemin complet (ex: "TestStruct.Name")
// et extraire le dernier segment
structNamespace := fieldErr.StructNamespace()
if structNamespace != "" {
parts := strings.Split(structNamespace, ".")
if len(parts) > 0 {
fieldName := parts[len(parts)-1]
// Convertir en camelCase pour JSON (première lettre en minuscule)
if len(fieldName) > 0 {
return strings.ToLower(fieldName[:1]) + fieldName[1:]
}
return fieldName
}
}
// Fallback: utiliser Field() et convertir en camelCase
fieldName := fieldErr.Field()
if len(fieldName) > 0 {
return strings.ToLower(fieldName[:1]) + fieldName[1:]
}
return fieldName
}
// getErrorMessage génère un message d'erreur lisible depuis l'erreur de validation
// BE-SVC-020: Messages d'erreur améliorés et plus descriptifs
func getErrorMessage(fieldErr validator.FieldError) string {
fieldName := getFieldName(fieldErr)
param := fieldErr.Param()
switch fieldErr.Tag() {
case "required":
return fmt.Sprintf("The field '%s' is required and cannot be empty", fieldName)
case "email":
return fmt.Sprintf("The field '%s' must be a valid email address (e.g., user@example.com)", fieldName)
case "min":
if fieldErr.Type().Kind().String() == "string" {
return fmt.Sprintf("The field '%s' must be at least %s characters long", fieldName, param)
}
return fmt.Sprintf("The field '%s' must be at least %s", fieldName, param)
case "max":
if fieldErr.Type().Kind().String() == "string" {
return fmt.Sprintf("The field '%s' must be at most %s characters long", fieldName, param)
}
return fmt.Sprintf("The field '%s' must be at most %s", fieldName, param)
case "len":
return fmt.Sprintf("The field '%s' must be exactly %s characters long", fieldName, param)
case "oneof":
return fmt.Sprintf("The field '%s' must be one of the following values: %s", fieldName, param)
case "eqfield":
return fmt.Sprintf("The field '%s' must equal the value of '%s'", fieldName, param)
case "nefield":
return fmt.Sprintf("The field '%s' must not equal the value of '%s'", fieldName, param)
case "uuid":
return fmt.Sprintf("The field '%s' must be a valid UUID format (e.g., 550e8400-e29b-41d4-a716-446655440000)", fieldName)
case "url":
return fmt.Sprintf("The field '%s' must be a valid URL (e.g., https://example.com)", fieldName)
case "uri":
return fmt.Sprintf("The field '%s' must be a valid URI", fieldName)
case "numeric":
return fmt.Sprintf("The field '%s' must be a numeric value", fieldName)
case "alpha":
return fmt.Sprintf("The field '%s' must contain only letters (a-z, A-Z)", fieldName)
case "alphanum":
return fmt.Sprintf("The field '%s' must contain only letters and numbers", fieldName)
case "alphaunicode":
return fmt.Sprintf("The field '%s' must contain only unicode letters", fieldName)
case "alphanumunicode":
return fmt.Sprintf("The field '%s' must contain only unicode letters and numbers", fieldName)
case "number":
return fmt.Sprintf("The field '%s' must be a valid number", fieldName)
case "gte":
return fmt.Sprintf("The field '%s' must be greater than or equal to %s", fieldName, param)
case "lte":
return fmt.Sprintf("The field '%s' must be less than or equal to %s", fieldName, param)
case "gt":
return fmt.Sprintf("The field '%s' must be greater than %s", fieldName, param)
case "lt":
return fmt.Sprintf("The field '%s' must be less than %s", fieldName, param)
case "eq":
return fmt.Sprintf("The field '%s' must equal %s", fieldName, param)
case "ne":
return fmt.Sprintf("The field '%s' must not equal %s", fieldName, param)
case "contains":
return fmt.Sprintf("The field '%s' must contain the substring '%s'", fieldName, param)
case "excludes":
return fmt.Sprintf("The field '%s' must not contain the substring '%s'", fieldName, param)
case "startswith":
return fmt.Sprintf("The field '%s' must start with '%s'", fieldName, param)
case "endswith":
return fmt.Sprintf("The field '%s' must end with '%s'", fieldName, param)
case "ip":
return fmt.Sprintf("The field '%s' must be a valid IP address", fieldName)
case "ipv4":
return fmt.Sprintf("The field '%s' must be a valid IPv4 address", fieldName)
case "ipv6":
return fmt.Sprintf("The field '%s' must be a valid IPv6 address", fieldName)
case "datetime":
return fmt.Sprintf("The field '%s' must be a valid datetime in format '%s'", fieldName, param)
case "date":
return fmt.Sprintf("The field '%s' must be a valid date", fieldName)
case "timezone":
return fmt.Sprintf("The field '%s' must be a valid timezone", fieldName)
case "base64":
return fmt.Sprintf("The field '%s' must be a valid base64 encoded string", fieldName)
case "json":
return fmt.Sprintf("The field '%s' must be a valid JSON string", fieldName)
case "username":
return fmt.Sprintf("The field '%s' must be a valid username (3-30 characters, alphanumeric and underscore only)", fieldName)
case "uuid_string":
return fmt.Sprintf("The field '%s' must be a valid UUID string", fieldName)
default:
// Pour les tags personnalisés ou inconnus, fournir un message générique mais informatif
return fmt.Sprintf("The field '%s' failed validation for tag '%s'", fieldName, fieldErr.Tag())
}
}
// registerCustomValidations enregistre des validations personnalisées
// BE-SVC-020: Ajout de validations personnalisées supplémentaires
func registerCustomValidations(v *validator.Validate) {
// Validation pour username (alphanumeric + underscore, 3-30 chars)
v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
username := fl.Field().String()
if len(username) < 3 || len(username) > 30 {
return false
}
for _, char := range username {
if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
(char >= '0' && char <= '9') || char == '_') {
return false
}
}
return true
})
// Validation pour UUID string
v.RegisterValidation("uuid_string", func(fl validator.FieldLevel) bool {
uuidStr := fl.Field().String()
if uuidStr == "" {
return true // Optionnel
}
// Utiliser le même validator pour éviter la récursion
uuidValidator := validator.New()
err := uuidValidator.Var(uuidStr, "uuid")
return err == nil
})
// Validation pour slug (alphanumeric + dash/underscore, utilisé dans les URLs)
v.RegisterValidation("slug", func(fl validator.FieldLevel) bool {
slug := fl.Field().String()
if len(slug) < 1 || len(slug) > 100 {
return false
}
for _, char := range slug {
if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
(char >= '0' && char <= '9') || char == '-' || char == '_') {
return false
}
}
return true
})
// Validation pour phone (format basique international)
v.RegisterValidation("phone", func(fl validator.FieldLevel) bool {
phone := fl.Field().String()
if phone == "" {
return true // Optionnel
}
// Format basique: + suivi de 7-15 chiffres, ou 10-15 chiffres sans +
// Enlever les espaces, tirets, parenthèses pour la validation
cleaned := strings.ReplaceAll(phone, " ", "")
cleaned = strings.ReplaceAll(cleaned, "-", "")
cleaned = strings.ReplaceAll(cleaned, "(", "")
cleaned = strings.ReplaceAll(cleaned, ")", "")
if len(cleaned) < 7 || len(cleaned) > 16 {
return false
}
// Doit commencer par + ou être uniquement des chiffres
if cleaned[0] == '+' {
cleaned = cleaned[1:]
}
// Vérifier que ce sont tous des chiffres
for _, char := range cleaned {
if char < '0' || char > '9' {
return false
}
}
return true
})
// Validation pour date ISO 8601 (YYYY-MM-DD)
v.RegisterValidation("date_iso", func(fl validator.FieldLevel) bool {
dateStr := fl.Field().String()
if dateStr == "" {
return true // Optionnel
}
// Format YYYY-MM-DD
if len(dateStr) != 10 {
return false
}
if dateStr[4] != '-' || dateStr[7] != '-' {
return false
}
// Vérifier que les parties sont numériques
for i, char := range dateStr {
if i == 4 || i == 7 {
continue
}
if char < '0' || char > '9' {
return false
}
}
return true
})
// Validation pour non-empty string (après trim)
v.RegisterValidation("not_empty", func(fl validator.FieldLevel) bool {
if fl.Field().Kind().String() != "string" {
return false
}
str := strings.TrimSpace(fl.Field().String())
return len(str) > 0
})
}