package subscription import ( "context" "errors" "fmt" "time" "github.com/google/uuid" "go.uber.org/zap" "gorm.io/gorm" ) // Service errors var ( ErrPlanNotFound = errors.New("subscription plan not found") ErrSubscriptionNotFound = errors.New("subscription not found") ErrAlreadySubscribed = errors.New("user already has an active subscription to this plan") ErrCannotDowngradeDuring = errors.New("downgrade takes effect at end of current period") ErrNoActiveSubscription = errors.New("no active subscription found") ErrInvalidBillingCycle = errors.New("invalid billing cycle: must be 'monthly' or 'yearly'") ErrFreePlanNoBilling = errors.New("free plan does not require billing") ) // PaymentProvider defines the interface for subscription payments type PaymentProvider interface { CreateSubscriptionPayment(ctx context.Context, amountCents int, currency, subscriptionID, returnURL string, metadata map[string]string) (paymentID, clientSecret string, err error) GetPayment(ctx context.Context, paymentID string) (status string, err error) } // ServiceOption is a functional option for configuring the Service type ServiceOption func(*Service) // WithPaymentProvider sets the payment provider for the subscription service func WithPaymentProvider(p PaymentProvider) ServiceOption { return func(s *Service) { s.paymentProvider = p } } // Service handles subscription business logic type Service struct { db *gorm.DB logger *zap.Logger paymentProvider PaymentProvider } // NewService creates a new subscription service func NewService(db *gorm.DB, logger *zap.Logger, opts ...ServiceOption) *Service { s := &Service{ db: db, logger: logger, } for _, opt := range opts { opt(s) } return s } // ListPlans returns all active subscription plans ordered by sort_order func (s *Service) ListPlans(ctx context.Context) ([]Plan, error) { var plans []Plan if err := s.db.WithContext(ctx). Where("is_active = ?", true). Order("sort_order ASC"). Find(&plans).Error; err != nil { return nil, fmt.Errorf("failed to list plans: %w", err) } return plans, nil } // GetPlan returns a plan by ID func (s *Service) GetPlan(ctx context.Context, planID uuid.UUID) (*Plan, error) { var plan Plan if err := s.db.WithContext(ctx).First(&plan, "id = ?", planID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrPlanNotFound } return nil, fmt.Errorf("failed to get plan: %w", err) } return &plan, nil } // GetPlanByName returns a plan by its name func (s *Service) GetPlanByName(ctx context.Context, name PlanName) (*Plan, error) { var plan Plan if err := s.db.WithContext(ctx).First(&plan, "name = ?", string(name)).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrPlanNotFound } return nil, fmt.Errorf("failed to get plan by name: %w", err) } return &plan, nil } // GetUserSubscription returns the user's current active/trialing subscription func (s *Service) GetUserSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) { var sub UserSubscription err := s.db.WithContext(ctx). Preload("Plan"). Where("user_id = ? AND status IN ?", userID, []string{string(StatusActive), string(StatusTrialing)}). First(&sub).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNoActiveSubscription } return nil, fmt.Errorf("failed to get user subscription: %w", err) } return &sub, nil } // GetUserSubscriptionHistory returns all subscriptions for a user (including canceled/expired) func (s *Service) GetUserSubscriptionHistory(ctx context.Context, userID uuid.UUID, limit, offset int) ([]UserSubscription, error) { if limit <= 0 || limit > 100 { limit = 20 } if offset < 0 { offset = 0 } var subs []UserSubscription err := s.db.WithContext(ctx). Preload("Plan"). Where("user_id = ?", userID). Order("created_at DESC"). Limit(limit). Offset(offset). Find(&subs).Error if err != nil { return nil, fmt.Errorf("failed to get subscription history: %w", err) } return subs, nil } // SubscribeRequest holds the parameters for subscribing to a plan type SubscribeRequest struct { PlanID uuid.UUID `json:"plan_id" binding:"required"` BillingCycle BillingCycle `json:"billing_cycle" binding:"required"` } // SubscribeResponse holds the result of a subscription creation type SubscribeResponse struct { Subscription *UserSubscription `json:"subscription"` ClientSecret string `json:"client_secret,omitempty"` // For Hyperswitch payment PaymentID string `json:"payment_id,omitempty"` } // Subscribe creates a new subscription for a user func (s *Service) Subscribe(ctx context.Context, userID uuid.UUID, req SubscribeRequest) (*SubscribeResponse, error) { if req.BillingCycle != BillingMonthly && req.BillingCycle != BillingYearly { return nil, ErrInvalidBillingCycle } plan, err := s.GetPlan(ctx, req.PlanID) if err != nil { return nil, err } if plan.Name == PlanFree { return s.subscribeToFreePlan(ctx, userID, plan) } // Check for existing active subscription existing, err := s.GetUserSubscription(ctx, userID) if err != nil && !errors.Is(err, ErrNoActiveSubscription) { return nil, err } if existing != nil && existing.PlanID == req.PlanID { return nil, ErrAlreadySubscribed } // If upgrading from a lower plan, handle the transition if existing != nil { return s.changePlan(ctx, userID, existing, plan, req.BillingCycle) } return s.createNewSubscription(ctx, userID, plan, req.BillingCycle) } // subscribeToFreePlan assigns the free plan without payment func (s *Service) subscribeToFreePlan(ctx context.Context, userID uuid.UUID, plan *Plan) (*SubscribeResponse, error) { // Cancel any existing subscription first existing, err := s.GetUserSubscription(ctx, userID) if err != nil && !errors.Is(err, ErrNoActiveSubscription) { return nil, err } if existing != nil { if err := s.cancelImmediately(ctx, existing); err != nil { return nil, err } } now := time.Now() sub := &UserSubscription{ UserID: userID, PlanID: plan.ID, Status: StatusActive, BillingCycle: BillingMonthly, CurrentPeriodStart: now, CurrentPeriodEnd: now.AddDate(100, 0, 0), // effectively never expires } if err := s.db.WithContext(ctx).Create(sub).Error; err != nil { return nil, fmt.Errorf("failed to create free subscription: %w", err) } sub.Plan = *plan return &SubscribeResponse{Subscription: sub}, nil } // createNewSubscription creates a subscription for a paid plan func (s *Service) createNewSubscription(ctx context.Context, userID uuid.UUID, plan *Plan, cycle BillingCycle) (*SubscribeResponse, error) { now := time.Now() var periodEnd time.Time var amountCents int switch cycle { case BillingYearly: periodEnd = now.AddDate(1, 0, 0) amountCents = plan.PriceYearly default: periodEnd = now.AddDate(0, 1, 0) amountCents = plan.PriceMonthly } sub := &UserSubscription{ UserID: userID, PlanID: plan.ID, BillingCycle: cycle, CurrentPeriodStart: now, CurrentPeriodEnd: periodEnd, } var clientSecret, paymentID string // SECURITY(REM-015): Trial check + subscription creation in single transaction to prevent // race condition where two concurrent requests both see previousTrialCount=0. err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Apply trial if available — checked INSIDE transaction for atomicity if plan.TrialDays > 0 { var previousTrialCount int64 tx.Model(&UserSubscription{}). Where("user_id = ? AND trial_start IS NOT NULL", userID). Count(&previousTrialCount) if previousTrialCount > 0 { sub.Status = StatusActive } else { trialEnd := now.AddDate(0, 0, plan.TrialDays) sub.Status = StatusTrialing sub.TrialStart = &now sub.TrialEnd = &trialEnd sub.CurrentPeriodEnd = trialEnd } } else { sub.Status = StatusActive } if err := tx.Create(sub).Error; err != nil { return fmt.Errorf("failed to create subscription: %w", err) } // Create invoice (for paid plans, not during trial) if !sub.IsTrialing() && amountCents > 0 { invoice := &Invoice{ SubscriptionID: sub.ID, UserID: userID, AmountCents: amountCents, Currency: plan.Currency, Status: InvoicePending, BillingPeriodStart: now, BillingPeriodEnd: periodEnd, } if err := tx.Create(invoice).Error; err != nil { return fmt.Errorf("failed to create invoice: %w", err) } // Initiate payment if provider is configured if s.paymentProvider != nil { var err error paymentID, clientSecret, err = s.paymentProvider.CreateSubscriptionPayment( ctx, amountCents, plan.Currency, sub.ID.String(), "", // returnURL to be set by frontend map[string]string{ "user_id": userID.String(), "subscription_id": sub.ID.String(), "plan": string(plan.Name), "billing_cycle": string(cycle), }, ) if err != nil { return fmt.Errorf("failed to create payment: %w", err) } invoice.HyperswitchPaymentID = paymentID if err := tx.Save(invoice).Error; err != nil { return fmt.Errorf("failed to update invoice with payment ID: %w", err) } } } return nil }) if err != nil { return nil, err } sub.Plan = *plan return &SubscribeResponse{ Subscription: sub, ClientSecret: clientSecret, PaymentID: paymentID, }, nil } // changePlan handles upgrade or downgrade between plans func (s *Service) changePlan(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) { currentPlan := ¤t.Plan if currentPlan.ID == uuid.Nil { var err error currentPlan, err = s.GetPlan(ctx, current.PlanID) if err != nil { return nil, err } } isUpgrade := newPlan.SortOrder > currentPlan.SortOrder if isUpgrade { // Upgrade: takes effect immediately return s.upgradeSubscription(ctx, userID, current, newPlan, cycle) } // Downgrade: takes effect at end of current period return s.scheduleDowngrade(ctx, userID, current, newPlan, cycle) } // upgradeSubscription applies an immediate upgrade func (s *Service) upgradeSubscription(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) { now := time.Now() err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Expire the current subscription current.Status = StatusExpired if err := tx.Save(current).Error; err != nil { return fmt.Errorf("failed to expire current subscription: %w", err) } return nil }) if err != nil { return nil, err } s.logger.Info("User upgraded subscription", zap.String("user_id", userID.String()), zap.String("from_plan", string(current.Plan.Name)), zap.String("to_plan", string(newPlan.Name)), zap.Time("upgraded_at", now), ) return s.createNewSubscription(ctx, userID, newPlan, cycle) } // scheduleDowngrade schedules a downgrade at end of current period func (s *Service) scheduleDowngrade(ctx context.Context, userID uuid.UUID, current *UserSubscription, newPlan *Plan, cycle BillingCycle) (*SubscribeResponse, error) { s.logger.Info("User scheduled downgrade", zap.String("user_id", userID.String()), zap.String("from_plan", string(current.Plan.Name)), zap.String("to_plan", string(newPlan.Name)), zap.Time("effective_at", current.CurrentPeriodEnd), ) // For now, we mark the current as cancel_at_period_end and return info // The actual downgrade will happen when the period ends (via ProcessExpiredSubscriptions) current.CancelAtPeriodEnd = true now := time.Now() current.CanceledAt = &now if err := s.db.WithContext(ctx).Save(current).Error; err != nil { return nil, fmt.Errorf("failed to schedule downgrade: %w", err) } current.Plan = *newPlan // indicate the target plan in response return &SubscribeResponse{ Subscription: current, }, nil } // CancelSubscription cancels a user's subscription at the end of the current period func (s *Service) CancelSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) { sub, err := s.GetUserSubscription(ctx, userID) if err != nil { return nil, err } // Free plan: cancel immediately if sub.Plan.Name == PlanFree { return nil, ErrFreePlanNoBilling } now := time.Now() sub.CancelAtPeriodEnd = true sub.CanceledAt = &now if err := s.db.WithContext(ctx).Save(sub).Error; err != nil { return nil, fmt.Errorf("failed to cancel subscription: %w", err) } s.logger.Info("User canceled subscription", zap.String("user_id", userID.String()), zap.String("plan", string(sub.Plan.Name)), zap.Time("access_until", sub.CurrentPeriodEnd), ) return sub, nil } // ReactivateSubscription removes the cancellation flag if still within the period func (s *Service) ReactivateSubscription(ctx context.Context, userID uuid.UUID) (*UserSubscription, error) { sub, err := s.GetUserSubscription(ctx, userID) if err != nil { return nil, err } if !sub.CancelAtPeriodEnd { return sub, nil // not canceled, nothing to do } sub.CancelAtPeriodEnd = false sub.CanceledAt = nil if err := s.db.WithContext(ctx).Save(sub).Error; err != nil { return nil, fmt.Errorf("failed to reactivate subscription: %w", err) } s.logger.Info("User reactivated subscription", zap.String("user_id", userID.String()), zap.String("plan", string(sub.Plan.Name)), ) return sub, nil } // cancelImmediately expires a subscription right away (used internally for plan switches) func (s *Service) cancelImmediately(ctx context.Context, sub *UserSubscription) error { now := time.Now() sub.Status = StatusExpired sub.CanceledAt = &now return s.db.WithContext(ctx).Save(sub).Error } // GetUserInvoices returns invoices for a user func (s *Service) GetUserInvoices(ctx context.Context, userID uuid.UUID, limit, offset int) ([]Invoice, error) { if limit <= 0 || limit > 100 { limit = 20 } if offset < 0 { offset = 0 } var invoices []Invoice err := s.db.WithContext(ctx). Where("user_id = ?", userID). Order("created_at DESC"). Limit(limit). Offset(offset). Find(&invoices).Error if err != nil { return nil, fmt.Errorf("failed to get invoices: %w", err) } return invoices, nil } // ProcessExpiredSubscriptions checks for subscriptions past their period end and expires them // This should be called periodically (e.g., daily cron job) func (s *Service) ProcessExpiredSubscriptions(ctx context.Context) (int, error) { now := time.Now() var count int64 // Expire subscriptions that have cancel_at_period_end and period has ended result := s.db.WithContext(ctx). Model(&UserSubscription{}). Where("cancel_at_period_end = ? AND current_period_end < ? AND status IN ?", true, now, []string{string(StatusActive), string(StatusTrialing)}). Updates(map[string]interface{}{ "status": StatusExpired, "updated_at": now, }) if result.Error != nil { return 0, fmt.Errorf("failed to expire subscriptions: %w", result.Error) } count = result.RowsAffected // Expire trials that have ended without payment trialResult := s.db.WithContext(ctx). Model(&UserSubscription{}). Where("status = ? AND trial_end < ?", StatusTrialing, now). Updates(map[string]interface{}{ "status": StatusExpired, "updated_at": now, }) if trialResult.Error != nil { return int(count), fmt.Errorf("failed to expire trials: %w", trialResult.Error) } count += trialResult.RowsAffected if count > 0 { s.logger.Info("Processed expired subscriptions", zap.Int64("expired_count", count)) } return int(count), nil } // ChangeBillingCycle switches between monthly and yearly billing func (s *Service) ChangeBillingCycle(ctx context.Context, userID uuid.UUID, newCycle BillingCycle) (*UserSubscription, error) { if newCycle != BillingMonthly && newCycle != BillingYearly { return nil, ErrInvalidBillingCycle } sub, err := s.GetUserSubscription(ctx, userID) if err != nil { return nil, err } if sub.BillingCycle == newCycle { return sub, nil // already on this cycle } sub.BillingCycle = newCycle // The new cycle takes effect at the next renewal if err := s.db.WithContext(ctx).Save(sub).Error; err != nil { return nil, fmt.Errorf("failed to change billing cycle: %w", err) } s.logger.Info("User changed billing cycle", zap.String("user_id", userID.String()), zap.String("new_cycle", string(newCycle)), ) return sub, nil }