veza/veza-backend-api/internal/repositories/track_repository.go

142 lines
4.5 KiB
Go

package repositories
import (
"context"
"veza-backend-api/internal/models"
"github.com/google/uuid"
"gorm.io/gorm"
)
// TrackListParams mirrors track.TrackListParams for repository layer
type TrackListParams struct {
Page int
Limit int
UserID *uuid.UUID
Genre *string
Format *string
SortBy string
SortOrder string
}
// TrackRepository defines the interface for track data access
type TrackRepository interface {
Create(ctx context.Context, track *models.Track) error
GetByID(ctx context.Context, trackID uuid.UUID) (*models.Track, error)
List(ctx context.Context, params TrackListParams) ([]*models.Track, int64, error)
Update(ctx context.Context, track *models.Track, updates map[string]interface{}) error
UpdateStatus(ctx context.Context, trackID uuid.UUID, status models.TrackStatus, message string) error
Delete(ctx context.Context, track *models.Track) error
CountByCreatorID(ctx context.Context, userID uuid.UUID) (int64, error)
SumStorageByCreatorID(ctx context.Context, userID uuid.UUID) (int64, error)
FindByIDs(ctx context.Context, trackIDs []uuid.UUID) ([]*models.Track, error)
}
type trackRepository struct {
db *gorm.DB
}
// NewTrackRepository creates a new TrackRepository
func NewTrackRepository(db *gorm.DB) TrackRepository {
return &trackRepository{db: db}
}
func (r *trackRepository) Create(ctx context.Context, track *models.Track) error {
return r.db.WithContext(ctx).Create(track).Error
}
func (r *trackRepository) GetByID(ctx context.Context, trackID uuid.UUID) (*models.Track, error) {
var track models.Track
if err := r.db.WithContext(ctx).Preload("User").First(&track, "id = ?", trackID).Error; err != nil {
return nil, err
}
return &track, nil
}
func (r *trackRepository) List(ctx context.Context, params TrackListParams) ([]*models.Track, int64, error) {
query := r.db.WithContext(ctx).Model(&models.Track{}).Where("status = ?", models.TrackStatusCompleted)
if params.UserID != nil {
query = query.Where("creator_id = ?", *params.UserID)
}
if params.Genre != nil && *params.Genre != "" {
query = query.Where("genre = ?", *params.Genre)
}
if params.Format != nil && *params.Format != "" {
query = query.Where("format = ?", *params.Format)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
sortOrder := "DESC"
if params.SortOrder == "asc" {
sortOrder = "ASC"
}
sortBy := params.SortBy
if sortBy == "" {
sortBy = "created_at"
}
validSortFields := map[string]bool{"created_at": true, "title": true, "popularity": true}
if !validSortFields[sortBy] {
sortBy = "created_at"
}
if sortBy == "popularity" {
query = query.Order("(play_count + like_count) " + sortOrder)
} else {
query = query.Order(sortBy + " " + sortOrder)
}
if params.Limit <= 0 {
params.Limit = 20
}
if params.Limit > 100 {
params.Limit = 100
}
if params.Page <= 0 {
params.Page = 1
}
offset := (params.Page - 1) * params.Limit
query = query.Offset(offset).Limit(params.Limit)
var tracks []*models.Track
if err := query.Preload("User").Find(&tracks).Error; err != nil {
return nil, 0, err
}
return tracks, total, nil
}
func (r *trackRepository) Update(ctx context.Context, track *models.Track, updates map[string]interface{}) error {
return r.db.WithContext(ctx).Model(track).Updates(updates).Error
}
func (r *trackRepository) UpdateStatus(ctx context.Context, trackID uuid.UUID, status models.TrackStatus, message string) error {
return r.db.WithContext(ctx).Model(&models.Track{}).Where("id = ?", trackID).
Updates(map[string]interface{}{"status": status, "status_message": message}).Error
}
func (r *trackRepository) Delete(ctx context.Context, track *models.Track) error {
return r.db.WithContext(ctx).Delete(track).Error
}
func (r *trackRepository) CountByCreatorID(ctx context.Context, userID uuid.UUID) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&models.Track{}).Where("creator_id = ?", userID).Count(&count).Error
return count, err
}
func (r *trackRepository) SumStorageByCreatorID(ctx context.Context, userID uuid.UUID) (int64, error) {
var sum int64
err := r.db.WithContext(ctx).Model(&models.Track{}).Where("creator_id = ?", userID).Select("COALESCE(SUM(file_size), 0)").Scan(&sum).Error
return sum, err
}
func (r *trackRepository) FindByIDs(ctx context.Context, trackIDs []uuid.UUID) ([]*models.Track, error) {
var tracks []*models.Track
if err := r.db.WithContext(ctx).Where("id IN ?", trackIDs).Find(&tracks).Error; err != nil {
return nil, err
}
return tracks, nil
}