2026-02-26 18:49:15 +00:00
|
|
|
// encrypt_oauth_tokens encrypts existing OAuth provider tokens in federated_identities (v0.902).
|
|
|
|
|
// Idempotent: skips tokens already prefixed with veza_enc_v1:
|
|
|
|
|
// Usage: DATABASE_URL=... OAUTH_ENCRYPTION_KEY=... go run ./cmd/tools/encrypt_oauth_tokens [-dry-run]
|
|
|
|
|
package main
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"database/sql"
|
|
|
|
|
"flag"
|
|
|
|
|
"log"
|
|
|
|
|
"os"
|
|
|
|
|
"strings"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"veza-backend-api/internal/services"
|
2026-03-05 22:03:43 +00:00
|
|
|
|
|
|
|
|
_ "github.com/lib/pq"
|
2026-02-26 18:49:15 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const encryptedPrefix = "veza_enc_v1:"
|
|
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
|
dryRun := flag.Bool("dry-run", false, "Show what would be updated without making changes")
|
|
|
|
|
flag.Parse()
|
|
|
|
|
|
|
|
|
|
dbURL := os.Getenv("DATABASE_URL")
|
|
|
|
|
if dbURL == "" {
|
|
|
|
|
log.Fatal("DATABASE_URL is required")
|
|
|
|
|
}
|
|
|
|
|
encKey := os.Getenv("OAUTH_ENCRYPTION_KEY")
|
|
|
|
|
if encKey == "" {
|
|
|
|
|
log.Fatal("OAUTH_ENCRYPTION_KEY is required (32+ bytes, base64 or raw)")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cryptoService, err := services.NewCryptoServiceFromBase64(encKey)
|
|
|
|
|
if err != nil {
|
|
|
|
|
keyBytes := []byte(encKey)
|
|
|
|
|
if len(keyBytes) >= 32 {
|
|
|
|
|
cryptoService, err = services.NewCryptoService(keyBytes)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatalf("CryptoService: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
db, err := sql.Open("postgres", dbURL)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatalf("DB connect: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer db.Close()
|
|
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
|
|
|
log.Fatalf("DB ping: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rows, err := db.QueryContext(ctx, `
|
|
|
|
|
SELECT id::text, access_token, refresh_token
|
|
|
|
|
FROM federated_identities
|
|
|
|
|
WHERE (access_token IS NOT NULL AND access_token != '')
|
|
|
|
|
OR (refresh_token IS NOT NULL AND refresh_token != '')
|
|
|
|
|
`)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatalf("Query: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer rows.Close()
|
|
|
|
|
|
|
|
|
|
var id string
|
|
|
|
|
var accessToken, refreshToken sql.NullString
|
|
|
|
|
updated := 0
|
|
|
|
|
skipped := 0
|
|
|
|
|
errors := 0
|
|
|
|
|
|
|
|
|
|
for rows.Next() {
|
|
|
|
|
if err := rows.Scan(&id, &accessToken, &refreshToken); err != nil {
|
|
|
|
|
log.Printf("Scan error: %v", err)
|
|
|
|
|
errors++
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
needUpdate := false
|
|
|
|
|
newAccess := accessToken.String
|
|
|
|
|
newRefresh := refreshToken.String
|
|
|
|
|
|
|
|
|
|
if accessToken.Valid && accessToken.String != "" && !strings.HasPrefix(accessToken.String, encryptedPrefix) {
|
|
|
|
|
newAccess, err = cryptoService.EncryptString(accessToken.String)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("Encrypt access_token for id %s: %v", id, err)
|
|
|
|
|
errors++
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
needUpdate = true
|
|
|
|
|
}
|
|
|
|
|
if refreshToken.Valid && refreshToken.String != "" && !strings.HasPrefix(refreshToken.String, encryptedPrefix) {
|
|
|
|
|
newRefresh, err = cryptoService.EncryptString(refreshToken.String)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("Encrypt refresh_token for id %s: %v", id, err)
|
|
|
|
|
errors++
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
needUpdate = true
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !needUpdate {
|
|
|
|
|
skipped++
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if *dryRun {
|
|
|
|
|
log.Printf("[dry-run] Would encrypt tokens for id %s", id)
|
|
|
|
|
updated++
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = db.ExecContext(ctx, `
|
|
|
|
|
UPDATE federated_identities
|
|
|
|
|
SET access_token = $1, refresh_token = $2, updated_at = NOW()
|
|
|
|
|
WHERE id = $3
|
|
|
|
|
`, nullIfEmpty(newAccess), nullIfEmpty(newRefresh), id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("UPDATE id %s: %v", id, err)
|
|
|
|
|
errors++
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
updated++
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := rows.Err(); err != nil {
|
|
|
|
|
log.Fatalf("Rows: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mode := ""
|
|
|
|
|
if *dryRun {
|
|
|
|
|
mode = " [dry-run]"
|
|
|
|
|
}
|
|
|
|
|
log.Printf("Done%s: updated=%d skipped=%d errors=%d", mode, updated, skipped, errors)
|
|
|
|
|
if errors > 0 {
|
|
|
|
|
os.Exit(1)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func nullIfEmpty(s string) interface{} {
|
|
|
|
|
if s == "" {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return s
|
|
|
|
|
}
|