veza/veza-backend-api/cmd/tools/encrypt_oauth_tokens/main.go
2026-03-05 23:03:43 +01:00

149 lines
3.4 KiB
Go

// encrypt_oauth_tokens encrypts existing OAuth provider tokens in federated_identities (v0.902).
// Idempotent: skips tokens already prefixed with veza_enc_v1:
// Usage: DATABASE_URL=... OAUTH_ENCRYPTION_KEY=... go run ./cmd/tools/encrypt_oauth_tokens [-dry-run]
package main
import (
"context"
"database/sql"
"flag"
"log"
"os"
"strings"
"time"
"veza-backend-api/internal/services"
_ "github.com/lib/pq"
)
const encryptedPrefix = "veza_enc_v1:"
func main() {
dryRun := flag.Bool("dry-run", false, "Show what would be updated without making changes")
flag.Parse()
dbURL := os.Getenv("DATABASE_URL")
if dbURL == "" {
log.Fatal("DATABASE_URL is required")
}
encKey := os.Getenv("OAUTH_ENCRYPTION_KEY")
if encKey == "" {
log.Fatal("OAUTH_ENCRYPTION_KEY is required (32+ bytes, base64 or raw)")
}
cryptoService, err := services.NewCryptoServiceFromBase64(encKey)
if err != nil {
keyBytes := []byte(encKey)
if len(keyBytes) >= 32 {
cryptoService, err = services.NewCryptoService(keyBytes)
}
}
if err != nil {
log.Fatalf("CryptoService: %v", err)
}
db, err := sql.Open("postgres", dbURL)
if err != nil {
log.Fatalf("DB connect: %v", err)
}
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
if err := db.PingContext(ctx); err != nil {
log.Fatalf("DB ping: %v", err)
}
rows, err := db.QueryContext(ctx, `
SELECT id::text, access_token, refresh_token
FROM federated_identities
WHERE (access_token IS NOT NULL AND access_token != '')
OR (refresh_token IS NOT NULL AND refresh_token != '')
`)
if err != nil {
log.Fatalf("Query: %v", err)
}
defer rows.Close()
var id string
var accessToken, refreshToken sql.NullString
updated := 0
skipped := 0
errors := 0
for rows.Next() {
if err := rows.Scan(&id, &accessToken, &refreshToken); err != nil {
log.Printf("Scan error: %v", err)
errors++
continue
}
needUpdate := false
newAccess := accessToken.String
newRefresh := refreshToken.String
if accessToken.Valid && accessToken.String != "" && !strings.HasPrefix(accessToken.String, encryptedPrefix) {
newAccess, err = cryptoService.EncryptString(accessToken.String)
if err != nil {
log.Printf("Encrypt access_token for id %s: %v", id, err)
errors++
continue
}
needUpdate = true
}
if refreshToken.Valid && refreshToken.String != "" && !strings.HasPrefix(refreshToken.String, encryptedPrefix) {
newRefresh, err = cryptoService.EncryptString(refreshToken.String)
if err != nil {
log.Printf("Encrypt refresh_token for id %s: %v", id, err)
errors++
continue
}
needUpdate = true
}
if !needUpdate {
skipped++
continue
}
if *dryRun {
log.Printf("[dry-run] Would encrypt tokens for id %s", id)
updated++
continue
}
_, err = db.ExecContext(ctx, `
UPDATE federated_identities
SET access_token = $1, refresh_token = $2, updated_at = NOW()
WHERE id = $3
`, nullIfEmpty(newAccess), nullIfEmpty(newRefresh), id)
if err != nil {
log.Printf("UPDATE id %s: %v", id, err)
errors++
continue
}
updated++
}
if err := rows.Err(); err != nil {
log.Fatalf("Rows: %v", err)
}
mode := ""
if *dryRun {
mode = " [dry-run]"
}
log.Printf("Done%s: updated=%d skipped=%d errors=%d", mode, updated, skipped, errors)
if errors > 0 {
os.Exit(1)
}
}
func nullIfEmpty(s string) interface{} {
if s == "" {
return nil
}
return s
}