package services import ( "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "errors" "fmt" "io" ) const ( // NonceSize for AES-GCM (12 bytes recommended by NIST) gcmNonceSize = 12 // Prefix for encrypted tokens stored in DB (detect already-encrypted) encryptedTokenPrefix = "veza_enc_v1:" ) // CryptoService provides AES-256-GCM encryption for sensitive data (e.g. OAuth tokens at rest) type CryptoService struct { aead cipher.AEAD } // NewCryptoService creates a CryptoService with the given key (32 bytes for AES-256) func NewCryptoService(key []byte) (*CryptoService, error) { if len(key) < 32 { return nil, errors.New("encryption key must be at least 32 bytes for AES-256") } // Use first 32 bytes key32 := key if len(key) > 32 { key32 = key[:32] } block, err := aes.NewCipher(key32) if err != nil { return nil, fmt.Errorf("aes new cipher: %w", err) } aead, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("gcm: %w", err) } return &CryptoService{aead: aead}, nil } // NewCryptoServiceFromBase64 creates a CryptoService from a base64-encoded key func NewCryptoServiceFromBase64(keyBase64 string) (*CryptoService, error) { if keyBase64 == "" { return nil, errors.New("encryption key must not be empty") } key, err := base64.RawStdEncoding.DecodeString(keyBase64) if err != nil { // Try standard base64 key, err = base64.StdEncoding.DecodeString(keyBase64) if err != nil { return nil, fmt.Errorf("decode key: %w", err) } } return NewCryptoService(key) } // NewCryptoServiceFromHex creates a CryptoService from a hex-encoded key (optional, for future) // For now we use base64. Key can also be raw bytes if passed as string - we'll decode. // Encrypt encrypts plaintext with AES-256-GCM. Returns base64-encoded result. func (c *CryptoService) Encrypt(plaintext []byte) ([]byte, error) { nonce := make([]byte, gcmNonceSize) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, fmt.Errorf("rand nonce: %w", err) } ciphertext := c.aead.Seal(nil, nonce, plaintext, nil) // Prepend nonce: nonce || ciphertext (ciphertext includes tag) out := make([]byte, 0, len(nonce)+len(ciphertext)) out = append(out, nonce...) out = append(out, ciphertext...) return out, nil } // Decrypt decrypts ciphertext (format: nonce || sealed). Returns plaintext. func (c *CryptoService) Decrypt(ciphertext []byte) ([]byte, error) { if len(ciphertext) < gcmNonceSize { return nil, errors.New("ciphertext too short") } nonce := ciphertext[:gcmNonceSize] sealed := ciphertext[gcmNonceSize:] return c.aead.Open(nil, nonce, sealed, nil) } // EncryptString encrypts a string and returns the prefixed base64 result for DB storage func (c *CryptoService) EncryptString(plaintext string) (string, error) { if plaintext == "" { return "", nil } enc, err := c.Encrypt([]byte(plaintext)) if err != nil { return "", err } return encryptedTokenPrefix + base64.RawStdEncoding.EncodeToString(enc), nil } // DecryptString decrypts a string stored with EncryptString (checks prefix) func (c *CryptoService) DecryptString(stored string) (string, error) { if stored == "" { return "", nil } if len(stored) < len(encryptedTokenPrefix) || stored[:len(encryptedTokenPrefix)] != encryptedTokenPrefix { // Not encrypted (legacy plaintext) return stored, nil } b64 := stored[len(encryptedTokenPrefix):] enc, err := base64.RawStdEncoding.DecodeString(b64) if err != nil { enc, err = base64.StdEncoding.DecodeString(b64) if err != nil { return "", fmt.Errorf("decode stored: %w", err) } } dec, err := c.Decrypt(enc) if err != nil { return "", err } return string(dec), nil } // EncryptedTokenPrefix returns the prefix used for encrypted tokens func EncryptedTokenPrefix() string { return encryptedTokenPrefix }