veza/veza-backend-api/internal/services/session_service.go
2025-12-16 11:23:49 -05:00

421 lines
11 KiB
Go

package services
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"time"
"veza-backend-api/internal/database"
"github.com/google/uuid"
"go.uber.org/zap"
)
// SessionService gère les sessions utilisateur
type SessionService struct {
db *database.Database
logger *zap.Logger
}
// Session représente une session utilisateur
// MIGRATION UUID: ID migré vers uuid.UUID
type Session struct {
ID uuid.UUID `json:"id" db:"id"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
TokenHash string `json:"-" db:"token_hash"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
RevokedAt *time.Time `json:"revoked_at" db:"revoked_at"`
IPAddress string `json:"ip_address" db:"ip_address"`
UserAgent string `json:"user_agent" db:"user_agent"`
}
// SessionCreateRequest données pour créer une session
// MIGRATION UUID: UserID migré vers uuid.UUID
type SessionCreateRequest struct {
UserID uuid.UUID `json:"user_id"`
Token string `json:"token"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
Metadata string `json:"metadata"` // Ignored by DB, kept for compatibility if needed
ExpiresIn time.Duration `json:"expires_in"`
}
// NewSessionService crée un nouveau service de session
func NewSessionService(db *database.Database, logger *zap.Logger) *SessionService {
return &SessionService{
db: db,
logger: logger,
}
}
// CreateSession crée une nouvelle session
func (ss *SessionService) CreateSession(ctx context.Context, req *SessionCreateRequest) (*Session, error) {
// Hasher le token pour le stockage
tokenHash := ss.hashToken(req.Token)
// Calculer la date d'expiration
// If ExpiresIn is 0, default to 24 hours
expiresIn := req.ExpiresIn
if expiresIn == 0 {
expiresIn = 24 * time.Hour
}
expiresAt := time.Now().Add(expiresIn)
// Créer la session struct
session := &Session{
ID: uuid.New(),
UserID: req.UserID,
TokenHash: tokenHash,
CreatedAt: time.Now(),
ExpiresAt: expiresAt,
IPAddress: req.IPAddress,
UserAgent: req.UserAgent,
}
// Insérer en base
query := `
INSERT INTO sessions (id, user_id, token_hash, created_at, expires_at, ip_address, user_agent)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := ss.db.ExecContext(ctx, query,
session.ID,
session.UserID,
session.TokenHash,
session.CreatedAt,
session.ExpiresAt,
session.IPAddress,
session.UserAgent,
)
if err != nil {
ss.logger.Error("Failed to create session",
zap.Error(err),
zap.String("user_id", req.UserID.String()),
)
return nil, fmt.Errorf("failed to create session: %w", err)
}
ss.logger.Info("Session created",
zap.String("session_id", session.ID.String()),
zap.String("user_id", req.UserID.String()),
zap.Time("expires_at", session.ExpiresAt),
)
return session, nil
}
// ValidateSession valide une session par token hash
func (ss *SessionService) ValidateSession(ctx context.Context, token string) (*Session, error) {
tokenHash := ss.hashToken(token)
query := `
SELECT id, user_id, token_hash, created_at, expires_at, revoked_at, ip_address, user_agent
FROM sessions
WHERE token_hash = $1 AND expires_at > $2 AND revoked_at IS NULL
`
var session Session
err := ss.db.QueryRowContext(ctx, query, tokenHash, time.Now()).Scan(
&session.ID,
&session.UserID,
&session.TokenHash,
&session.CreatedAt,
&session.ExpiresAt,
&session.RevokedAt,
&session.IPAddress,
&session.UserAgent,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found or expired")
}
ss.logger.Error("Failed to validate session",
zap.Error(err),
zap.String("token_hash", tokenHash),
)
return nil, fmt.Errorf("failed to validate session: %w", err)
}
return &session, nil
}
// RevokeSession révoque une session par token
func (ss *SessionService) RevokeSession(ctx context.Context, token string) error {
tokenHash := ss.hashToken(token)
query := `
UPDATE sessions
SET revoked_at = $1
WHERE token_hash = $2 AND revoked_at IS NULL
`
result, err := ss.db.ExecContext(ctx, query, time.Now(), tokenHash)
if err != nil {
ss.logger.Error("Failed to revoke session",
zap.Error(err),
zap.String("token_hash", tokenHash),
)
return fmt.Errorf("failed to revoke session: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("session not found or already revoked")
}
ss.logger.Info("Session revoked",
zap.String("token_hash", tokenHash),
)
return nil
}
// RevokeAllUserSessions révoque toutes les sessions d'un utilisateur
func (ss *SessionService) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) (int64, error) {
query := `
UPDATE sessions
SET revoked_at = $2
WHERE user_id = $1 AND revoked_at IS NULL
`
result, err := ss.db.ExecContext(ctx, query, userID, time.Now())
if err != nil {
ss.logger.Error("Failed to revoke user sessions",
zap.Error(err),
zap.String("user_id", userID.String()),
)
return 0, fmt.Errorf("failed to revoke user sessions: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return rowsAffected, nil
}
// RevokeAllUserSessionsByUserID est un alias pour satisfaire l'interface attendue par AuthService
func (ss *SessionService) RevokeAllUserSessionsByUserID(ctx context.Context, userID uuid.UUID) (int64, error) {
return ss.RevokeAllUserSessions(ctx, userID)
}
// RefreshSession étend la durée d'une session
func (ss *SessionService) RefreshSession(ctx context.Context, token string, newExpiresIn time.Duration) error {
tokenHash := ss.hashToken(token)
newExpiresAt := time.Now().Add(newExpiresIn)
query := `
UPDATE sessions
SET expires_at = $1
WHERE token_hash = $2 AND revoked_at IS NULL AND expires_at > $3
`
result, err := ss.db.ExecContext(ctx, query, newExpiresAt, tokenHash, time.Now())
if err != nil {
ss.logger.Error("Failed to refresh session",
zap.Error(err),
zap.String("token_hash", tokenHash),
)
return fmt.Errorf("failed to refresh session: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("session not found or expired")
}
ss.logger.Info("Session refreshed",
zap.String("token_hash", tokenHash),
zap.Time("new_expires_at", newExpiresAt),
)
return nil
}
// CleanupExpiredSessions supprime les sessions expirées
func (ss *SessionService) CleanupExpiredSessions(ctx context.Context) error {
query := `
DELETE FROM sessions
WHERE expires_at < $1 OR revoked_at IS NOT NULL
`
result, err := ss.db.ExecContext(ctx, query, time.Now())
if err != nil {
ss.logger.Error("Failed to cleanup expired sessions", zap.Error(err))
return fmt.Errorf("failed to cleanup expired sessions: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected > 0 {
ss.logger.Info("Expired sessions cleaned up", zap.Int64("count", rowsAffected))
}
return nil
}
// hashToken hashe un token pour le stockage
func (ss *SessionService) hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// GetSessionStats retourne les statistiques des sessions
func (ss *SessionService) GetSessionStats(ctx context.Context) (map[string]interface{}, error) {
query := `
SELECT
COUNT(*) as total_active,
COUNT(DISTINCT user_id) as unique_users
FROM sessions
WHERE expires_at > $1 AND revoked_at IS NULL
`
var totalActive, uniqueUsers int64
err := ss.db.QueryRowContext(ctx, query, time.Now()).Scan(&totalActive, &uniqueUsers)
if err != nil {
return nil, fmt.Errorf("failed to get session stats: %w", err)
}
return map[string]interface{}{
"total_active": totalActive,
"unique_users": uniqueUsers,
}, nil
}
// GetSessionByID récupère une session par ID
func (ss *SessionService) GetSessionByID(sessionID uuid.UUID) (*Session, error) {
ctx := context.Background()
query := `
SELECT id, user_id, token_hash, created_at, expires_at, revoked_at, ip_address, user_agent
FROM sessions
WHERE id = $1
`
var session Session
err := ss.db.QueryRowContext(ctx, query, sessionID).Scan(
&session.ID,
&session.UserID,
&session.TokenHash,
&session.CreatedAt,
&session.ExpiresAt,
&session.RevokedAt,
&session.IPAddress,
&session.UserAgent,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found")
}
ss.logger.Error("Failed to get session by ID",
zap.Error(err),
zap.String("session_id", sessionID.String()),
)
return nil, fmt.Errorf("failed to get session by ID: %w", err)
}
return &session, nil
}
// GetUserSessions récupère toutes les sessions d'un utilisateur
func (ss *SessionService) GetUserSessions(userID uuid.UUID) ([]*Session, error) {
ctx := context.Background()
query := `
SELECT id, user_id, token_hash, created_at, expires_at, revoked_at, ip_address, user_agent
FROM sessions
WHERE user_id = $1 AND expires_at > $2 AND revoked_at IS NULL
ORDER BY created_at DESC
`
rows, err := ss.db.QueryContext(ctx, query, userID, time.Now())
if err != nil {
ss.logger.Error("Failed to get user sessions",
zap.Error(err),
zap.String("user_id", userID.String()),
)
return nil, fmt.Errorf("failed to get user sessions: %w", err)
}
defer rows.Close()
var sessions []*Session
for rows.Next() {
var session Session
if err := rows.Scan(
&session.ID,
&session.UserID,
&session.TokenHash,
&session.CreatedAt,
&session.ExpiresAt,
&session.RevokedAt,
&session.IPAddress,
&session.UserAgent,
); err != nil {
return nil, fmt.Errorf("failed to scan session: %w", err)
}
sessions = append(sessions, &session)
}
return sessions, nil
}
// HashTokenForMiddleware hashe un token (pour usage middleware/handler)
func (ss *SessionService) HashTokenForMiddleware(token string) string {
return ss.hashToken(token)
}
// DeleteSession révoque une session (alias pour RevokeSession, utilisé par les handlers)
func (ss *SessionService) DeleteSession(tokenHash string) error {
// Note: tokenHash is already hashed. RevokeSession expects raw token.
// But DeleteSession takes tokenHash.
// We need a method to revoke by hash.
ctx := context.Background()
query := `
UPDATE sessions
SET revoked_at = $2
WHERE token_hash = $1 AND revoked_at IS NULL
`
result, err := ss.db.ExecContext(ctx, query, tokenHash, time.Now())
if err != nil {
ss.logger.Error("Failed to revoke session by hash",
zap.Error(err),
zap.String("token_hash", tokenHash),
)
return fmt.Errorf("failed to revoke session: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("session not found or already revoked")
}
ss.logger.Info("Session revoked by hash",
zap.String("token_hash", tokenHash),
)
return nil
}