421 lines
11 KiB
Go
421 lines
11 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"time"
|
|
|
|
"veza-backend-api/internal/database"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// SessionService gère les sessions utilisateur
|
|
type SessionService struct {
|
|
db *database.Database
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// Session représente une session utilisateur
|
|
// MIGRATION UUID: ID migré vers uuid.UUID
|
|
type Session struct {
|
|
ID uuid.UUID `json:"id" db:"id"`
|
|
UserID uuid.UUID `json:"user_id" db:"user_id"`
|
|
TokenHash string `json:"-" db:"token_hash"`
|
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
|
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
|
|
RevokedAt *time.Time `json:"revoked_at" db:"revoked_at"`
|
|
IPAddress string `json:"ip_address" db:"ip_address"`
|
|
UserAgent string `json:"user_agent" db:"user_agent"`
|
|
}
|
|
|
|
// SessionCreateRequest données pour créer une session
|
|
// MIGRATION UUID: UserID migré vers uuid.UUID
|
|
type SessionCreateRequest struct {
|
|
UserID uuid.UUID `json:"user_id"`
|
|
Token string `json:"token"`
|
|
IPAddress string `json:"ip_address"`
|
|
UserAgent string `json:"user_agent"`
|
|
Metadata string `json:"metadata"` // Ignored by DB, kept for compatibility if needed
|
|
ExpiresIn time.Duration `json:"expires_in"`
|
|
}
|
|
|
|
// NewSessionService crée un nouveau service de session
|
|
func NewSessionService(db *database.Database, logger *zap.Logger) *SessionService {
|
|
return &SessionService{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// CreateSession crée une nouvelle session
|
|
func (ss *SessionService) CreateSession(ctx context.Context, req *SessionCreateRequest) (*Session, error) {
|
|
// Hasher le token pour le stockage
|
|
tokenHash := ss.hashToken(req.Token)
|
|
|
|
// Calculer la date d'expiration
|
|
// If ExpiresIn is 0, default to 24 hours
|
|
expiresIn := req.ExpiresIn
|
|
if expiresIn == 0 {
|
|
expiresIn = 24 * time.Hour
|
|
}
|
|
expiresAt := time.Now().Add(expiresIn)
|
|
|
|
// Créer la session struct
|
|
session := &Session{
|
|
ID: uuid.New(),
|
|
UserID: req.UserID,
|
|
TokenHash: tokenHash,
|
|
CreatedAt: time.Now(),
|
|
ExpiresAt: expiresAt,
|
|
IPAddress: req.IPAddress,
|
|
UserAgent: req.UserAgent,
|
|
}
|
|
|
|
// Insérer en base
|
|
query := `
|
|
INSERT INTO sessions (id, user_id, token_hash, created_at, expires_at, ip_address, user_agent)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
`
|
|
|
|
_, err := ss.db.ExecContext(ctx, query,
|
|
session.ID,
|
|
session.UserID,
|
|
session.TokenHash,
|
|
session.CreatedAt,
|
|
session.ExpiresAt,
|
|
session.IPAddress,
|
|
session.UserAgent,
|
|
)
|
|
|
|
if err != nil {
|
|
ss.logger.Error("Failed to create session",
|
|
zap.Error(err),
|
|
zap.String("user_id", req.UserID.String()),
|
|
)
|
|
return nil, fmt.Errorf("failed to create session: %w", err)
|
|
}
|
|
|
|
ss.logger.Info("Session created",
|
|
zap.String("session_id", session.ID.String()),
|
|
zap.String("user_id", req.UserID.String()),
|
|
zap.Time("expires_at", session.ExpiresAt),
|
|
)
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// ValidateSession valide une session par token hash
|
|
func (ss *SessionService) ValidateSession(ctx context.Context, token string) (*Session, error) {
|
|
tokenHash := ss.hashToken(token)
|
|
|
|
query := `
|
|
SELECT id, user_id, token_hash, created_at, expires_at, revoked_at, ip_address, user_agent
|
|
FROM sessions
|
|
WHERE token_hash = $1 AND expires_at > $2 AND revoked_at IS NULL
|
|
`
|
|
|
|
var session Session
|
|
err := ss.db.QueryRowContext(ctx, query, tokenHash, time.Now()).Scan(
|
|
&session.ID,
|
|
&session.UserID,
|
|
&session.TokenHash,
|
|
&session.CreatedAt,
|
|
&session.ExpiresAt,
|
|
&session.RevokedAt,
|
|
&session.IPAddress,
|
|
&session.UserAgent,
|
|
)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("session not found or expired")
|
|
}
|
|
ss.logger.Error("Failed to validate session",
|
|
zap.Error(err),
|
|
zap.String("token_hash", tokenHash),
|
|
)
|
|
return nil, fmt.Errorf("failed to validate session: %w", err)
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
// RevokeSession révoque une session par token
|
|
func (ss *SessionService) RevokeSession(ctx context.Context, token string) error {
|
|
tokenHash := ss.hashToken(token)
|
|
|
|
query := `
|
|
UPDATE sessions
|
|
SET revoked_at = $1
|
|
WHERE token_hash = $2 AND revoked_at IS NULL
|
|
`
|
|
|
|
result, err := ss.db.ExecContext(ctx, query, time.Now(), tokenHash)
|
|
if err != nil {
|
|
ss.logger.Error("Failed to revoke session",
|
|
zap.Error(err),
|
|
zap.String("token_hash", tokenHash),
|
|
)
|
|
return fmt.Errorf("failed to revoke session: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("session not found or already revoked")
|
|
}
|
|
|
|
ss.logger.Info("Session revoked",
|
|
zap.String("token_hash", tokenHash),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
// RevokeAllUserSessions révoque toutes les sessions d'un utilisateur
|
|
func (ss *SessionService) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) (int64, error) {
|
|
query := `
|
|
UPDATE sessions
|
|
SET revoked_at = $2
|
|
WHERE user_id = $1 AND revoked_at IS NULL
|
|
`
|
|
|
|
result, err := ss.db.ExecContext(ctx, query, userID, time.Now())
|
|
if err != nil {
|
|
ss.logger.Error("Failed to revoke user sessions",
|
|
zap.Error(err),
|
|
zap.String("user_id", userID.String()),
|
|
)
|
|
return 0, fmt.Errorf("failed to revoke user sessions: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
return rowsAffected, nil
|
|
}
|
|
|
|
// RevokeAllUserSessionsByUserID est un alias pour satisfaire l'interface attendue par AuthService
|
|
func (ss *SessionService) RevokeAllUserSessionsByUserID(ctx context.Context, userID uuid.UUID) (int64, error) {
|
|
return ss.RevokeAllUserSessions(ctx, userID)
|
|
}
|
|
|
|
// RefreshSession étend la durée d'une session
|
|
func (ss *SessionService) RefreshSession(ctx context.Context, token string, newExpiresIn time.Duration) error {
|
|
tokenHash := ss.hashToken(token)
|
|
newExpiresAt := time.Now().Add(newExpiresIn)
|
|
|
|
query := `
|
|
UPDATE sessions
|
|
SET expires_at = $1
|
|
WHERE token_hash = $2 AND revoked_at IS NULL AND expires_at > $3
|
|
`
|
|
|
|
result, err := ss.db.ExecContext(ctx, query, newExpiresAt, tokenHash, time.Now())
|
|
if err != nil {
|
|
ss.logger.Error("Failed to refresh session",
|
|
zap.Error(err),
|
|
zap.String("token_hash", tokenHash),
|
|
)
|
|
return fmt.Errorf("failed to refresh session: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("session not found or expired")
|
|
}
|
|
|
|
ss.logger.Info("Session refreshed",
|
|
zap.String("token_hash", tokenHash),
|
|
zap.Time("new_expires_at", newExpiresAt),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
// CleanupExpiredSessions supprime les sessions expirées
|
|
func (ss *SessionService) CleanupExpiredSessions(ctx context.Context) error {
|
|
query := `
|
|
DELETE FROM sessions
|
|
WHERE expires_at < $1 OR revoked_at IS NOT NULL
|
|
`
|
|
|
|
result, err := ss.db.ExecContext(ctx, query, time.Now())
|
|
if err != nil {
|
|
ss.logger.Error("Failed to cleanup expired sessions", zap.Error(err))
|
|
return fmt.Errorf("failed to cleanup expired sessions: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected > 0 {
|
|
ss.logger.Info("Expired sessions cleaned up", zap.Int64("count", rowsAffected))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// hashToken hashe un token pour le stockage
|
|
func (ss *SessionService) hashToken(token string) string {
|
|
hash := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// GetSessionStats retourne les statistiques des sessions
|
|
func (ss *SessionService) GetSessionStats(ctx context.Context) (map[string]interface{}, error) {
|
|
query := `
|
|
SELECT
|
|
COUNT(*) as total_active,
|
|
COUNT(DISTINCT user_id) as unique_users
|
|
FROM sessions
|
|
WHERE expires_at > $1 AND revoked_at IS NULL
|
|
`
|
|
|
|
var totalActive, uniqueUsers int64
|
|
err := ss.db.QueryRowContext(ctx, query, time.Now()).Scan(&totalActive, &uniqueUsers)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get session stats: %w", err)
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"total_active": totalActive,
|
|
"unique_users": uniqueUsers,
|
|
}, nil
|
|
}
|
|
|
|
// GetSessionByID récupère une session par ID
|
|
func (ss *SessionService) GetSessionByID(sessionID uuid.UUID) (*Session, error) {
|
|
ctx := context.Background()
|
|
query := `
|
|
SELECT id, user_id, token_hash, created_at, expires_at, revoked_at, ip_address, user_agent
|
|
FROM sessions
|
|
WHERE id = $1
|
|
`
|
|
|
|
var session Session
|
|
err := ss.db.QueryRowContext(ctx, query, sessionID).Scan(
|
|
&session.ID,
|
|
&session.UserID,
|
|
&session.TokenHash,
|
|
&session.CreatedAt,
|
|
&session.ExpiresAt,
|
|
&session.RevokedAt,
|
|
&session.IPAddress,
|
|
&session.UserAgent,
|
|
)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("session not found")
|
|
}
|
|
ss.logger.Error("Failed to get session by ID",
|
|
zap.Error(err),
|
|
zap.String("session_id", sessionID.String()),
|
|
)
|
|
return nil, fmt.Errorf("failed to get session by ID: %w", err)
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
// GetUserSessions récupère toutes les sessions d'un utilisateur
|
|
func (ss *SessionService) GetUserSessions(userID uuid.UUID) ([]*Session, error) {
|
|
ctx := context.Background()
|
|
query := `
|
|
SELECT id, user_id, token_hash, created_at, expires_at, revoked_at, ip_address, user_agent
|
|
FROM sessions
|
|
WHERE user_id = $1 AND expires_at > $2 AND revoked_at IS NULL
|
|
ORDER BY created_at DESC
|
|
`
|
|
|
|
rows, err := ss.db.QueryContext(ctx, query, userID, time.Now())
|
|
if err != nil {
|
|
ss.logger.Error("Failed to get user sessions",
|
|
zap.Error(err),
|
|
zap.String("user_id", userID.String()),
|
|
)
|
|
return nil, fmt.Errorf("failed to get user sessions: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sessions []*Session
|
|
for rows.Next() {
|
|
var session Session
|
|
if err := rows.Scan(
|
|
&session.ID,
|
|
&session.UserID,
|
|
&session.TokenHash,
|
|
&session.CreatedAt,
|
|
&session.ExpiresAt,
|
|
&session.RevokedAt,
|
|
&session.IPAddress,
|
|
&session.UserAgent,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("failed to scan session: %w", err)
|
|
}
|
|
sessions = append(sessions, &session)
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
// HashTokenForMiddleware hashe un token (pour usage middleware/handler)
|
|
func (ss *SessionService) HashTokenForMiddleware(token string) string {
|
|
return ss.hashToken(token)
|
|
}
|
|
|
|
// DeleteSession révoque une session (alias pour RevokeSession, utilisé par les handlers)
|
|
func (ss *SessionService) DeleteSession(tokenHash string) error {
|
|
// Note: tokenHash is already hashed. RevokeSession expects raw token.
|
|
// But DeleteSession takes tokenHash.
|
|
// We need a method to revoke by hash.
|
|
|
|
ctx := context.Background()
|
|
query := `
|
|
UPDATE sessions
|
|
SET revoked_at = $2
|
|
WHERE token_hash = $1 AND revoked_at IS NULL
|
|
`
|
|
|
|
result, err := ss.db.ExecContext(ctx, query, tokenHash, time.Now())
|
|
if err != nil {
|
|
ss.logger.Error("Failed to revoke session by hash",
|
|
zap.Error(err),
|
|
zap.String("token_hash", tokenHash),
|
|
)
|
|
return fmt.Errorf("failed to revoke session: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("session not found or already revoked")
|
|
}
|
|
|
|
ss.logger.Info("Session revoked by hash",
|
|
zap.String("token_hash", tokenHash),
|
|
)
|
|
|
|
return nil
|
|
}
|