263 lines
7.6 KiB
Go
263 lines
7.6 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base32"
|
|
"encoding/json"
|
|
"fmt"
|
|
mathrand "math/rand"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"veza-backend-api/internal/database"
|
|
"veza-backend-api/internal/models"
|
|
|
|
"github.com/pquerna/otp/totp"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// TwoFactorService handles 2FA operations
|
|
type TwoFactorService struct {
|
|
db *database.Database
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewTwoFactorService creates a new 2FA service
|
|
func NewTwoFactorService(db *database.Database, logger *zap.Logger) *TwoFactorService {
|
|
return &TwoFactorService{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// TwoFactorSetup represents 2FA setup information
|
|
type TwoFactorSetup struct {
|
|
Secret string `json:"secret"`
|
|
QRCodeURL string `json:"qr_code_url"`
|
|
RecoveryCodes []string `json:"recovery_codes"`
|
|
}
|
|
|
|
// TwoFactorVerification represents 2FA verification
|
|
type TwoFactorVerification struct {
|
|
Code string `json:"code" binding:"required"`
|
|
RecoveryCode string `json:"recovery_code,omitempty"`
|
|
}
|
|
|
|
// GenerateSecret generates a new TOTP secret
|
|
func (s *TwoFactorService) GenerateSecret(user *models.User) (*TwoFactorSetup, error) {
|
|
// Generate a random secret
|
|
secret := make([]byte, 20)
|
|
if _, err := rand.Read(secret); err != nil {
|
|
return nil, fmt.Errorf("failed to generate secret: %w", err)
|
|
}
|
|
|
|
// Encode as base32
|
|
secretBase32 := base32.StdEncoding.EncodeToString(secret)
|
|
|
|
// Generate QR code URL
|
|
qrCodeURL := fmt.Sprintf("otpauth://totp/Veza:%s?secret=%s&issuer=Veza&algorithm=SHA1&digits=6&period=30",
|
|
user.Email, secretBase32)
|
|
|
|
// Generate recovery codes
|
|
recoveryCodes := s.generateRecoveryCodes()
|
|
|
|
setup := &TwoFactorSetup{
|
|
Secret: secretBase32,
|
|
QRCodeURL: qrCodeURL,
|
|
RecoveryCodes: recoveryCodes,
|
|
}
|
|
|
|
return setup, nil
|
|
}
|
|
|
|
// EnableTwoFactor enables 2FA for a user
|
|
func (s *TwoFactorService) EnableTwoFactor(ctx context.Context, userID uuid.UUID, secret string, recoveryCodes []string) error {
|
|
// Hash the recovery codes before storing
|
|
hashedCodes := make([]string, len(recoveryCodes))
|
|
for i, code := range recoveryCodes {
|
|
hashedCodes[i] = s.hashRecoveryCode(code)
|
|
}
|
|
|
|
// Serialize backup_codes as JSON (column is TEXT/JSONB; driver does not accept []string)
|
|
backupCodesJSON, err := json.Marshal(hashedCodes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal backup codes: %w", err)
|
|
}
|
|
|
|
query := `
|
|
UPDATE users
|
|
SET two_factor_enabled = true,
|
|
two_factor_secret = $1,
|
|
backup_codes = $2,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = $3
|
|
`
|
|
|
|
_, err = s.db.ExecContext(ctx, query, secret, string(backupCodesJSON), userID)
|
|
if err != nil {
|
|
s.logger.Error("Failed to enable 2FA", zap.Error(err), zap.String("user_id", userID.String()))
|
|
return fmt.Errorf("failed to enable 2FA: %w", err)
|
|
}
|
|
|
|
s.logger.Info("2FA enabled successfully", zap.String("user_id", userID.String()))
|
|
return nil
|
|
}
|
|
|
|
// DisableTwoFactor disables 2FA for a user
|
|
func (s *TwoFactorService) DisableTwoFactor(ctx context.Context, userID uuid.UUID) error {
|
|
query := `
|
|
UPDATE users
|
|
SET two_factor_enabled = false,
|
|
two_factor_secret = '',
|
|
backup_codes = '{}',
|
|
updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = $1
|
|
`
|
|
|
|
_, err := s.db.ExecContext(ctx, query, userID)
|
|
if err != nil {
|
|
s.logger.Error("Failed to disable 2FA", zap.Error(err), zap.String("user_id", userID.String()))
|
|
return fmt.Errorf("failed to disable 2FA: %w", err)
|
|
}
|
|
|
|
s.logger.Info("2FA disabled successfully", zap.String("user_id", userID.String()))
|
|
return nil
|
|
}
|
|
|
|
// VerifyTwoFactor verifies a 2FA code
|
|
func (s *TwoFactorService) VerifyTwoFactor(ctx context.Context, userID uuid.UUID, code string) (bool, error) {
|
|
// Get user's 2FA secret; backup_codes stored as JSON (TEXT/JSONB)
|
|
var secret string
|
|
var backupCodesRaw []byte
|
|
query := `SELECT two_factor_secret, backup_codes FROM users WHERE id = $1 AND two_factor_enabled = true`
|
|
|
|
err := s.db.QueryRowContext(ctx, query, userID).Scan(&secret, &backupCodesRaw)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return false, fmt.Errorf("2FA not enabled for user")
|
|
}
|
|
return false, fmt.Errorf("failed to get 2FA secret: %w", err)
|
|
}
|
|
|
|
var recoveryCodes []string
|
|
if len(backupCodesRaw) > 0 {
|
|
if err := json.Unmarshal(backupCodesRaw, &recoveryCodes); err != nil {
|
|
return false, fmt.Errorf("failed to unmarshal backup codes: %w", err)
|
|
}
|
|
}
|
|
|
|
// Check if it's a recovery code
|
|
if s.isRecoveryCode(code, recoveryCodes) {
|
|
// Remove the used recovery code
|
|
s.removeRecoveryCode(ctx, userID, code)
|
|
return true, nil
|
|
}
|
|
|
|
// Verify TOTP code
|
|
valid := totp.Validate(code, secret)
|
|
if !valid {
|
|
s.logger.Warn("Invalid 2FA code", zap.String("user_id", userID.String()))
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// VerifyTOTPCode verifies a TOTP code against a secret
|
|
// BE-API-001: Helper method for 2FA verification
|
|
func (s *TwoFactorService) VerifyTOTPCode(secret, code string) bool {
|
|
return totp.Validate(code, secret)
|
|
}
|
|
|
|
// GetTwoFactorStatus gets the 2FA status for a user
|
|
func (s *TwoFactorService) GetTwoFactorStatus(ctx context.Context, userID uuid.UUID) (bool, error) {
|
|
var enabled bool
|
|
query := `SELECT two_factor_enabled FROM users WHERE id = $1`
|
|
|
|
err := s.db.QueryRowContext(ctx, query, userID).Scan(&enabled)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to get 2FA status: %w", err)
|
|
}
|
|
|
|
return enabled, nil
|
|
}
|
|
|
|
// GenerateRecoveryCodes generates 8 recovery codes (public method)
|
|
// BE-API-001: Public method for generating recovery codes
|
|
func (s *TwoFactorService) GenerateRecoveryCodes() []string {
|
|
return s.generateRecoveryCodes()
|
|
}
|
|
|
|
// generateRecoveryCodes generates 8 recovery codes (internal)
|
|
func (s *TwoFactorService) generateRecoveryCodes() []string {
|
|
codes := make([]string, 8)
|
|
for i := 0; i < 8; i++ {
|
|
// Generate 8-character alphanumeric code
|
|
code := make([]byte, 8)
|
|
for j := 0; j < 8; j++ {
|
|
code[j] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"[mathrand.Intn(36)]
|
|
}
|
|
codes[i] = string(code)
|
|
}
|
|
return codes
|
|
}
|
|
|
|
// hashRecoveryCode hashes a recovery code for storage
|
|
func (s *TwoFactorService) hashRecoveryCode(code string) string {
|
|
// In production, use proper hashing (bcrypt, argon2, etc.)
|
|
// For now, using a simple hash for demonstration
|
|
return fmt.Sprintf("hashed_%s", code)
|
|
}
|
|
|
|
// isRecoveryCode checks if a code is a valid recovery code
|
|
func (s *TwoFactorService) isRecoveryCode(code string, storedCodes []string) bool {
|
|
for _, storedCode := range storedCodes {
|
|
if s.hashRecoveryCode(code) == storedCode {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// removeRecoveryCode removes a used recovery code
|
|
func (s *TwoFactorService) removeRecoveryCode(ctx context.Context, userID uuid.UUID, usedCode string) {
|
|
var backupCodesRaw []byte
|
|
query := `SELECT backup_codes FROM users WHERE id = $1`
|
|
|
|
err := s.db.QueryRowContext(ctx, query, userID).Scan(&backupCodesRaw)
|
|
if err != nil {
|
|
s.logger.Error("Failed to get recovery codes", zap.Error(err))
|
|
return
|
|
}
|
|
|
|
var recoveryCodes []string
|
|
if len(backupCodesRaw) > 0 {
|
|
if err := json.Unmarshal(backupCodesRaw, &recoveryCodes); err != nil {
|
|
s.logger.Error("Failed to unmarshal backup codes", zap.Error(err))
|
|
return
|
|
}
|
|
}
|
|
|
|
// Remove the used code
|
|
newCodes := make([]string, 0)
|
|
hashedUsedCode := s.hashRecoveryCode(usedCode)
|
|
for _, code := range recoveryCodes {
|
|
if code != hashedUsedCode {
|
|
newCodes = append(newCodes, code)
|
|
}
|
|
}
|
|
|
|
newCodesJSON, err := json.Marshal(newCodes)
|
|
if err != nil {
|
|
s.logger.Error("Failed to marshal backup codes", zap.Error(err))
|
|
return
|
|
}
|
|
|
|
updateQuery := `UPDATE users SET backup_codes = $1, updated_at = CURRENT_TIMESTAMP WHERE id = $2`
|
|
_, err = s.db.ExecContext(ctx, updateQuery, string(newCodesJSON), userID)
|
|
if err != nil {
|
|
s.logger.Error("Failed to remove recovery code", zap.Error(err))
|
|
}
|
|
}
|