From 9e2003f88a943906113d99c2cf9a15b5bfd4dde9 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Fri, 6 Feb 2026 13:19:47 -0500 Subject: [PATCH] cmd/config: offer to pull missing models instead of erroring (#14113) --- cmd/config/integrations.go | 43 ++++++++-- cmd/config/integrations_test.go | 138 ++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 7 deletions(-) diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go index 69bf55a62..7fb3667df 100644 --- a/cmd/config/integrations.go +++ b/cmd/config/integrations.go @@ -194,6 +194,20 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st return nil } +// showOrPull checks if a model exists via client.Show and offers to pull it if not found. +func showOrPull(ctx context.Context, client *api.Client, model string) error { + if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil { + return nil + } + if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil { + return err + } else if !ok { + return errCancelled + } + fmt.Fprintf(os.Stderr, "\n") + return pullModel(ctx, client, model) +} + func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) { client, err := api.ClientFromEnvironment() if err != nil { @@ -397,8 +411,11 @@ Examples: // Validate --model flag if provided if modelFlag != "" { - if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil { - return fmt.Errorf("model %q not found", modelFlag) + if err := showOrPull(cmd.Context(), client, modelFlag); err != nil { + if errors.Is(err, errCancelled) { + return nil + } + return err } } @@ -424,9 +441,11 @@ Examples: // Validate saved model still exists if model != "" && modelFlag == "" { - if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil { + if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil { fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset) - model = "" + if err := showOrPull(cmd.Context(), client, model); err != nil { + model = "" + } } } @@ -443,6 +462,13 @@ Examples: existingAliases = aliases } + // Ensure cloud models are authenticated + if isCloudModel(cmd.Context(), client, model) { + if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil { + return err + } + } + // Sync aliases and save if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil { fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset) @@ -467,8 +493,11 @@ Examples: if err != nil { return err } - if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil { - return fmt.Errorf("model %q not found", modelFlag) + if err := showOrPull(cmd.Context(), client, modelFlag); err != nil { + if errors.Is(err, errCancelled) { + return nil + } + return err } } @@ -650,7 +679,7 @@ func isCloudModel(ctx context.Context, client *api.Client, name string) bool { if client == nil { return false } - resp, err := client.Show(ctx, &api.ShowRequest{Name: name}) + resp, err := client.Show(ctx, &api.ShowRequest{Model: name}) if err != nil { return false } diff --git a/cmd/config/integrations_test.go b/cmd/config/integrations_test.go index 4c0eb4a05..14d89e59a 100644 --- a/cmd/config/integrations_test.go +++ b/cmd/config/integrations_test.go @@ -2,12 +2,17 @@ package config import ( "context" + "encoding/json" "fmt" + "net/http" + "net/http/httptest" + "net/url" "slices" "strings" "testing" "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" "github.com/spf13/cobra" ) @@ -539,3 +544,136 @@ func TestAliasConfigurerInterface(t *testing.T) { } }) } + +func TestShowOrPull_ModelExists(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"model":"test-model"}`) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := showOrPull(context.Background(), client, "test-model") + if err != nil { + t.Errorf("showOrPull should return nil when model exists, got: %v", err) + } +} + +func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"model not found"}`) + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + // confirmPrompt will fail in test (no terminal), so showOrPull should return an error + err := showOrPull(context.Background(), client, "missing-model") + if err == nil { + t.Error("showOrPull should return error when model not found and no terminal available") + } +} + +func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) { + var receivedModel string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" { + var req api.ShowRequest + if err := json.NewDecoder(r.Body).Decode(&req); err == nil { + receivedModel = req.Model + } + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"model":"%s"}`, receivedModel) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + _ = showOrPull(context.Background(), client, "qwen3:8b") + if receivedModel != "qwen3:8b" { + t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel) + } +} + +func TestEnsureAuth_NoCloudModels(t *testing.T) { + // ensureAuth should be a no-op when no cloud models are selected + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("no API calls expected when no cloud models selected") + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"}) + if err != nil { + t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err) + } +} + +func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) { + // ensureAuth should only care about models in cloudModels map + var whoamiCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/me" { + whoamiCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"name":"testuser"}`) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + cloudModels := map[string]bool{"cloud-model:cloud": true} + selected := []string{"cloud-model:cloud", "local-model"} + + err := ensureAuth(context.Background(), client, cloudModels, selected) + if err != nil { + t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err) + } + if !whoamiCalled { + t.Error("expected whoami to be called for cloud model") + } +} + +func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) { + var whoamiCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/me" { + whoamiCalled = true + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + // cloudModels has entries but none are in selected + cloudModels := map[string]bool{"cloud-model:cloud": true} + selected := []string{"local-model"} + + err := ensureAuth(context.Background(), client, cloudModels, selected) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if whoamiCalled { + t.Error("whoami should not be called when no cloud models are selected") + } +}