veza/veza-backend-api/internal/services/two_factor_service.go

263 lines
7.6 KiB
Go

package services
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base32"
"encoding/json"
"fmt"
mathrand "math/rand"
"github.com/google/uuid"
"veza-backend-api/internal/database"
"veza-backend-api/internal/models"
"github.com/pquerna/otp/totp"
"go.uber.org/zap"
)
// TwoFactorService handles 2FA operations
type TwoFactorService struct {
db *database.Database
logger *zap.Logger
}
// NewTwoFactorService creates a new 2FA service
func NewTwoFactorService(db *database.Database, logger *zap.Logger) *TwoFactorService {
return &TwoFactorService{
db: db,
logger: logger,
}
}
// TwoFactorSetup represents 2FA setup information
type TwoFactorSetup struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
RecoveryCodes []string `json:"recovery_codes"`
}
// TwoFactorVerification represents 2FA verification
type TwoFactorVerification struct {
Code string `json:"code" binding:"required"`
RecoveryCode string `json:"recovery_code,omitempty"`
}
// GenerateSecret generates a new TOTP secret
func (s *TwoFactorService) GenerateSecret(user *models.User) (*TwoFactorSetup, error) {
// Generate a random secret
secret := make([]byte, 20)
if _, err := rand.Read(secret); err != nil {
return nil, fmt.Errorf("failed to generate secret: %w", err)
}
// Encode as base32
secretBase32 := base32.StdEncoding.EncodeToString(secret)
// Generate QR code URL
qrCodeURL := fmt.Sprintf("otpauth://totp/Veza:%s?secret=%s&issuer=Veza&algorithm=SHA1&digits=6&period=30",
user.Email, secretBase32)
// Generate recovery codes
recoveryCodes := s.generateRecoveryCodes()
setup := &TwoFactorSetup{
Secret: secretBase32,
QRCodeURL: qrCodeURL,
RecoveryCodes: recoveryCodes,
}
return setup, nil
}
// EnableTwoFactor enables 2FA for a user
func (s *TwoFactorService) EnableTwoFactor(ctx context.Context, userID uuid.UUID, secret string, recoveryCodes []string) error {
// Hash the recovery codes before storing
hashedCodes := make([]string, len(recoveryCodes))
for i, code := range recoveryCodes {
hashedCodes[i] = s.hashRecoveryCode(code)
}
// Serialize backup_codes as JSON (column is TEXT/JSONB; driver does not accept []string)
backupCodesJSON, err := json.Marshal(hashedCodes)
if err != nil {
return fmt.Errorf("failed to marshal backup codes: %w", err)
}
query := `
UPDATE users
SET two_factor_enabled = true,
two_factor_secret = $1,
backup_codes = $2,
updated_at = CURRENT_TIMESTAMP
WHERE id = $3
`
_, err = s.db.ExecContext(ctx, query, secret, string(backupCodesJSON), userID)
if err != nil {
s.logger.Error("Failed to enable 2FA", zap.Error(err), zap.String("user_id", userID.String()))
return fmt.Errorf("failed to enable 2FA: %w", err)
}
s.logger.Info("2FA enabled successfully", zap.String("user_id", userID.String()))
return nil
}
// DisableTwoFactor disables 2FA for a user
func (s *TwoFactorService) DisableTwoFactor(ctx context.Context, userID uuid.UUID) error {
query := `
UPDATE users
SET two_factor_enabled = false,
two_factor_secret = '',
backup_codes = '{}',
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
`
_, err := s.db.ExecContext(ctx, query, userID)
if err != nil {
s.logger.Error("Failed to disable 2FA", zap.Error(err), zap.String("user_id", userID.String()))
return fmt.Errorf("failed to disable 2FA: %w", err)
}
s.logger.Info("2FA disabled successfully", zap.String("user_id", userID.String()))
return nil
}
// VerifyTwoFactor verifies a 2FA code
func (s *TwoFactorService) VerifyTwoFactor(ctx context.Context, userID uuid.UUID, code string) (bool, error) {
// Get user's 2FA secret; backup_codes stored as JSON (TEXT/JSONB)
var secret string
var backupCodesRaw []byte
query := `SELECT two_factor_secret, backup_codes FROM users WHERE id = $1 AND two_factor_enabled = true`
err := s.db.QueryRowContext(ctx, query, userID).Scan(&secret, &backupCodesRaw)
if err != nil {
if err == sql.ErrNoRows {
return false, fmt.Errorf("2FA not enabled for user")
}
return false, fmt.Errorf("failed to get 2FA secret: %w", err)
}
var recoveryCodes []string
if len(backupCodesRaw) > 0 {
if err := json.Unmarshal(backupCodesRaw, &recoveryCodes); err != nil {
return false, fmt.Errorf("failed to unmarshal backup codes: %w", err)
}
}
// Check if it's a recovery code
if s.isRecoveryCode(code, recoveryCodes) {
// Remove the used recovery code
s.removeRecoveryCode(ctx, userID, code)
return true, nil
}
// Verify TOTP code
valid := totp.Validate(code, secret)
if !valid {
s.logger.Warn("Invalid 2FA code", zap.String("user_id", userID.String()))
return false, nil
}
return true, nil
}
// VerifyTOTPCode verifies a TOTP code against a secret
// BE-API-001: Helper method for 2FA verification
func (s *TwoFactorService) VerifyTOTPCode(secret, code string) bool {
return totp.Validate(code, secret)
}
// GetTwoFactorStatus gets the 2FA status for a user
func (s *TwoFactorService) GetTwoFactorStatus(ctx context.Context, userID uuid.UUID) (bool, error) {
var enabled bool
query := `SELECT two_factor_enabled FROM users WHERE id = $1`
err := s.db.QueryRowContext(ctx, query, userID).Scan(&enabled)
if err != nil {
return false, fmt.Errorf("failed to get 2FA status: %w", err)
}
return enabled, nil
}
// GenerateRecoveryCodes generates 8 recovery codes (public method)
// BE-API-001: Public method for generating recovery codes
func (s *TwoFactorService) GenerateRecoveryCodes() []string {
return s.generateRecoveryCodes()
}
// generateRecoveryCodes generates 8 recovery codes (internal)
func (s *TwoFactorService) generateRecoveryCodes() []string {
codes := make([]string, 8)
for i := 0; i < 8; i++ {
// Generate 8-character alphanumeric code
code := make([]byte, 8)
for j := 0; j < 8; j++ {
code[j] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"[mathrand.Intn(36)]
}
codes[i] = string(code)
}
return codes
}
// hashRecoveryCode hashes a recovery code for storage
func (s *TwoFactorService) hashRecoveryCode(code string) string {
// In production, use proper hashing (bcrypt, argon2, etc.)
// For now, using a simple hash for demonstration
return fmt.Sprintf("hashed_%s", code)
}
// isRecoveryCode checks if a code is a valid recovery code
func (s *TwoFactorService) isRecoveryCode(code string, storedCodes []string) bool {
for _, storedCode := range storedCodes {
if s.hashRecoveryCode(code) == storedCode {
return true
}
}
return false
}
// removeRecoveryCode removes a used recovery code
func (s *TwoFactorService) removeRecoveryCode(ctx context.Context, userID uuid.UUID, usedCode string) {
var backupCodesRaw []byte
query := `SELECT backup_codes FROM users WHERE id = $1`
err := s.db.QueryRowContext(ctx, query, userID).Scan(&backupCodesRaw)
if err != nil {
s.logger.Error("Failed to get recovery codes", zap.Error(err))
return
}
var recoveryCodes []string
if len(backupCodesRaw) > 0 {
if err := json.Unmarshal(backupCodesRaw, &recoveryCodes); err != nil {
s.logger.Error("Failed to unmarshal backup codes", zap.Error(err))
return
}
}
// Remove the used code
newCodes := make([]string, 0)
hashedUsedCode := s.hashRecoveryCode(usedCode)
for _, code := range recoveryCodes {
if code != hashedUsedCode {
newCodes = append(newCodes, code)
}
}
newCodesJSON, err := json.Marshal(newCodes)
if err != nil {
s.logger.Error("Failed to marshal backup codes", zap.Error(err))
return
}
updateQuery := `UPDATE users SET backup_codes = $1, updated_at = CURRENT_TIMESTAMP WHERE id = $2`
_, err = s.db.ExecContext(ctx, updateQuery, string(newCodesJSON), userID)
if err != nil {
s.logger.Error("Failed to remove recovery code", zap.Error(err))
}
}