mirror of
https://github.com/ollama/ollama.git
synced 2026-03-28 03:08:44 +07:00
Removes 5-minute HTTP client timeout that caused "context deadline exceeded" errors on large file downloads. Stall detection (10s) already handles unresponsive connections. Fixes progress bar total going down on resume by calculating total from all blobs upfront and reporting already-downloaded bytes as completed immediately.
330 lines
7.4 KiB
Go
330 lines
7.4 KiB
Go
package transfer
|
|
|
|
import (
|
|
"cmp"
|
|
"context"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/sync/semaphore"
|
|
)
|
|
|
|
var (
|
|
errStalled = errors.New("download stalled")
|
|
errSlow = errors.New("download too slow")
|
|
)
|
|
|
|
type downloader struct {
|
|
client *http.Client
|
|
baseURL string
|
|
destDir string
|
|
repository string // Repository path for blob URLs (e.g., "library/model")
|
|
token *string
|
|
getToken func(context.Context, AuthChallenge) (string, error)
|
|
userAgent string
|
|
stallTimeout time.Duration
|
|
progress *progressTracker
|
|
speeds *speedTracker
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func download(ctx context.Context, opts DownloadOptions) error {
|
|
if len(opts.Blobs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Calculate total from all blobs (for accurate progress reporting on resume)
|
|
var total int64
|
|
for _, b := range opts.Blobs {
|
|
total += b.Size
|
|
}
|
|
|
|
// Filter out already-downloaded blobs and track completed bytes
|
|
var blobs []Blob
|
|
var alreadyCompleted int64
|
|
for _, b := range opts.Blobs {
|
|
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
|
|
if opts.Logger != nil {
|
|
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
|
|
}
|
|
alreadyCompleted += b.Size
|
|
continue
|
|
}
|
|
blobs = append(blobs, b)
|
|
}
|
|
if len(blobs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
token := opts.Token
|
|
progress := newProgressTracker(total, opts.Progress)
|
|
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
|
|
|
|
d := &downloader{
|
|
client: cmp.Or(opts.Client, defaultClient),
|
|
baseURL: opts.BaseURL,
|
|
destDir: opts.DestDir,
|
|
repository: cmp.Or(opts.Repository, "library/_"),
|
|
token: &token,
|
|
getToken: opts.GetToken,
|
|
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
|
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
|
|
progress: progress,
|
|
speeds: &speedTracker{},
|
|
logger: opts.Logger,
|
|
}
|
|
|
|
concurrency := cmp.Or(opts.Concurrency, DefaultDownloadConcurrency)
|
|
sem := semaphore.NewWeighted(int64(concurrency))
|
|
|
|
g, ctx := errgroup.WithContext(ctx)
|
|
for _, blob := range blobs {
|
|
g.Go(func() error {
|
|
if err := sem.Acquire(ctx, 1); err != nil {
|
|
return err
|
|
}
|
|
defer sem.Release(1)
|
|
return d.download(ctx, blob)
|
|
})
|
|
}
|
|
return g.Wait()
|
|
}
|
|
|
|
func (d *downloader) download(ctx context.Context, blob Blob) error {
|
|
var lastErr error
|
|
var slowRetries int
|
|
attempt := 0
|
|
|
|
for attempt < maxRetries {
|
|
if attempt > 0 {
|
|
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
start := time.Now()
|
|
n, err := d.downloadOnce(ctx, blob)
|
|
if err == nil {
|
|
if s := time.Since(start).Seconds(); s > 0 {
|
|
d.speeds.record(float64(blob.Size) / s)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
d.progress.add(-n) // rollback
|
|
|
|
switch {
|
|
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
|
return err
|
|
case errors.Is(err, errStalled):
|
|
// Don't count stall retries against limit
|
|
case errors.Is(err, errSlow):
|
|
if slowRetries++; slowRetries >= 3 {
|
|
attempt++ // Only count after 3 slow retries
|
|
}
|
|
default:
|
|
attempt++
|
|
}
|
|
lastErr = err
|
|
}
|
|
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
|
|
}
|
|
|
|
func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error) {
|
|
if d.logger != nil {
|
|
d.logger.Debug("downloading blob", "digest", blob.Digest, "size", blob.Size)
|
|
}
|
|
|
|
baseURL, _ := url.Parse(d.baseURL)
|
|
u, err := d.resolve(ctx, fmt.Sprintf("%s/v2/%s/blobs/%s", d.baseURL, d.repository, blob.Digest))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
|
req.Header.Set("User-Agent", d.userAgent)
|
|
// Add auth only for same-host (not CDN)
|
|
if u.Host == baseURL.Host && *d.token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+*d.token)
|
|
}
|
|
|
|
resp, err := d.client.Do(req)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return 0, fmt.Errorf("status %d", resp.StatusCode)
|
|
}
|
|
|
|
return d.save(ctx, blob, resp.Body)
|
|
}
|
|
|
|
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) {
|
|
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
|
|
tmp := dest + ".tmp"
|
|
os.MkdirAll(filepath.Dir(dest), 0o755)
|
|
|
|
f, err := os.Create(tmp)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer f.Close()
|
|
setSparse(f)
|
|
|
|
h := sha256.New()
|
|
n, err := d.copy(ctx, f, r, h)
|
|
if err != nil {
|
|
os.Remove(tmp)
|
|
return n, err
|
|
}
|
|
f.Close()
|
|
|
|
if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest {
|
|
os.Remove(tmp)
|
|
return n, fmt.Errorf("digest mismatch")
|
|
}
|
|
if n != blob.Size {
|
|
os.Remove(tmp)
|
|
return n, fmt.Errorf("size mismatch")
|
|
}
|
|
return n, os.Rename(tmp, dest)
|
|
}
|
|
|
|
func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) {
|
|
var n int64
|
|
var lastRead atomic.Int64
|
|
lastRead.Store(time.Now().UnixNano())
|
|
start := time.Now()
|
|
|
|
ctx, cancel := context.WithCancelCause(ctx)
|
|
defer cancel(nil)
|
|
|
|
go func() {
|
|
tick := time.NewTicker(time.Second)
|
|
defer tick.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-tick.C:
|
|
if time.Since(time.Unix(0, lastRead.Load())) > d.stallTimeout {
|
|
cancel(errStalled)
|
|
return
|
|
}
|
|
if e := time.Since(start); e > 5*time.Second {
|
|
if m := d.speeds.median(); m > 0 && float64(n)/e.Seconds() < m*0.1 {
|
|
cancel(errSlow)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
if err := ctx.Err(); err != nil {
|
|
if c := context.Cause(ctx); c != nil {
|
|
return n, c
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
nr, err := src.Read(buf)
|
|
if nr > 0 {
|
|
lastRead.Store(time.Now().UnixNano())
|
|
dst.Write(buf[:nr])
|
|
h.Write(buf[:nr])
|
|
d.progress.add(int64(nr))
|
|
n += int64(nr)
|
|
}
|
|
if err == io.EOF {
|
|
return n, nil
|
|
}
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) {
|
|
u, _ := url.Parse(rawURL)
|
|
for range 10 {
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
|
req.Header.Set("User-Agent", d.userAgent)
|
|
if *d.token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+*d.token)
|
|
}
|
|
|
|
resp, err := d.client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp.Body.Close()
|
|
|
|
switch resp.StatusCode {
|
|
case http.StatusOK:
|
|
return u, nil
|
|
case http.StatusUnauthorized:
|
|
if d.getToken == nil {
|
|
return nil, fmt.Errorf("unauthorized")
|
|
}
|
|
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
|
if *d.token, err = d.getToken(ctx, ch); err != nil {
|
|
return nil, err
|
|
}
|
|
case http.StatusTemporaryRedirect, http.StatusFound, http.StatusMovedPermanently:
|
|
loc, _ := resp.Location()
|
|
if loc.Host != u.Host {
|
|
return loc, nil
|
|
}
|
|
u = loc
|
|
default:
|
|
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("too many redirects")
|
|
}
|
|
|
|
type speedTracker struct {
|
|
mu sync.Mutex
|
|
speeds []float64
|
|
}
|
|
|
|
func (s *speedTracker) record(v float64) {
|
|
s.mu.Lock()
|
|
s.speeds = append(s.speeds, v)
|
|
if len(s.speeds) > 30 {
|
|
s.speeds = s.speeds[1:]
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *speedTracker) median() float64 {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if len(s.speeds) < 5 {
|
|
return 0
|
|
}
|
|
sorted := make([]float64, len(s.speeds))
|
|
copy(sorted, s.speeds)
|
|
slices.Sort(sorted)
|
|
return sorted[len(sorted)/2]
|
|
}
|
|
|
|
const defaultStallTimeout = 10 * time.Second
|