veza/veza-backend-api/internal/core/subscription/service.go
senke 9cd0da0046 fix(v0.12.6): apply all pentest remediations — 36 findings across 36 files
CRITICAL fixes:
- Race condition (TOCTOU) in payout/refund with SELECT FOR UPDATE (CRITICAL-001/002)
- IDOR on analytics endpoint — ownership check enforced (CRITICAL-003)
- CSWSH on all WebSocket endpoints — origin whitelist (CRITICAL-004)
- Mass assignment on user self-update — strip privileged fields (CRITICAL-005)

HIGH fixes:
- Path traversal in marketplace upload — UUID filenames (HIGH-001)
- IP spoofing — use Gin trusted proxy c.ClientIP() (HIGH-002)
- Popularity metrics (followers, likes) set to json:"-" (HIGH-003)
- bcrypt cost hardened to 12 everywhere (HIGH-004)
- Refresh token lock made mandatory (HIGH-005)
- Stream token replay prevention with access_count (HIGH-006)
- Subscription trial race condition fixed (HIGH-007)
- License download expiration check (HIGH-008)
- Webhook amount validation (HIGH-009)
- pprof endpoint removed from production (HIGH-010)

MEDIUM fixes:
- WebSocket message size limit 64KB (MEDIUM-010)
- HSTS header in nginx production (MEDIUM-001)
- CORS origin restricted in nginx-rtmp (MEDIUM-002)
- Docker alpine pinned to 3.21 (MEDIUM-003/004)
- Redis authentication enforced (MEDIUM-005)
- GDPR account deletion expanded (MEDIUM-006)
- .gitignore hardened (MEDIUM-007)

LOW/INFO fixes:
- GitHub Actions SHA pinning on all workflows (LOW-001)
- .env.example security documentation (INFO-001)
- Production CORS set to HTTPS (LOW-002)

All tests pass. Go and Rust compile clean.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 00:44:46 +01:00

543 lines
16 KiB
Go

package subscription
import (
"context"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"gorm.io/gorm"
)
// Service errors
var (
ErrPlanNotFound = errors.New("subscription plan not found")
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrAlreadySubscribed = errors.New("user already has an active subscription to this plan")
ErrCannotDowngradeDuring = errors.New("downgrade takes effect at end of current period")
ErrNoActiveSubscription = errors.New("no active subscription found")
ErrInvalidBillingCycle = errors.New("invalid billing cycle: must be 'monthly' or 'yearly'")
ErrFreePlanNoBilling = errors.New("free plan does not require billing")
)
// PaymentProvider defines the interface for subscription payments
type PaymentProvider interface {
CreateSubscriptionPayment(ctx context.Context, amountCents int, currency, subscriptionID, returnURL string, metadata map[string]string) (paymentID, clientSecret string, err error)
GetPayment(ctx context.Context, paymentID string) (status string, err error)
}
// ServiceOption is a functional option for configuring the Service
type ServiceOption func(*Service)
// WithPaymentProvider sets the payment provider for the subscription service
func WithPaymentProvider(p PaymentProvider) ServiceOption {
return func(s *Service) {
s.paymentProvider = p
}
}
// Service handles subscription business logic
type Service struct {
db *gorm.DB
logger *zap.Logger
paymentProvider PaymentProvider
}
// NewService creates a new subscription service
func NewService(db *gorm.DB, logger *zap.Logger, opts ...ServiceOption) *Service {
s := &Service{
db: db,
logger: logger,
}
for _, opt := range opts {
opt(s)
}
return s
}
// ListPlans returns all active subscription plans ordered by sort_order
func (s *Service) ListPlans(ctx context.Context) ([]Plan, error) {
var plans []Plan
if err := s.db.WithContext(ctx).
Where("is_active = ?", true).
Order("sort_order ASC").
Find(&plans).Error; err != nil {
return nil, fmt.Errorf("failed to list plans: %w", err)
}
return plans, nil
}
// GetPlan returns a plan by ID
func (s *Service) GetPlan(ctx context.Context, planID uuid.UUID) (*Plan, error) {
var plan Plan
if err := s.db.WithContext(ctx).First(&plan, "id = ?", planID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrPlanNotFound
}
return nil, fmt.Errorf("failed to get plan: %w", err)
}
return &plan, nil
}
// GetPlanByName returns a plan by its name
func (s *Service) GetPlanByName(ctx context.Context, name PlanName) (*Plan, error) {
var plan Plan
if err := s.db.WithContext(ctx).First(&plan, "name = ?", string(name)).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrPlanNotFound
}
return nil, fmt.Errorf("failed to get plan by name: %w", err)
}
return &plan, nil
}
// GetUserSubscription returns the user's current active/trialing subscription
func (s *Service) GetUserSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) {
var sub UserSubscription
err := s.db.WithContext(ctx).
Preload("Plan").
Where("user_id = ? AND status IN ?", userID, []string{string(StatusActive), string(StatusTrialing)}).
First(&sub).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNoActiveSubscription
}
return nil, fmt.Errorf("failed to get user subscription: %w", err)
}
return &sub, nil
}
// GetUserSubscriptionHistory returns all subscriptions for a user (including canceled/expired)
func (s *Service) GetUserSubscriptionHistory(ctx context.Context, userID uuid.UUID, limit, offset int) ([]UserSubscription, error) {
if limit <= 0 || limit > 100 {
limit = 20
}
if offset < 0 {
offset = 0
}
var subs []UserSubscription
err := s.db.WithContext(ctx).
Preload("Plan").
Where("user_id = ?", userID).
Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&subs).Error
if err != nil {
return nil, fmt.Errorf("failed to get subscription history: %w", err)
}
return subs, nil
}
// SubscribeRequest holds the parameters for subscribing to a plan
type SubscribeRequest struct {
PlanID uuid.UUID `json:"plan_id" binding:"required"`
BillingCycle BillingCycle `json:"billing_cycle" binding:"required"`
}
// SubscribeResponse holds the result of a subscription creation
type SubscribeResponse struct {
Subscription *UserSubscription `json:"subscription"`
ClientSecret string `json:"client_secret,omitempty"` // For Hyperswitch payment
PaymentID string `json:"payment_id,omitempty"`
}
// Subscribe creates a new subscription for a user
func (s *Service) Subscribe(ctx context.Context, userID uuid.UUID, req SubscribeRequest) (*SubscribeResponse, error) {
if req.BillingCycle != BillingMonthly && req.BillingCycle != BillingYearly {
return nil, ErrInvalidBillingCycle
}
plan, err := s.GetPlan(ctx, req.PlanID)
if err != nil {
return nil, err
}
if plan.Name == PlanFree {
return s.subscribeToFreePlan(ctx, userID, plan)
}
// Check for existing active subscription
existing, err := s.GetUserSubscription(ctx, userID)
if err != nil && !errors.Is(err, ErrNoActiveSubscription) {
return nil, err
}
if existing != nil && existing.PlanID == req.PlanID {
return nil, ErrAlreadySubscribed
}
// If upgrading from a lower plan, handle the transition
if existing != nil {
return s.changePlan(ctx, userID, existing, plan, req.BillingCycle)
}
return s.createNewSubscription(ctx, userID, plan, req.BillingCycle)
}
// subscribeToFreePlan assigns the free plan without payment
func (s *Service) subscribeToFreePlan(ctx context.Context, userID uuid.UUID, plan *Plan) (*SubscribeResponse, error) {
// Cancel any existing subscription first
existing, err := s.GetUserSubscription(ctx, userID)
if err != nil && !errors.Is(err, ErrNoActiveSubscription) {
return nil, err
}
if existing != nil {
if err := s.cancelImmediately(ctx, existing); err != nil {
return nil, err
}
}
now := time.Now()
sub := &UserSubscription{
UserID: userID,
PlanID: plan.ID,
Status: StatusActive,
BillingCycle: BillingMonthly,
CurrentPeriodStart: now,
CurrentPeriodEnd: now.AddDate(100, 0, 0), // effectively never expires
}
if err := s.db.WithContext(ctx).Create(sub).Error; err != nil {
return nil, fmt.Errorf("failed to create free subscription: %w", err)
}
sub.Plan = *plan
return &SubscribeResponse{Subscription: sub}, nil
}
// createNewSubscription creates a subscription for a paid plan
func (s *Service) createNewSubscription(ctx context.Context, userID uuid.UUID, plan *Plan, cycle BillingCycle) (*SubscribeResponse, error) {
now := time.Now()
var periodEnd time.Time
var amountCents int
switch cycle {
case BillingYearly:
periodEnd = now.AddDate(1, 0, 0)
amountCents = plan.PriceYearly
default:
periodEnd = now.AddDate(0, 1, 0)
amountCents = plan.PriceMonthly
}
sub := &UserSubscription{
UserID: userID,
PlanID: plan.ID,
BillingCycle: cycle,
CurrentPeriodStart: now,
CurrentPeriodEnd: periodEnd,
}
var clientSecret, paymentID string
// SECURITY(REM-015): Trial check + subscription creation in single transaction to prevent
// race condition where two concurrent requests both see previousTrialCount=0.
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Apply trial if available — checked INSIDE transaction for atomicity
if plan.TrialDays > 0 {
var previousTrialCount int64
tx.Model(&UserSubscription{}).
Where("user_id = ? AND trial_start IS NOT NULL", userID).
Count(&previousTrialCount)
if previousTrialCount > 0 {
sub.Status = StatusActive
} else {
trialEnd := now.AddDate(0, 0, plan.TrialDays)
sub.Status = StatusTrialing
sub.TrialStart = &now
sub.TrialEnd = &trialEnd
sub.CurrentPeriodEnd = trialEnd
}
} else {
sub.Status = StatusActive
}
if err := tx.Create(sub).Error; err != nil {
return fmt.Errorf("failed to create subscription: %w", err)
}
// Create invoice (for paid plans, not during trial)
if !sub.IsTrialing() && amountCents > 0 {
invoice := &Invoice{
SubscriptionID: sub.ID,
UserID: userID,
AmountCents: amountCents,
Currency: plan.Currency,
Status: InvoicePending,
BillingPeriodStart: now,
BillingPeriodEnd: periodEnd,
}
if err := tx.Create(invoice).Error; err != nil {
return fmt.Errorf("failed to create invoice: %w", err)
}
// Initiate payment if provider is configured
if s.paymentProvider != nil {
var err error
paymentID, clientSecret, err = s.paymentProvider.CreateSubscriptionPayment(
ctx, amountCents, plan.Currency, sub.ID.String(),
"", // returnURL to be set by frontend
map[string]string{
"user_id": userID.String(),
"subscription_id": sub.ID.String(),
"plan": string(plan.Name),
"billing_cycle": string(cycle),
},
)
if err != nil {
return fmt.Errorf("failed to create payment: %w", err)
}
invoice.HyperswitchPaymentID = paymentID
if err := tx.Save(invoice).Error; err != nil {
return fmt.Errorf("failed to update invoice with payment ID: %w", err)
}
}
}
return nil
})
if err != nil {
return nil, err
}
sub.Plan = *plan
return &SubscribeResponse{
Subscription: sub,
ClientSecret: clientSecret,
PaymentID: paymentID,
}, nil
}
// changePlan handles upgrade or downgrade between plans
func (s *Service) changePlan(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) {
currentPlan := &current.Plan
if currentPlan.ID == uuid.Nil {
var err error
currentPlan, err = s.GetPlan(ctx, current.PlanID)
if err != nil {
return nil, err
}
}
isUpgrade := newPlan.SortOrder > currentPlan.SortOrder
if isUpgrade {
// Upgrade: takes effect immediately
return s.upgradeSubscription(ctx, userID, current, newPlan, cycle)
}
// Downgrade: takes effect at end of current period
return s.scheduleDowngrade(ctx, userID, current, newPlan, cycle)
}
// upgradeSubscription applies an immediate upgrade
func (s *Service) upgradeSubscription(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) {
now := time.Now()
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Expire the current subscription
current.Status = StatusExpired
if err := tx.Save(current).Error; err != nil {
return fmt.Errorf("failed to expire current subscription: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
s.logger.Info("User upgraded subscription",
zap.String("user_id", userID.String()),
zap.String("from_plan", string(current.Plan.Name)),
zap.String("to_plan", string(newPlan.Name)),
zap.Time("upgraded_at", now),
)
return s.createNewSubscription(ctx, userID, newPlan, cycle)
}
// scheduleDowngrade schedules a downgrade at end of current period
func (s *Service) scheduleDowngrade(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) {
s.logger.Info("User scheduled downgrade",
zap.String("user_id", userID.String()),
zap.String("from_plan", string(current.Plan.Name)),
zap.String("to_plan", string(newPlan.Name)),
zap.Time("effective_at", current.CurrentPeriodEnd),
)
// For now, we mark the current as cancel_at_period_end and return info
// The actual downgrade will happen when the period ends (via ProcessExpiredSubscriptions)
current.CancelAtPeriodEnd = true
now := time.Now()
current.CanceledAt = &now
if err := s.db.WithContext(ctx).Save(current).Error; err != nil {
return nil, fmt.Errorf("failed to schedule downgrade: %w", err)
}
current.Plan = *newPlan // indicate the target plan in response
return &SubscribeResponse{
Subscription: current,
}, nil
}
// CancelSubscription cancels a user's subscription at the end of the current period
func (s *Service) CancelSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) {
sub, err := s.GetUserSubscription(ctx, userID)
if err != nil {
return nil, err
}
// Free plan: cancel immediately
if sub.Plan.Name == PlanFree {
return nil, ErrFreePlanNoBilling
}
now := time.Now()
sub.CancelAtPeriodEnd = true
sub.CanceledAt = &now
if err := s.db.WithContext(ctx).Save(sub).Error; err != nil {
return nil, fmt.Errorf("failed to cancel subscription: %w", err)
}
s.logger.Info("User canceled subscription",
zap.String("user_id", userID.String()),
zap.String("plan", string(sub.Plan.Name)),
zap.Time("access_until", sub.CurrentPeriodEnd),
)
return sub, nil
}
// ReactivateSubscription removes the cancellation flag if still within the period
func (s *Service) ReactivateSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) {
sub, err := s.GetUserSubscription(ctx, userID)
if err != nil {
return nil, err
}
if !sub.CancelAtPeriodEnd {
return sub, nil // not canceled, nothing to do
}
sub.CancelAtPeriodEnd = false
sub.CanceledAt = nil
if err := s.db.WithContext(ctx).Save(sub).Error; err != nil {
return nil, fmt.Errorf("failed to reactivate subscription: %w", err)
}
s.logger.Info("User reactivated subscription",
zap.String("user_id", userID.String()),
zap.String("plan", string(sub.Plan.Name)),
)
return sub, nil
}
// cancelImmediately expires a subscription right away (used internally for plan switches)
func (s *Service) cancelImmediately(ctx context.Context, sub *UserSubscription) error {
now := time.Now()
sub.Status = StatusExpired
sub.CanceledAt = &now
return s.db.WithContext(ctx).Save(sub).Error
}
// GetUserInvoices returns invoices for a user
func (s *Service) GetUserInvoices(ctx context.Context, userID uuid.UUID, limit, offset int) ([]Invoice, error) {
if limit <= 0 || limit > 100 {
limit = 20
}
if offset < 0 {
offset = 0
}
var invoices []Invoice
err := s.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&invoices).Error
if err != nil {
return nil, fmt.Errorf("failed to get invoices: %w", err)
}
return invoices, nil
}
// ProcessExpiredSubscriptions checks for subscriptions past their period end and expires them
// This should be called periodically (e.g., daily cron job)
func (s *Service) ProcessExpiredSubscriptions(ctx context.Context) (int, error) {
now := time.Now()
var count int64
// Expire subscriptions that have cancel_at_period_end and period has ended
result := s.db.WithContext(ctx).
Model(&UserSubscription{}).
Where("cancel_at_period_end = ? AND current_period_end < ? AND status IN ?",
true, now, []string{string(StatusActive), string(StatusTrialing)}).
Updates(map[string]interface{}{
"status": StatusExpired,
"updated_at": now,
})
if result.Error != nil {
return 0, fmt.Errorf("failed to expire subscriptions: %w", result.Error)
}
count = result.RowsAffected
// Expire trials that have ended without payment
trialResult := s.db.WithContext(ctx).
Model(&UserSubscription{}).
Where("status = ? AND trial_end < ?", StatusTrialing, now).
Updates(map[string]interface{}{
"status": StatusExpired,
"updated_at": now,
})
if trialResult.Error != nil {
return int(count), fmt.Errorf("failed to expire trials: %w", trialResult.Error)
}
count += trialResult.RowsAffected
if count > 0 {
s.logger.Info("Processed expired subscriptions", zap.Int64("expired_count", count))
}
return int(count), nil
}
// ChangeBillingCycle switches between monthly and yearly billing
func (s *Service) ChangeBillingCycle(ctx context.Context, userID uuid.UUID, newCycle BillingCycle) (*UserSubscription, error) {
if newCycle != BillingMonthly && newCycle != BillingYearly {
return nil, ErrInvalidBillingCycle
}
sub, err := s.GetUserSubscription(ctx, userID)
if err != nil {
return nil, err
}
if sub.BillingCycle == newCycle {
return sub, nil // already on this cycle
}
sub.BillingCycle = newCycle
// The new cycle takes effect at the next renewal
if err := s.db.WithContext(ctx).Save(sub).Error; err != nil {
return nil, fmt.Errorf("failed to change billing cycle: %w", err)
}
s.logger.Info("User changed billing cycle",
zap.String("user_id", userID.String()),
zap.String("new_cycle", string(newCycle)),
)
return sub, nil
}