CRITICAL fixes: - Race condition (TOCTOU) in payout/refund with SELECT FOR UPDATE (CRITICAL-001/002) - IDOR on analytics endpoint — ownership check enforced (CRITICAL-003) - CSWSH on all WebSocket endpoints — origin whitelist (CRITICAL-004) - Mass assignment on user self-update — strip privileged fields (CRITICAL-005) HIGH fixes: - Path traversal in marketplace upload — UUID filenames (HIGH-001) - IP spoofing — use Gin trusted proxy c.ClientIP() (HIGH-002) - Popularity metrics (followers, likes) set to json:"-" (HIGH-003) - bcrypt cost hardened to 12 everywhere (HIGH-004) - Refresh token lock made mandatory (HIGH-005) - Stream token replay prevention with access_count (HIGH-006) - Subscription trial race condition fixed (HIGH-007) - License download expiration check (HIGH-008) - Webhook amount validation (HIGH-009) - pprof endpoint removed from production (HIGH-010) MEDIUM fixes: - WebSocket message size limit 64KB (MEDIUM-010) - HSTS header in nginx production (MEDIUM-001) - CORS origin restricted in nginx-rtmp (MEDIUM-002) - Docker alpine pinned to 3.21 (MEDIUM-003/004) - Redis authentication enforced (MEDIUM-005) - GDPR account deletion expanded (MEDIUM-006) - .gitignore hardened (MEDIUM-007) LOW/INFO fixes: - GitHub Actions SHA pinning on all workflows (LOW-001) - .env.example security documentation (INFO-001) - Production CORS set to HTTPS (LOW-002) All tests pass. Go and Rust compile clean. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
543 lines
16 KiB
Go
543 lines
16 KiB
Go
package subscription
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// Service errors
|
|
var (
|
|
ErrPlanNotFound = errors.New("subscription plan not found")
|
|
ErrSubscriptionNotFound = errors.New("subscription not found")
|
|
ErrAlreadySubscribed = errors.New("user already has an active subscription to this plan")
|
|
ErrCannotDowngradeDuring = errors.New("downgrade takes effect at end of current period")
|
|
ErrNoActiveSubscription = errors.New("no active subscription found")
|
|
ErrInvalidBillingCycle = errors.New("invalid billing cycle: must be 'monthly' or 'yearly'")
|
|
ErrFreePlanNoBilling = errors.New("free plan does not require billing")
|
|
)
|
|
|
|
// PaymentProvider defines the interface for subscription payments
|
|
type PaymentProvider interface {
|
|
CreateSubscriptionPayment(ctx context.Context, amountCents int, currency, subscriptionID, returnURL string, metadata map[string]string) (paymentID, clientSecret string, err error)
|
|
GetPayment(ctx context.Context, paymentID string) (status string, err error)
|
|
}
|
|
|
|
// ServiceOption is a functional option for configuring the Service
|
|
type ServiceOption func(*Service)
|
|
|
|
// WithPaymentProvider sets the payment provider for the subscription service
|
|
func WithPaymentProvider(p PaymentProvider) ServiceOption {
|
|
return func(s *Service) {
|
|
s.paymentProvider = p
|
|
}
|
|
}
|
|
|
|
// Service handles subscription business logic
|
|
type Service struct {
|
|
db *gorm.DB
|
|
logger *zap.Logger
|
|
paymentProvider PaymentProvider
|
|
}
|
|
|
|
// NewService creates a new subscription service
|
|
func NewService(db *gorm.DB, logger *zap.Logger, opts ...ServiceOption) *Service {
|
|
s := &Service{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(s)
|
|
}
|
|
return s
|
|
}
|
|
|
|
// ListPlans returns all active subscription plans ordered by sort_order
|
|
func (s *Service) ListPlans(ctx context.Context) ([]Plan, error) {
|
|
var plans []Plan
|
|
if err := s.db.WithContext(ctx).
|
|
Where("is_active = ?", true).
|
|
Order("sort_order ASC").
|
|
Find(&plans).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to list plans: %w", err)
|
|
}
|
|
return plans, nil
|
|
}
|
|
|
|
// GetPlan returns a plan by ID
|
|
func (s *Service) GetPlan(ctx context.Context, planID uuid.UUID) (*Plan, error) {
|
|
var plan Plan
|
|
if err := s.db.WithContext(ctx).First(&plan, "id = ?", planID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrPlanNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get plan: %w", err)
|
|
}
|
|
return &plan, nil
|
|
}
|
|
|
|
// GetPlanByName returns a plan by its name
|
|
func (s *Service) GetPlanByName(ctx context.Context, name PlanName) (*Plan, error) {
|
|
var plan Plan
|
|
if err := s.db.WithContext(ctx).First(&plan, "name = ?", string(name)).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrPlanNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get plan by name: %w", err)
|
|
}
|
|
return &plan, nil
|
|
}
|
|
|
|
// GetUserSubscription returns the user's current active/trialing subscription
|
|
func (s *Service) GetUserSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) {
|
|
var sub UserSubscription
|
|
err := s.db.WithContext(ctx).
|
|
Preload("Plan").
|
|
Where("user_id = ? AND status IN ?", userID, []string{string(StatusActive), string(StatusTrialing)}).
|
|
First(&sub).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNoActiveSubscription
|
|
}
|
|
return nil, fmt.Errorf("failed to get user subscription: %w", err)
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// GetUserSubscriptionHistory returns all subscriptions for a user (including canceled/expired)
|
|
func (s *Service) GetUserSubscriptionHistory(ctx context.Context, userID uuid.UUID, limit, offset int) ([]UserSubscription, error) {
|
|
if limit <= 0 || limit > 100 {
|
|
limit = 20
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
var subs []UserSubscription
|
|
err := s.db.WithContext(ctx).
|
|
Preload("Plan").
|
|
Where("user_id = ?", userID).
|
|
Order("created_at DESC").
|
|
Limit(limit).
|
|
Offset(offset).
|
|
Find(&subs).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get subscription history: %w", err)
|
|
}
|
|
return subs, nil
|
|
}
|
|
|
|
// SubscribeRequest holds the parameters for subscribing to a plan
|
|
type SubscribeRequest struct {
|
|
PlanID uuid.UUID `json:"plan_id" binding:"required"`
|
|
BillingCycle BillingCycle `json:"billing_cycle" binding:"required"`
|
|
}
|
|
|
|
// SubscribeResponse holds the result of a subscription creation
|
|
type SubscribeResponse struct {
|
|
Subscription *UserSubscription `json:"subscription"`
|
|
ClientSecret string `json:"client_secret,omitempty"` // For Hyperswitch payment
|
|
PaymentID string `json:"payment_id,omitempty"`
|
|
}
|
|
|
|
// Subscribe creates a new subscription for a user
|
|
func (s *Service) Subscribe(ctx context.Context, userID uuid.UUID, req SubscribeRequest) (*SubscribeResponse, error) {
|
|
if req.BillingCycle != BillingMonthly && req.BillingCycle != BillingYearly {
|
|
return nil, ErrInvalidBillingCycle
|
|
}
|
|
|
|
plan, err := s.GetPlan(ctx, req.PlanID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if plan.Name == PlanFree {
|
|
return s.subscribeToFreePlan(ctx, userID, plan)
|
|
}
|
|
|
|
// Check for existing active subscription
|
|
existing, err := s.GetUserSubscription(ctx, userID)
|
|
if err != nil && !errors.Is(err, ErrNoActiveSubscription) {
|
|
return nil, err
|
|
}
|
|
|
|
if existing != nil && existing.PlanID == req.PlanID {
|
|
return nil, ErrAlreadySubscribed
|
|
}
|
|
|
|
// If upgrading from a lower plan, handle the transition
|
|
if existing != nil {
|
|
return s.changePlan(ctx, userID, existing, plan, req.BillingCycle)
|
|
}
|
|
|
|
return s.createNewSubscription(ctx, userID, plan, req.BillingCycle)
|
|
}
|
|
|
|
// subscribeToFreePlan assigns the free plan without payment
|
|
func (s *Service) subscribeToFreePlan(ctx context.Context, userID uuid.UUID, plan *Plan) (*SubscribeResponse, error) {
|
|
// Cancel any existing subscription first
|
|
existing, err := s.GetUserSubscription(ctx, userID)
|
|
if err != nil && !errors.Is(err, ErrNoActiveSubscription) {
|
|
return nil, err
|
|
}
|
|
if existing != nil {
|
|
if err := s.cancelImmediately(ctx, existing); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
now := time.Now()
|
|
sub := &UserSubscription{
|
|
UserID: userID,
|
|
PlanID: plan.ID,
|
|
Status: StatusActive,
|
|
BillingCycle: BillingMonthly,
|
|
CurrentPeriodStart: now,
|
|
CurrentPeriodEnd: now.AddDate(100, 0, 0), // effectively never expires
|
|
}
|
|
|
|
if err := s.db.WithContext(ctx).Create(sub).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to create free subscription: %w", err)
|
|
}
|
|
|
|
sub.Plan = *plan
|
|
return &SubscribeResponse{Subscription: sub}, nil
|
|
}
|
|
|
|
// createNewSubscription creates a subscription for a paid plan
|
|
func (s *Service) createNewSubscription(ctx context.Context, userID uuid.UUID, plan *Plan, cycle BillingCycle) (*SubscribeResponse, error) {
|
|
now := time.Now()
|
|
var periodEnd time.Time
|
|
var amountCents int
|
|
|
|
switch cycle {
|
|
case BillingYearly:
|
|
periodEnd = now.AddDate(1, 0, 0)
|
|
amountCents = plan.PriceYearly
|
|
default:
|
|
periodEnd = now.AddDate(0, 1, 0)
|
|
amountCents = plan.PriceMonthly
|
|
}
|
|
|
|
sub := &UserSubscription{
|
|
UserID: userID,
|
|
PlanID: plan.ID,
|
|
BillingCycle: cycle,
|
|
CurrentPeriodStart: now,
|
|
CurrentPeriodEnd: periodEnd,
|
|
}
|
|
|
|
var clientSecret, paymentID string
|
|
|
|
// SECURITY(REM-015): Trial check + subscription creation in single transaction to prevent
|
|
// race condition where two concurrent requests both see previousTrialCount=0.
|
|
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// Apply trial if available — checked INSIDE transaction for atomicity
|
|
if plan.TrialDays > 0 {
|
|
var previousTrialCount int64
|
|
tx.Model(&UserSubscription{}).
|
|
Where("user_id = ? AND trial_start IS NOT NULL", userID).
|
|
Count(&previousTrialCount)
|
|
if previousTrialCount > 0 {
|
|
sub.Status = StatusActive
|
|
} else {
|
|
trialEnd := now.AddDate(0, 0, plan.TrialDays)
|
|
sub.Status = StatusTrialing
|
|
sub.TrialStart = &now
|
|
sub.TrialEnd = &trialEnd
|
|
sub.CurrentPeriodEnd = trialEnd
|
|
}
|
|
} else {
|
|
sub.Status = StatusActive
|
|
}
|
|
|
|
if err := tx.Create(sub).Error; err != nil {
|
|
return fmt.Errorf("failed to create subscription: %w", err)
|
|
}
|
|
|
|
// Create invoice (for paid plans, not during trial)
|
|
if !sub.IsTrialing() && amountCents > 0 {
|
|
invoice := &Invoice{
|
|
SubscriptionID: sub.ID,
|
|
UserID: userID,
|
|
AmountCents: amountCents,
|
|
Currency: plan.Currency,
|
|
Status: InvoicePending,
|
|
BillingPeriodStart: now,
|
|
BillingPeriodEnd: periodEnd,
|
|
}
|
|
if err := tx.Create(invoice).Error; err != nil {
|
|
return fmt.Errorf("failed to create invoice: %w", err)
|
|
}
|
|
|
|
// Initiate payment if provider is configured
|
|
if s.paymentProvider != nil {
|
|
var err error
|
|
paymentID, clientSecret, err = s.paymentProvider.CreateSubscriptionPayment(
|
|
ctx, amountCents, plan.Currency, sub.ID.String(),
|
|
"", // returnURL to be set by frontend
|
|
map[string]string{
|
|
"user_id": userID.String(),
|
|
"subscription_id": sub.ID.String(),
|
|
"plan": string(plan.Name),
|
|
"billing_cycle": string(cycle),
|
|
},
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create payment: %w", err)
|
|
}
|
|
invoice.HyperswitchPaymentID = paymentID
|
|
if err := tx.Save(invoice).Error; err != nil {
|
|
return fmt.Errorf("failed to update invoice with payment ID: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sub.Plan = *plan
|
|
return &SubscribeResponse{
|
|
Subscription: sub,
|
|
ClientSecret: clientSecret,
|
|
PaymentID: paymentID,
|
|
}, nil
|
|
}
|
|
|
|
// changePlan handles upgrade or downgrade between plans
|
|
func (s *Service) changePlan(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) {
|
|
currentPlan := ¤t.Plan
|
|
if currentPlan.ID == uuid.Nil {
|
|
var err error
|
|
currentPlan, err = s.GetPlan(ctx, current.PlanID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
isUpgrade := newPlan.SortOrder > currentPlan.SortOrder
|
|
|
|
if isUpgrade {
|
|
// Upgrade: takes effect immediately
|
|
return s.upgradeSubscription(ctx, userID, current, newPlan, cycle)
|
|
}
|
|
|
|
// Downgrade: takes effect at end of current period
|
|
return s.scheduleDowngrade(ctx, userID, current, newPlan, cycle)
|
|
}
|
|
|
|
// upgradeSubscription applies an immediate upgrade
|
|
func (s *Service) upgradeSubscription(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) {
|
|
now := time.Now()
|
|
|
|
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// Expire the current subscription
|
|
current.Status = StatusExpired
|
|
if err := tx.Save(current).Error; err != nil {
|
|
return fmt.Errorf("failed to expire current subscription: %w", err)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.logger.Info("User upgraded subscription",
|
|
zap.String("user_id", userID.String()),
|
|
zap.String("from_plan", string(current.Plan.Name)),
|
|
zap.String("to_plan", string(newPlan.Name)),
|
|
zap.Time("upgraded_at", now),
|
|
)
|
|
|
|
return s.createNewSubscription(ctx, userID, newPlan, cycle)
|
|
}
|
|
|
|
// scheduleDowngrade schedules a downgrade at end of current period
|
|
func (s *Service) scheduleDowngrade(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) {
|
|
s.logger.Info("User scheduled downgrade",
|
|
zap.String("user_id", userID.String()),
|
|
zap.String("from_plan", string(current.Plan.Name)),
|
|
zap.String("to_plan", string(newPlan.Name)),
|
|
zap.Time("effective_at", current.CurrentPeriodEnd),
|
|
)
|
|
|
|
// For now, we mark the current as cancel_at_period_end and return info
|
|
// The actual downgrade will happen when the period ends (via ProcessExpiredSubscriptions)
|
|
current.CancelAtPeriodEnd = true
|
|
now := time.Now()
|
|
current.CanceledAt = &now
|
|
|
|
if err := s.db.WithContext(ctx).Save(current).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to schedule downgrade: %w", err)
|
|
}
|
|
|
|
current.Plan = *newPlan // indicate the target plan in response
|
|
return &SubscribeResponse{
|
|
Subscription: current,
|
|
}, nil
|
|
}
|
|
|
|
// CancelSubscription cancels a user's subscription at the end of the current period
|
|
func (s *Service) CancelSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) {
|
|
sub, err := s.GetUserSubscription(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Free plan: cancel immediately
|
|
if sub.Plan.Name == PlanFree {
|
|
return nil, ErrFreePlanNoBilling
|
|
}
|
|
|
|
now := time.Now()
|
|
sub.CancelAtPeriodEnd = true
|
|
sub.CanceledAt = &now
|
|
|
|
if err := s.db.WithContext(ctx).Save(sub).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to cancel subscription: %w", err)
|
|
}
|
|
|
|
s.logger.Info("User canceled subscription",
|
|
zap.String("user_id", userID.String()),
|
|
zap.String("plan", string(sub.Plan.Name)),
|
|
zap.Time("access_until", sub.CurrentPeriodEnd),
|
|
)
|
|
|
|
return sub, nil
|
|
}
|
|
|
|
// ReactivateSubscription removes the cancellation flag if still within the period
|
|
func (s *Service) ReactivateSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) {
|
|
sub, err := s.GetUserSubscription(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !sub.CancelAtPeriodEnd {
|
|
return sub, nil // not canceled, nothing to do
|
|
}
|
|
|
|
sub.CancelAtPeriodEnd = false
|
|
sub.CanceledAt = nil
|
|
|
|
if err := s.db.WithContext(ctx).Save(sub).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to reactivate subscription: %w", err)
|
|
}
|
|
|
|
s.logger.Info("User reactivated subscription",
|
|
zap.String("user_id", userID.String()),
|
|
zap.String("plan", string(sub.Plan.Name)),
|
|
)
|
|
|
|
return sub, nil
|
|
}
|
|
|
|
// cancelImmediately expires a subscription right away (used internally for plan switches)
|
|
func (s *Service) cancelImmediately(ctx context.Context, sub *UserSubscription) error {
|
|
now := time.Now()
|
|
sub.Status = StatusExpired
|
|
sub.CanceledAt = &now
|
|
return s.db.WithContext(ctx).Save(sub).Error
|
|
}
|
|
|
|
// GetUserInvoices returns invoices for a user
|
|
func (s *Service) GetUserInvoices(ctx context.Context, userID uuid.UUID, limit, offset int) ([]Invoice, error) {
|
|
if limit <= 0 || limit > 100 {
|
|
limit = 20
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
var invoices []Invoice
|
|
err := s.db.WithContext(ctx).
|
|
Where("user_id = ?", userID).
|
|
Order("created_at DESC").
|
|
Limit(limit).
|
|
Offset(offset).
|
|
Find(&invoices).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get invoices: %w", err)
|
|
}
|
|
return invoices, nil
|
|
}
|
|
|
|
// ProcessExpiredSubscriptions checks for subscriptions past their period end and expires them
|
|
// This should be called periodically (e.g., daily cron job)
|
|
func (s *Service) ProcessExpiredSubscriptions(ctx context.Context) (int, error) {
|
|
now := time.Now()
|
|
var count int64
|
|
|
|
// Expire subscriptions that have cancel_at_period_end and period has ended
|
|
result := s.db.WithContext(ctx).
|
|
Model(&UserSubscription{}).
|
|
Where("cancel_at_period_end = ? AND current_period_end < ? AND status IN ?",
|
|
true, now, []string{string(StatusActive), string(StatusTrialing)}).
|
|
Updates(map[string]interface{}{
|
|
"status": StatusExpired,
|
|
"updated_at": now,
|
|
})
|
|
|
|
if result.Error != nil {
|
|
return 0, fmt.Errorf("failed to expire subscriptions: %w", result.Error)
|
|
}
|
|
count = result.RowsAffected
|
|
|
|
// Expire trials that have ended without payment
|
|
trialResult := s.db.WithContext(ctx).
|
|
Model(&UserSubscription{}).
|
|
Where("status = ? AND trial_end < ?", StatusTrialing, now).
|
|
Updates(map[string]interface{}{
|
|
"status": StatusExpired,
|
|
"updated_at": now,
|
|
})
|
|
|
|
if trialResult.Error != nil {
|
|
return int(count), fmt.Errorf("failed to expire trials: %w", trialResult.Error)
|
|
}
|
|
count += trialResult.RowsAffected
|
|
|
|
if count > 0 {
|
|
s.logger.Info("Processed expired subscriptions", zap.Int64("expired_count", count))
|
|
}
|
|
|
|
return int(count), nil
|
|
}
|
|
|
|
// ChangeBillingCycle switches between monthly and yearly billing
|
|
func (s *Service) ChangeBillingCycle(ctx context.Context, userID uuid.UUID, newCycle BillingCycle) (*UserSubscription, error) {
|
|
if newCycle != BillingMonthly && newCycle != BillingYearly {
|
|
return nil, ErrInvalidBillingCycle
|
|
}
|
|
|
|
sub, err := s.GetUserSubscription(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if sub.BillingCycle == newCycle {
|
|
return sub, nil // already on this cycle
|
|
}
|
|
|
|
sub.BillingCycle = newCycle
|
|
// The new cycle takes effect at the next renewal
|
|
if err := s.db.WithContext(ctx).Save(sub).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to change billing cycle: %w", err)
|
|
}
|
|
|
|
s.logger.Info("User changed billing cycle",
|
|
zap.String("user_id", userID.String()),
|
|
zap.String("new_cycle", string(newCycle)),
|
|
)
|
|
|
|
return sub, nil
|
|
}
|