package services import ( "bytes" "context" "crypto/sha256" "encoding/hex" "errors" "io" "os" "path/filepath" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" ) // failingWriter returns errAfter on the call after `okBytes` bytes have // been successfully written. Used to prove StreamChunkedUpload aborts // cleanly when the destination (typically the io.Pipe wired to S3) fails. type failingWriter struct { okBytes int written int errAfter error } func (f *failingWriter) Write(p []byte) (int, error) { if f.written >= f.okBytes { return 0, f.errAfter } allowed := f.okBytes - f.written if allowed > len(p) { allowed = len(p) } f.written += allowed if allowed < len(p) { return allowed, f.errAfter } return allowed, nil } // makeChunksOnDisk lays out N test chunks in chunksDir/uploadID/chunk_i, // each containing `chunkSize` deterministic bytes (i+1 stamped). Returns // the populated ChunkUploadInfo so tests can register it in the store. func makeChunksOnDisk(t *testing.T, chunksDir, uploadID string, userID uuid.UUID, total int, chunkSize int) (*ChunkUploadInfo, []byte) { t.Helper() uploadDir := filepath.Join(chunksDir, uploadID) require.NoError(t, os.MkdirAll(uploadDir, 0o755)) info := &ChunkUploadInfo{ UploadID: uploadID, UserID: userID, TotalChunks: total, Filename: "song.mp3", Chunks: make(map[int]ChunkInfo, total), CreatedAt: time.Now(), UpdatedAt: time.Now(), } var assembled bytes.Buffer for i := 1; i <= total; i++ { path := filepath.Join(uploadDir, "chunk_"+itoa(i)) buf := make([]byte, chunkSize) for j := range buf { buf[j] = byte(i) } require.NoError(t, os.WriteFile(path, buf, 0o644)) assembled.Write(buf) info.Chunks[i] = ChunkInfo{ ChunkNumber: i, Size: int64(chunkSize), FilePath: path, Received: true, } } info.TotalSize = int64(total * chunkSize) return info, assembled.Bytes() } func itoa(i int) string { // avoid strconv import noise; itoa for small ints is fine here. if i == 0 { return "0" } neg := i < 0 if neg { i = -i } var buf [20]byte pos := len(buf) for i > 0 { pos-- buf[pos] = byte('0' + i%10) i /= 10 } if neg { pos-- buf[pos] = '-' } return string(buf[pos:]) } func newStreamTestService(t *testing.T) (*TrackChunkService, *MockStore, string) { t.Helper() chunksDir, err := os.MkdirTemp("", "veza-chunk-stream-*") require.NoError(t, err) t.Cleanup(func() { _ = os.RemoveAll(chunksDir) }) store := NewMockStore() svc := &TrackChunkService{ chunksDir: chunksDir, store: store, logger: zap.NewNop(), } return svc, store, chunksDir } func TestStreamChunkedUpload_AssemblesIntoWriterAndCleansUp(t *testing.T) { svc, store, chunksDir := newStreamTestService(t) uploadID := "stream-happy" userID := uuid.New() info, want := makeChunksOnDisk(t, chunksDir, uploadID, userID, 4, 1024) require.NoError(t, store.SetState(context.Background(), info)) var dst bytes.Buffer filename, totalSize, checksum, err := svc.StreamChunkedUpload(context.Background(), uploadID, &dst) require.NoError(t, err) assert.Equal(t, "song.mp3", filename) assert.Equal(t, int64(4*1024), totalSize) wantSum := sha256.Sum256(want) assert.Equal(t, hex.EncodeToString(wantSum[:]), checksum) assert.Equal(t, want, dst.Bytes(), "assembled bytes must equal concatenated chunks in order") // Cleanup expectations: chunks dir and Redis state both gone on success. if _, err := os.Stat(filepath.Join(chunksDir, uploadID)); !os.IsNotExist(err) { t.Fatalf("expected chunk dir to be removed, got err=%v", err) } if _, err := store.GetState(context.Background(), uploadID); err == nil { t.Fatalf("expected store state to be deleted on success") } } func TestStreamChunkedUpload_PreservesStateOnWriterError(t *testing.T) { svc, store, chunksDir := newStreamTestService(t) uploadID := "stream-writer-fail" userID := uuid.New() info, _ := makeChunksOnDisk(t, chunksDir, uploadID, userID, 3, 512) require.NoError(t, store.SetState(context.Background(), info)) sentinel := errors.New("downstream pipe closed") dst := &failingWriter{okBytes: 100, errAfter: sentinel} // fail mid-first-chunk _, _, _, err := svc.StreamChunkedUpload(context.Background(), uploadID, dst) require.Error(t, err) assert.ErrorContains(t, err, "downstream pipe closed", "writer error must surface to caller") // On error, we deliberately keep the chunks on disk and the state in // Redis so the upload can be resumed. v1.0.9 item 1.5. if _, err := os.Stat(filepath.Join(chunksDir, uploadID)); err != nil { t.Fatalf("expected chunk dir to be preserved on failure, got err=%v", err) } if _, err := store.GetState(context.Background(), uploadID); err != nil { t.Fatalf("expected store state to be preserved on failure, got err=%v", err) } } func TestStreamChunkedUpload_DetectsSizeMismatch(t *testing.T) { svc, store, chunksDir := newStreamTestService(t) uploadID := "stream-size-mismatch" userID := uuid.New() info, _ := makeChunksOnDisk(t, chunksDir, uploadID, userID, 2, 256) info.TotalSize = 99999 // lie about the expected size require.NoError(t, store.SetState(context.Background(), info)) _, _, _, err := svc.StreamChunkedUpload(context.Background(), uploadID, io.Discard) require.Error(t, err) assert.ErrorContains(t, err, "size mismatch") // Mismatch is a recoverable client/server disagreement — keep state so // the user can re-upload the missing chunks rather than start over. if _, err := store.GetState(context.Background(), uploadID); err != nil { t.Fatalf("expected store state to be preserved on size mismatch, got err=%v", err) } } func TestCompleteChunkedUpload_DelegatesToStream(t *testing.T) { svc, store, chunksDir := newStreamTestService(t) uploadID := "complete-via-stream" userID := uuid.New() info, want := makeChunksOnDisk(t, chunksDir, uploadID, userID, 5, 200) require.NoError(t, store.SetState(context.Background(), info)) finalPath := filepath.Join(chunksDir, "out", "assembled.bin") filename, totalSize, checksum, err := svc.CompleteChunkedUpload(context.Background(), uploadID, finalPath) require.NoError(t, err) assert.Equal(t, "song.mp3", filename) assert.Equal(t, int64(5*200), totalSize) got, err := os.ReadFile(finalPath) require.NoError(t, err) assert.Equal(t, want, got) wantSum := sha256.Sum256(want) assert.Equal(t, hex.EncodeToString(wantSum[:]), checksum) }