mirror of
https://github.com/ollama/ollama.git
synced 2026-03-28 03:08:44 +07:00
Removes 5-minute HTTP client timeout that caused "context deadline exceeded" errors on large file downloads. Stall detection (10s) already handles unresponsive connections. Fixes progress bar total going down on resume by calculating total from all blobs upfront and reporting already-downloaded bytes as completed immediately.
1778 lines
46 KiB
Go
1778 lines
46 KiB
Go
package transfer
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// createTestBlob creates a blob with deterministic content and returns its digest
|
|
func createTestBlob(t *testing.T, dir string, size int) (Blob, []byte) {
|
|
t.Helper()
|
|
|
|
// Create deterministic content
|
|
data := make([]byte, size)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
|
|
// Write to file
|
|
path := filepath.Join(dir, digestToPath(digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return Blob{Digest: digest, Size: int64(size)}, data
|
|
}
|
|
|
|
func TestDownload(t *testing.T) {
|
|
// Create test blobs on "server"
|
|
serverDir := t.TempDir()
|
|
blob1, data1 := createTestBlob(t, serverDir, 1024)
|
|
blob2, data2 := createTestBlob(t, serverDir, 2048)
|
|
blob3, data3 := createTestBlob(t, serverDir, 512)
|
|
|
|
// Mock server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract digest from URL: /v2/library/_/blobs/sha256:...
|
|
digest := filepath.Base(r.URL.Path)
|
|
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Download to client dir
|
|
clientDir := t.TempDir()
|
|
|
|
var progressCalls atomic.Int32
|
|
var lastCompleted, lastTotal atomic.Int64
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob1, blob2, blob3},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 2,
|
|
Progress: func(completed, total int64) {
|
|
progressCalls.Add(1)
|
|
lastCompleted.Store(completed)
|
|
lastTotal.Store(total)
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Verify files
|
|
verifyBlob(t, clientDir, blob1, data1)
|
|
verifyBlob(t, clientDir, blob2, data2)
|
|
verifyBlob(t, clientDir, blob3, data3)
|
|
|
|
// Verify progress was called
|
|
if progressCalls.Load() == 0 {
|
|
t.Error("Progress callback never called")
|
|
}
|
|
if lastTotal.Load() != blob1.Size+blob2.Size+blob3.Size {
|
|
t.Errorf("Wrong total: got %d, want %d", lastTotal.Load(), blob1.Size+blob2.Size+blob3.Size)
|
|
}
|
|
}
|
|
|
|
func TestDownloadWithRedirect(t *testing.T) {
|
|
// Create test blob on "CDN"
|
|
cdnDir := t.TempDir()
|
|
blob, data := createTestBlob(t, cdnDir, 1024)
|
|
|
|
// CDN server (the redirect target)
|
|
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Serve the blob content
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(cdnDir, digestToPath(digest))
|
|
blobData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(blobData)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(blobData)
|
|
}))
|
|
defer cdn.Close()
|
|
|
|
// Registry server (redirects to CDN)
|
|
registry := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Redirect to CDN
|
|
cdnURL := cdn.URL + r.URL.Path
|
|
http.Redirect(w, r, cdnURL, http.StatusTemporaryRedirect)
|
|
}))
|
|
defer registry.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: registry.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download with redirect failed: %v", err)
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob, data)
|
|
}
|
|
|
|
func TestDownloadWithRetry(t *testing.T) {
|
|
// Create test blob
|
|
serverDir := t.TempDir()
|
|
blob, data := createTestBlob(t, serverDir, 1024)
|
|
|
|
var requestCount atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
count := requestCount.Add(1)
|
|
|
|
// Fail first 2 attempts, succeed on 3rd
|
|
if count < 3 {
|
|
http.Error(w, "temporary error", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
blobData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(blobData)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download with retry failed: %v", err)
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob, data)
|
|
|
|
// Should have made 3 requests (2 failures + 1 success)
|
|
if requestCount.Load() < 3 {
|
|
t.Errorf("Expected at least 3 requests for retry, got %d", requestCount.Load())
|
|
}
|
|
}
|
|
|
|
func TestDownloadWithAuth(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob, data := createTestBlob(t, serverDir, 1024)
|
|
|
|
var authCalled atomic.Bool
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Require auth
|
|
auth := r.Header.Get("Authorization")
|
|
if auth != "Bearer valid-token" {
|
|
w.Header().Set("WWW-Authenticate", `Bearer realm="https://auth.example.com",service="registry",scope="repository:library:pull"`)
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
blobData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(blobData)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
GetToken: func(ctx context.Context, challenge AuthChallenge) (string, error) {
|
|
authCalled.Store(true)
|
|
if challenge.Realm != "https://auth.example.com" {
|
|
t.Errorf("Wrong realm: %s", challenge.Realm)
|
|
}
|
|
if challenge.Service != "registry" {
|
|
t.Errorf("Wrong service: %s", challenge.Service)
|
|
}
|
|
return "valid-token", nil
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download with auth failed: %v", err)
|
|
}
|
|
|
|
if !authCalled.Load() {
|
|
t.Error("GetToken was never called")
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob, data)
|
|
}
|
|
|
|
func TestDownloadSkipsExisting(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob1, data1 := createTestBlob(t, serverDir, 1024)
|
|
|
|
// Pre-populate client dir
|
|
clientDir := t.TempDir()
|
|
path := filepath.Join(clientDir, digestToPath(blob1.Digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, data1, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var requestCount atomic.Int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestCount.Add(1)
|
|
http.NotFound(w, r)
|
|
}))
|
|
defer server.Close()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob1},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Should not have made any requests (blob already exists)
|
|
if requestCount.Load() != 0 {
|
|
t.Errorf("Made %d requests, expected 0 (blob should be skipped)", requestCount.Load())
|
|
}
|
|
}
|
|
|
|
func TestDownloadResumeProgressTotal(t *testing.T) {
|
|
// Test that when resuming a download with some blobs already present:
|
|
// 1. Total reflects ALL blob sizes (not just remaining)
|
|
// 2. Completed starts at the size of already-downloaded blobs
|
|
serverDir := t.TempDir()
|
|
blob1, data1 := createTestBlob(t, serverDir, 1000)
|
|
blob2, data2 := createTestBlob(t, serverDir, 2000)
|
|
blob3, data3 := createTestBlob(t, serverDir, 3000)
|
|
|
|
// Pre-populate client with blob1 and blob2 (simulating partial download)
|
|
clientDir := t.TempDir()
|
|
for _, b := range []struct {
|
|
blob Blob
|
|
data []byte
|
|
}{{blob1, data1}, {blob2, data2}} {
|
|
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, b.data, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
var firstCompleted, firstTotal int64
|
|
var gotFirstProgress bool
|
|
var mu sync.Mutex
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob1, blob2, blob3},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 1,
|
|
Progress: func(completed, total int64) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if !gotFirstProgress {
|
|
firstCompleted = completed
|
|
firstTotal = total
|
|
gotFirstProgress = true
|
|
}
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Total should be sum of ALL blobs, not just blob3
|
|
expectedTotal := blob1.Size + blob2.Size + blob3.Size
|
|
if firstTotal != expectedTotal {
|
|
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
|
|
}
|
|
|
|
// First progress call should show already-completed bytes from blob1+blob2
|
|
expectedCompleted := blob1.Size + blob2.Size
|
|
if firstCompleted < expectedCompleted {
|
|
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
|
|
}
|
|
|
|
// Verify blob3 was downloaded
|
|
verifyBlob(t, clientDir, blob3, data3)
|
|
}
|
|
|
|
func TestDownloadDigestMismatch(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Return wrong data
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("wrong data"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{{Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000", Size: 10}},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err == nil {
|
|
t.Fatal("Expected error for digest mismatch")
|
|
}
|
|
}
|
|
|
|
func TestUpload(t *testing.T) {
|
|
// Create test blobs
|
|
clientDir := t.TempDir()
|
|
blob1, _ := createTestBlob(t, clientDir, 1024)
|
|
blob2, _ := createTestBlob(t, clientDir, 2048)
|
|
|
|
var uploadedBlobs sync.Map
|
|
uploadID := 0
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
// Blob doesn't exist
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost && r.URL.Path == "/v2/library/_/blobs/uploads/":
|
|
// Initiate upload
|
|
uploadID++
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, uploadID))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPut:
|
|
// Complete upload
|
|
digest := r.URL.Query().Get("digest")
|
|
data, _ := io.ReadAll(r.Body)
|
|
uploadedBlobs.Store(digest, data)
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
var progressCalls atomic.Int32
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob1, blob2},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Concurrency: 2,
|
|
Progress: func(completed, total int64) {
|
|
progressCalls.Add(1)
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
// Verify both blobs were uploaded
|
|
if _, ok := uploadedBlobs.Load(blob1.Digest); !ok {
|
|
t.Error("Blob 1 not uploaded")
|
|
}
|
|
if _, ok := uploadedBlobs.Load(blob2.Digest); !ok {
|
|
t.Error("Blob 2 not uploaded")
|
|
}
|
|
|
|
if progressCalls.Load() == 0 {
|
|
t.Error("Progress callback never called")
|
|
}
|
|
}
|
|
|
|
func TestUploadWithRedirect(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var uploadedBlobs sync.Map
|
|
var cdnCalled atomic.Bool
|
|
|
|
// CDN server (redirect target)
|
|
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cdnCalled.Store(true)
|
|
if r.Method == http.MethodPut {
|
|
digest := r.URL.Query().Get("digest")
|
|
data, _ := io.ReadAll(r.Body)
|
|
uploadedBlobs.Store(digest, data)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer cdn.Close()
|
|
|
|
var serverURL string
|
|
uploadID := 0
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost && r.URL.Path == "/v2/library/_/blobs/uploads/":
|
|
uploadID++
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, uploadID))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPut:
|
|
// Redirect to CDN
|
|
cdnURL := cdn.URL + r.URL.Path + "?" + r.URL.RawQuery
|
|
http.Redirect(w, r, cdnURL, http.StatusTemporaryRedirect)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with redirect failed: %v", err)
|
|
}
|
|
|
|
if !cdnCalled.Load() {
|
|
t.Error("CDN was never called (redirect not followed)")
|
|
}
|
|
|
|
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Error("Blob not uploaded to CDN")
|
|
}
|
|
}
|
|
|
|
func TestUploadWithAuth(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var uploadedBlobs sync.Map
|
|
var authCalled atomic.Bool
|
|
uploadID := 0
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Require auth for all requests
|
|
auth := r.Header.Get("Authorization")
|
|
if auth != "Bearer valid-token" {
|
|
w.Header().Set("WWW-Authenticate", `Bearer realm="https://auth.example.com",service="registry",scope="repository:library:push"`)
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost && r.URL.Path == "/v2/library/_/blobs/uploads/":
|
|
uploadID++
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, uploadID))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPut:
|
|
digest := r.URL.Query().Get("digest")
|
|
data, _ := io.ReadAll(r.Body)
|
|
uploadedBlobs.Store(digest, data)
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
GetToken: func(ctx context.Context, challenge AuthChallenge) (string, error) {
|
|
authCalled.Store(true)
|
|
return "valid-token", nil
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with auth failed: %v", err)
|
|
}
|
|
|
|
if !authCalled.Load() {
|
|
t.Error("GetToken was never called")
|
|
}
|
|
|
|
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Error("Blob not uploaded")
|
|
}
|
|
}
|
|
|
|
func TestUploadSkipsExisting(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob1, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var headChecked atomic.Bool
|
|
var putCalled atomic.Bool
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
// HEAD check for blob existence - return 200 OK to indicate blob exists
|
|
headChecked.Store(true)
|
|
w.WriteHeader(http.StatusOK)
|
|
case http.MethodPost:
|
|
http.NotFound(w, r)
|
|
case http.MethodPut:
|
|
putCalled.Store(true)
|
|
w.WriteHeader(http.StatusCreated)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob1},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
// Verify HEAD check was used
|
|
if !headChecked.Load() {
|
|
t.Error("HEAD check was never made")
|
|
}
|
|
|
|
// Should not have attempted PUT (blob already exists)
|
|
if putCalled.Load() {
|
|
t.Error("PUT was called even though blob exists (HEAD returned 200)")
|
|
}
|
|
|
|
t.Log("HEAD-based existence check verified")
|
|
}
|
|
|
|
// TestUploadWithCustomRepository verifies that custom repository paths are used
|
|
func TestUploadWithCustomRepository(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob1, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var headPath, postPath string
|
|
var mu sync.Mutex
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
headPath = r.URL.Path
|
|
w.WriteHeader(http.StatusNotFound) // Blob doesn't exist
|
|
case http.MethodPost:
|
|
postPath = r.URL.Path
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/myorg/mymodel/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPut:
|
|
io.Copy(io.Discard, r.Body)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
mu.Unlock()
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob1},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Repository: "myorg/mymodel", // Custom repository
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Verify HEAD used custom repository path
|
|
expectedHeadPath := fmt.Sprintf("/v2/myorg/mymodel/blobs/%s", blob1.Digest)
|
|
if headPath != expectedHeadPath {
|
|
t.Errorf("HEAD path mismatch: got %s, want %s", headPath, expectedHeadPath)
|
|
}
|
|
|
|
// Verify POST used custom repository path
|
|
expectedPostPath := "/v2/myorg/mymodel/blobs/uploads/"
|
|
if postPath != expectedPostPath {
|
|
t.Errorf("POST path mismatch: got %s, want %s", postPath, expectedPostPath)
|
|
}
|
|
|
|
t.Logf("Custom repository paths verified: HEAD=%s, POST=%s", headPath, postPath)
|
|
}
|
|
|
|
// TestDownloadWithCustomRepository verifies that custom repository paths are used
|
|
func TestDownloadWithCustomRepository(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob, data := createTestBlob(t, serverDir, 1024)
|
|
|
|
var requestPath string
|
|
var mu sync.Mutex
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
requestPath = r.URL.Path
|
|
mu.Unlock()
|
|
|
|
// Serve blob from any path
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
blobData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(blobData)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Repository: "myorg/mymodel", // Custom repository
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob, data)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Verify request used custom repository path
|
|
expectedPath := fmt.Sprintf("/v2/myorg/mymodel/blobs/%s", blob.Digest)
|
|
if requestPath != expectedPath {
|
|
t.Errorf("Request path mismatch: got %s, want %s", requestPath, expectedPath)
|
|
}
|
|
|
|
t.Logf("Custom repository path verified: %s", requestPath)
|
|
}
|
|
|
|
func TestDigestToPath(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want string
|
|
}{
|
|
{"sha256:abc123", "sha256-abc123"},
|
|
{"sha256-abc123", "sha256-abc123"},
|
|
{"other", "other"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got := digestToPath(tt.input)
|
|
if got != tt.want {
|
|
t.Errorf("digestToPath(%q) = %q, want %q", tt.input, got, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestParseAuthChallenge(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want AuthChallenge
|
|
}{
|
|
{
|
|
input: `Bearer realm="https://auth.example.com/token",service="registry",scope="repository:library/test:pull"`,
|
|
want: AuthChallenge{
|
|
Realm: "https://auth.example.com/token",
|
|
Service: "registry",
|
|
Scope: "repository:library/test:pull",
|
|
},
|
|
},
|
|
{
|
|
input: `Bearer realm="https://auth.example.com"`,
|
|
want: AuthChallenge{
|
|
Realm: "https://auth.example.com",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got := parseAuthChallenge(tt.input)
|
|
if got.Realm != tt.want.Realm {
|
|
t.Errorf("parseAuthChallenge(%q).Realm = %q, want %q", tt.input, got.Realm, tt.want.Realm)
|
|
}
|
|
if got.Service != tt.want.Service {
|
|
t.Errorf("parseAuthChallenge(%q).Service = %q, want %q", tt.input, got.Service, tt.want.Service)
|
|
}
|
|
if got.Scope != tt.want.Scope {
|
|
t.Errorf("parseAuthChallenge(%q).Scope = %q, want %q", tt.input, got.Scope, tt.want.Scope)
|
|
}
|
|
}
|
|
}
|
|
|
|
func verifyBlob(t *testing.T, dir string, blob Blob, expected []byte) {
|
|
t.Helper()
|
|
|
|
path := filepath.Join(dir, digestToPath(blob.Digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
t.Errorf("Failed to read %s: %v", blob.Digest[:19], err)
|
|
return
|
|
}
|
|
|
|
if len(data) != len(expected) {
|
|
t.Errorf("Size mismatch for %s: got %d, want %d", blob.Digest[:19], len(data), len(expected))
|
|
return
|
|
}
|
|
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
if digest != blob.Digest {
|
|
t.Errorf("Digest mismatch for %s: got %s", blob.Digest[:19], digest[:19])
|
|
}
|
|
}
|
|
|
|
// ==================== Parallelism Tests ====================
|
|
|
|
func TestDownloadParallelism(t *testing.T) {
|
|
// Create many blobs to test parallelism
|
|
serverDir := t.TempDir()
|
|
numBlobs := 10
|
|
blobs := make([]Blob, numBlobs)
|
|
blobData := make([][]byte, numBlobs)
|
|
|
|
for i := range numBlobs {
|
|
blobs[i], blobData[i] = createTestBlob(t, serverDir, 1024+i*100)
|
|
}
|
|
|
|
var activeRequests atomic.Int32
|
|
var maxConcurrent atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
current := activeRequests.Add(1)
|
|
defer activeRequests.Add(-1)
|
|
|
|
// Track max concurrent requests
|
|
for {
|
|
old := maxConcurrent.Load()
|
|
if current <= old || maxConcurrent.CompareAndSwap(old, current) {
|
|
break
|
|
}
|
|
}
|
|
|
|
// Simulate network latency to ensure parallelism is visible
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
start := time.Now()
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: blobs,
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 4,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Verify all blobs downloaded
|
|
for i, blob := range blobs {
|
|
verifyBlob(t, clientDir, blob, blobData[i])
|
|
}
|
|
|
|
// Verify parallelism was used
|
|
if maxConcurrent.Load() < 2 {
|
|
t.Errorf("Max concurrent requests was %d, expected at least 2 for parallelism", maxConcurrent.Load())
|
|
}
|
|
|
|
// With 10 blobs at 50ms each, sequential would take ~500ms
|
|
// Parallel with 4 workers should take ~150ms (relax to 1s for CI variance)
|
|
if elapsed > time.Second {
|
|
t.Errorf("Downloads took %v, expected faster with parallelism", elapsed)
|
|
}
|
|
|
|
t.Logf("Downloaded %d blobs in %v with max %d concurrent requests", numBlobs, elapsed, maxConcurrent.Load())
|
|
}
|
|
|
|
func TestUploadParallelism(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
numBlobs := 10
|
|
blobs := make([]Blob, numBlobs)
|
|
|
|
for i := range numBlobs {
|
|
blobs[i], _ = createTestBlob(t, clientDir, 1024+i*100)
|
|
}
|
|
|
|
var activeRequests atomic.Int32
|
|
var maxConcurrent atomic.Int32
|
|
var uploadedBlobs sync.Map
|
|
var uploadID atomic.Int32
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
current := activeRequests.Add(1)
|
|
defer activeRequests.Add(-1)
|
|
|
|
// Track max concurrent
|
|
for {
|
|
old := maxConcurrent.Load()
|
|
if current <= old || maxConcurrent.CompareAndSwap(old, current) {
|
|
break
|
|
}
|
|
}
|
|
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost:
|
|
id := uploadID.Add(1)
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPut:
|
|
time.Sleep(50 * time.Millisecond) // Simulate upload time
|
|
digest := r.URL.Query().Get("digest")
|
|
data, _ := io.ReadAll(r.Body)
|
|
uploadedBlobs.Store(digest, data)
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
start := time.Now()
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: blobs,
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Concurrency: 4,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
// Verify all blobs uploaded
|
|
for _, blob := range blobs {
|
|
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Errorf("Blob %s not uploaded", blob.Digest[:19])
|
|
}
|
|
}
|
|
|
|
if maxConcurrent.Load() < 2 {
|
|
t.Errorf("Max concurrent requests was %d, expected at least 2", maxConcurrent.Load())
|
|
}
|
|
|
|
t.Logf("Uploaded %d blobs in %v with max %d concurrent requests", numBlobs, elapsed, maxConcurrent.Load())
|
|
}
|
|
|
|
// ==================== Stall Detection Test ====================
|
|
|
|
func TestDownloadStallDetection(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping stall detection test in short mode")
|
|
}
|
|
|
|
serverDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, serverDir, 10*1024) // 10KB
|
|
|
|
var requestCount atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
count := requestCount.Add(1)
|
|
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
if count == 1 {
|
|
// First request: send partial data then stall
|
|
w.Write(data[:1024]) // Send first 1KB
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
// Stall for longer than stall timeout (test uses 200ms)
|
|
time.Sleep(500 * time.Millisecond)
|
|
return
|
|
}
|
|
|
|
// Subsequent requests: send full data
|
|
w.Write(data)
|
|
}))
|
|
defer func() {
|
|
server.CloseClientConnections()
|
|
server.Close()
|
|
}()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
start := time.Now()
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
StallTimeout: 200 * time.Millisecond, // Short timeout for testing
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Should have retried after stall detection
|
|
if requestCount.Load() < 2 {
|
|
t.Errorf("Expected at least 2 requests (stall + retry), got %d", requestCount.Load())
|
|
}
|
|
|
|
// Should complete quickly with short stall timeout
|
|
if elapsed > 3*time.Second {
|
|
t.Errorf("Download took %v, stall detection should have triggered faster", elapsed)
|
|
}
|
|
|
|
t.Logf("Stall detection worked: %d requests in %v", requestCount.Load(), elapsed)
|
|
}
|
|
|
|
// ==================== Context Cancellation Tests ====================
|
|
|
|
func TestDownloadCancellation(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, serverDir, 100*1024) // 100KB (smaller for faster test)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, _ := os.ReadFile(path)
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Send data slowly
|
|
for i := 0; i < len(data); i += 1024 {
|
|
end := i + 1024
|
|
if end > len(data) {
|
|
end = len(data)
|
|
}
|
|
w.Write(data[i:end])
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
}))
|
|
defer func() {
|
|
server.CloseClientConnections()
|
|
server.Close()
|
|
}()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Cancel after 50ms
|
|
go func() {
|
|
time.Sleep(50 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
|
|
start := time.Now()
|
|
err := Download(ctx, DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err == nil {
|
|
t.Fatal("Expected error from cancellation")
|
|
}
|
|
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Errorf("Expected context.Canceled error, got: %v", err)
|
|
}
|
|
|
|
// Should cancel quickly, not wait for full download
|
|
if elapsed > 500*time.Millisecond {
|
|
t.Errorf("Cancellation took %v, expected faster response", elapsed)
|
|
}
|
|
|
|
t.Logf("Cancellation worked in %v", elapsed)
|
|
}
|
|
|
|
func TestUploadCancellation(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 100*1024) // 100KB (smaller for faster test)
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPut:
|
|
// Read slowly
|
|
buf := make([]byte, 1024)
|
|
for {
|
|
_, err := r.Body.Read(buf)
|
|
if err != nil {
|
|
break
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer func() {
|
|
server.CloseClientConnections()
|
|
server.Close()
|
|
}()
|
|
serverURL = server.URL
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
go func() {
|
|
time.Sleep(50 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
|
|
start := time.Now()
|
|
err := Upload(ctx, UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err == nil {
|
|
t.Fatal("Expected error from cancellation")
|
|
}
|
|
|
|
if elapsed > 500*time.Millisecond {
|
|
t.Errorf("Cancellation took %v, expected faster", elapsed)
|
|
}
|
|
|
|
t.Logf("Upload cancellation worked in %v", elapsed)
|
|
}
|
|
|
|
// ==================== Progress Tracking Tests ====================
|
|
|
|
func TestProgressTracking(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob1, data1 := createTestBlob(t, serverDir, 5000)
|
|
blob2, data2 := createTestBlob(t, serverDir, 3000)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, _ := os.ReadFile(path)
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
var progressHistory []struct{ completed, total int64 }
|
|
var mu sync.Mutex
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob1, blob2},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 1, // Sequential to make progress predictable
|
|
Progress: func(completed, total int64) {
|
|
mu.Lock()
|
|
progressHistory = append(progressHistory, struct{ completed, total int64 }{completed, total})
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob1, data1)
|
|
verifyBlob(t, clientDir, blob2, data2)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if len(progressHistory) == 0 {
|
|
t.Fatal("No progress callbacks received")
|
|
}
|
|
|
|
// Total should always be sum of blob sizes
|
|
expectedTotal := blob1.Size + blob2.Size
|
|
for _, p := range progressHistory {
|
|
if p.total != expectedTotal {
|
|
t.Errorf("Total changed during download: got %d, want %d", p.total, expectedTotal)
|
|
}
|
|
}
|
|
|
|
// Completed should be monotonically increasing
|
|
var lastCompleted int64
|
|
for _, p := range progressHistory {
|
|
if p.completed < lastCompleted {
|
|
t.Errorf("Progress went backwards: %d -> %d", lastCompleted, p.completed)
|
|
}
|
|
lastCompleted = p.completed
|
|
}
|
|
|
|
// Final completed should equal total
|
|
final := progressHistory[len(progressHistory)-1]
|
|
if final.completed != expectedTotal {
|
|
t.Errorf("Final completed %d != total %d", final.completed, expectedTotal)
|
|
}
|
|
|
|
t.Logf("Progress tracked correctly: %d callbacks, final %d/%d", len(progressHistory), final.completed, final.total)
|
|
}
|
|
|
|
// ==================== Edge Cases ====================
|
|
|
|
func TestDownloadEmptyBlobList(t *testing.T) {
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{},
|
|
BaseURL: "http://unused",
|
|
DestDir: t.TempDir(),
|
|
})
|
|
if err != nil {
|
|
t.Errorf("Expected no error for empty blob list, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestUploadEmptyBlobList(t *testing.T) {
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{},
|
|
BaseURL: "http://unused",
|
|
SrcDir: t.TempDir(),
|
|
})
|
|
if err != nil {
|
|
t.Errorf("Expected no error for empty blob list, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDownloadManyBlobs(t *testing.T) {
|
|
// Test with many blobs to verify high concurrency works
|
|
serverDir := t.TempDir()
|
|
numBlobs := 50
|
|
blobs := make([]Blob, numBlobs)
|
|
blobData := make([][]byte, numBlobs)
|
|
|
|
for i := range numBlobs {
|
|
blobs[i], blobData[i] = createTestBlob(t, serverDir, 512) // Small blobs
|
|
}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
start := time.Now()
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: blobs,
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 16,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Verify all blobs
|
|
for i, blob := range blobs {
|
|
verifyBlob(t, clientDir, blob, blobData[i])
|
|
}
|
|
|
|
t.Logf("Downloaded %d blobs in %v", numBlobs, elapsed)
|
|
}
|
|
|
|
func TestUploadRetryOnFailure(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var putCount atomic.Int32
|
|
var uploadedBlobs sync.Map
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPut:
|
|
count := putCount.Add(1)
|
|
if count < 3 {
|
|
// Fail first 2 attempts
|
|
http.Error(w, "server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
digest := r.URL.Query().Get("digest")
|
|
data, _ := io.ReadAll(r.Body)
|
|
uploadedBlobs.Store(digest, data)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with retry failed: %v", err)
|
|
}
|
|
|
|
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Error("Blob not uploaded after retry")
|
|
}
|
|
|
|
if putCount.Load() < 3 {
|
|
t.Errorf("Expected at least 3 PUT attempts, got %d", putCount.Load())
|
|
}
|
|
}
|
|
|
|
// TestProgressRollback verifies that progress is rolled back on retry
|
|
func TestProgressRollback(t *testing.T) {
|
|
content := []byte("test content for rollback test")
|
|
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
|
|
blob := Blob{Digest: digest, Size: int64(len(content))}
|
|
|
|
clientDir := t.TempDir()
|
|
path := filepath.Join(clientDir, digestToPath(digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, content, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var putCount atomic.Int32
|
|
var progressValues []int64
|
|
var mu sync.Mutex
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPut:
|
|
// Read some data before failing
|
|
io.CopyN(io.Discard, r.Body, 10)
|
|
count := putCount.Add(1)
|
|
if count < 3 {
|
|
http.Error(w, "server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
io.Copy(io.Discard, r.Body)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Progress: func(completed, total int64) {
|
|
mu.Lock()
|
|
progressValues = append(progressValues, completed)
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with retry failed: %v", err)
|
|
}
|
|
|
|
// Check that progress was rolled back (should have negative values or drops)
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Final progress should equal blob size
|
|
if len(progressValues) > 0 {
|
|
final := progressValues[len(progressValues)-1]
|
|
if final != blob.Size {
|
|
t.Errorf("Final progress %d != blob size %d", final, blob.Size)
|
|
}
|
|
}
|
|
|
|
t.Logf("Progress rollback test: %d progress callbacks", len(progressValues))
|
|
}
|
|
|
|
// TestUserAgentHeader verifies User-Agent header is set on requests
|
|
func TestUserAgentHeader(t *testing.T) {
|
|
content := []byte("test content")
|
|
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
|
|
blob := Blob{Digest: digest, Size: int64(len(content))}
|
|
|
|
destDir := t.TempDir()
|
|
var userAgents []string
|
|
var mu sync.Mutex
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
ua := r.Header.Get("User-Agent")
|
|
userAgents = append(userAgents, ua)
|
|
mu.Unlock()
|
|
|
|
if r.Method == http.MethodGet {
|
|
w.Write(content)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Test with custom User-Agent
|
|
customUA := "test-agent/1.0"
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: destDir,
|
|
UserAgent: customUA,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Verify custom User-Agent was used
|
|
for _, ua := range userAgents {
|
|
if ua != customUA {
|
|
t.Errorf("User-Agent %q != expected %q", ua, customUA)
|
|
}
|
|
}
|
|
t.Logf("User-Agent header test: %d requests with correct User-Agent", len(userAgents))
|
|
}
|
|
|
|
// TestDefaultUserAgent verifies default User-Agent is used when not specified
|
|
func TestDefaultUserAgent(t *testing.T) {
|
|
content := []byte("test content")
|
|
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
|
|
blob := Blob{Digest: digest, Size: int64(len(content))}
|
|
|
|
destDir := t.TempDir()
|
|
var userAgent string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userAgent = r.Header.Get("User-Agent")
|
|
if r.Method == http.MethodGet {
|
|
w.Write(content)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: destDir,
|
|
// No UserAgent specified - should use default
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
if userAgent == "" {
|
|
t.Error("User-Agent header was empty")
|
|
}
|
|
if userAgent != defaultUserAgent {
|
|
t.Errorf("Default User-Agent %q != expected %q", userAgent, defaultUserAgent)
|
|
}
|
|
}
|
|
|
|
// TestManifestPush verifies that manifest is pushed after blobs
|
|
func TestManifestPush(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1000)
|
|
|
|
testManifest := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`)
|
|
testRepo := "library/test-model"
|
|
testRef := "latest"
|
|
|
|
var manifestReceived []byte
|
|
var manifestPath string
|
|
var manifestContentType string
|
|
var serverURL string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Handle blob check (HEAD)
|
|
if r.Method == http.MethodHead {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
// Handle blob upload initiate (POST)
|
|
if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/blobs/uploads") {
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
// Handle blob upload (PUT to blobs)
|
|
if r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/blobs/") {
|
|
w.WriteHeader(http.StatusCreated)
|
|
return
|
|
}
|
|
|
|
// Handle manifest push (PUT to manifests)
|
|
if r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/manifests/") {
|
|
manifestPath = r.URL.Path
|
|
manifestContentType = r.Header.Get("Content-Type")
|
|
manifestReceived, _ = io.ReadAll(r.Body)
|
|
w.WriteHeader(http.StatusCreated)
|
|
return
|
|
}
|
|
|
|
http.NotFound(w, r)
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Manifest: testManifest,
|
|
ManifestRef: testRef,
|
|
Repository: testRepo,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
// Verify manifest was pushed
|
|
if manifestReceived == nil {
|
|
t.Fatal("Manifest was not received by server")
|
|
}
|
|
|
|
if !bytes.Equal(manifestReceived, testManifest) {
|
|
t.Errorf("Manifest content mismatch: got %s, want %s", manifestReceived, testManifest)
|
|
}
|
|
|
|
expectedPath := fmt.Sprintf("/v2/%s/manifests/%s", testRepo, testRef)
|
|
if manifestPath != expectedPath {
|
|
t.Errorf("Manifest path mismatch: got %s, want %s", manifestPath, expectedPath)
|
|
}
|
|
|
|
if manifestContentType != "application/vnd.docker.distribution.manifest.v2+json" {
|
|
t.Errorf("Manifest content type mismatch: got %s", manifestContentType)
|
|
}
|
|
|
|
t.Logf("Manifest push test passed: received %d bytes at %s", len(manifestReceived), manifestPath)
|
|
}
|
|
|
|
// ==================== Throughput Benchmarks ====================
|
|
|
|
func BenchmarkDownloadThroughput(b *testing.B) {
|
|
// Create test data - 1MB blob
|
|
data := make([]byte, 1024*1024)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(len(data))}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
b.SetBytes(int64(len(data)))
|
|
b.ResetTimer()
|
|
|
|
for range b.N {
|
|
clientDir := b.TempDir()
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 1,
|
|
})
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkUploadThroughput(b *testing.B) {
|
|
// Create test data - 1MB blob
|
|
data := make([]byte, 1024*1024)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(len(data))}
|
|
|
|
// Create source file once
|
|
srcDir := b.TempDir()
|
|
path := filepath.Join(srcDir, digestToPath(digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
case http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPut:
|
|
io.Copy(io.Discard, r.Body)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
b.SetBytes(int64(len(data)))
|
|
b.ResetTimer()
|
|
|
|
for range b.N {
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: srcDir,
|
|
Concurrency: 1,
|
|
})
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestThroughput is a quick throughput test that reports MB/s
|
|
func TestThroughput(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping throughput test in short mode")
|
|
}
|
|
|
|
// Test parameters - 5MB total across 5 blobs
|
|
const blobSize = 1024 * 1024 // 1MB per blob
|
|
const numBlobs = 5
|
|
const concurrency = 5
|
|
|
|
// Create test blobs
|
|
serverDir := t.TempDir()
|
|
blobs := make([]Blob, numBlobs)
|
|
for i := range numBlobs {
|
|
data := make([]byte, blobSize)
|
|
// Different seed per blob for unique digests
|
|
for j := range data {
|
|
data[j] = byte((i*256 + j) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blobs[i] = Blob{Digest: digest, Size: int64(len(data))}
|
|
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(path), 0o755)
|
|
os.WriteFile(path, data, 0o644)
|
|
}
|
|
|
|
totalBytes := int64(blobSize * numBlobs)
|
|
|
|
// Download server
|
|
dlServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer dlServer.Close()
|
|
|
|
// Measure download throughput
|
|
clientDir := t.TempDir()
|
|
start := time.Now()
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: blobs,
|
|
BaseURL: dlServer.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: concurrency,
|
|
})
|
|
dlElapsed := time.Since(start)
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
dlThroughput := float64(totalBytes) / dlElapsed.Seconds() / (1024 * 1024)
|
|
t.Logf("Download: %.2f MB/s (%d bytes in %v)", dlThroughput, totalBytes, dlElapsed)
|
|
|
|
// Upload server
|
|
var ulServerURL string
|
|
ulServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
case http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", ulServerURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPut:
|
|
io.Copy(io.Discard, r.Body)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer ulServer.Close()
|
|
ulServerURL = ulServer.URL
|
|
|
|
// Measure upload throughput
|
|
start = time.Now()
|
|
err = Upload(context.Background(), UploadOptions{
|
|
Blobs: blobs,
|
|
BaseURL: ulServer.URL,
|
|
SrcDir: serverDir,
|
|
Concurrency: concurrency,
|
|
})
|
|
ulElapsed := time.Since(start)
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
ulThroughput := float64(totalBytes) / ulElapsed.Seconds() / (1024 * 1024)
|
|
t.Logf("Upload: %.2f MB/s (%d bytes in %v)", ulThroughput, totalBytes, ulElapsed)
|
|
|
|
// Sanity check - local transfers should be fast (>50 MB/s is reasonable for local)
|
|
// This ensures the implementation isn't artificially throttled
|
|
if dlThroughput < 10 {
|
|
t.Errorf("Download throughput unexpectedly low: %.2f MB/s", dlThroughput)
|
|
}
|
|
if ulThroughput < 10 {
|
|
t.Errorf("Upload throughput unexpectedly low: %.2f MB/s", ulThroughput)
|
|
}
|
|
|
|
// Overall time check - should complete in <500ms for local transfers
|
|
if dlElapsed+ulElapsed > 500*time.Millisecond {
|
|
t.Logf("Warning: total time %v exceeds 500ms target", dlElapsed+ulElapsed)
|
|
}
|
|
}
|