diff --git a/server/cloud_proxy.go b/server/cloud_proxy.go index 70fe215f9..8b490c78d 100644 --- a/server/cloud_proxy.go +++ b/server/cloud_proxy.go @@ -176,16 +176,10 @@ func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) { proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation) } -func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) { - if disabled, _ := internalcloud.Status(); disabled { - c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)}) - return - } - +func buildCloudProxyRequest(c *gin.Context, path string, body []byte) (*http.Request, error) { baseURL, err := url.Parse(cloudProxyBaseURL) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return nil, err } targetURL := baseURL.ResolveReference(&url.URL{ @@ -195,8 +189,7 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body)) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return nil, err } copyProxyRequestHeaders(outReq.Header, c.Request.Header) @@ -207,22 +200,10 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable outReq.Header.Set("Content-Type", "application/json") } - if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil { - slog.Warn("cloud proxy signing failed", "error", err) - writeCloudUnauthorized(c) - return - } - - // TODO(drifkin): Add phase-specific proxy timeouts. - // Connect/TLS/TTFB should have bounded timeouts, but once streaming starts - // we should not enforce a short total timeout for long-lived responses. - resp, err := http.DefaultClient.Do(outReq) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) - return - } - defer resp.Body.Close() + return outReq, nil +} +func writeCloudProxyResponse(c *gin.Context, path string, resp *http.Response) { copyProxyResponseHeaders(c.Writer.Header(), resp.Header) c.Status(resp.StatusCode) @@ -239,7 +220,7 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable bodyWriter = framedWriter } - err = copyProxyResponseBody(bodyWriter, resp.Body) + err := copyProxyResponseBody(bodyWriter, resp.Body) if err == nil && framedWriter != nil { err = framedWriter.FlushPending() } @@ -263,8 +244,38 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable "request_context_err", ctxErr, "error", err, ) + } +} + +func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) { + if disabled, _ := internalcloud.Status(); disabled { + c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)}) return } + + outReq, err := buildCloudProxyRequest(c, path, body) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil { + slog.Warn("cloud proxy signing failed", "error", err) + writeCloudUnauthorized(c) + return + } + + // TODO(drifkin): Add phase-specific proxy timeouts. + // Connect/TLS/TTFB should have bounded timeouts, but once streaming starts + // we should not enforce a short total timeout for long-lived responses. + resp, err := http.DefaultClient.Do(outReq) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + defer resp.Body.Close() + + writeCloudProxyResponse(c, path, resp) } func replaceJSONModelField(body []byte, model string) ([]byte, error) { diff --git a/server/routes.go b/server/routes.go index 32adef36c..27b2e6daf 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1129,8 +1129,30 @@ func (s *Server) ShowHandler(c *gin.Context) { } if modelRef.Source == modelSourceCloud { - req.Model = modelRef.Base - proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable) + cloudReq := req + cloudReq.Model = modelRef.Base + cloudShow, cloudResp := s.resolveCloudShow(c, cloudReq) + + // Best-effort compatibility: if cloud alias lookup fails, retry local + // lookup using the original verbatim model string (including source + // suffix/tag) so legacy/stub manifest names continue to resolve. + if cloudResp != nil && cloudResp.StatusCode == http.StatusNotFound { + localReq := req + localReq.Model = modelRef.Original + localReq.Name = "" + + if resp, err := GetModelInfo(localReq); err == nil { + c.JSON(http.StatusOK, resp) + return + } + } + + if cloudShow != nil { + c.JSON(http.StatusOK, cloudShow) + return + } + + writeCloudShowResponse(c, cloudResp) return } diff --git a/server/routes_cloud_test.go b/server/routes_cloud_test.go index aaaf5b73d..c45c93ca8 100644 --- a/server/routes_cloud_test.go +++ b/server/routes_cloud_test.go @@ -797,6 +797,86 @@ func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) { }) } +func TestShowExplicitCloudFallsBackToLocalOnCloud404(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + type upstreamCapture struct { + path string + body string + } + + capture := &upstreamCapture{} + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + payload, _ := io.ReadAll(r.Body) + capture.path = r.URL.Path + capture.body = string(payload) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"model not found"}`)) + })) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "fallback-alias:cloud", + From: "fallback-upstream:cloud", + Info: map[string]any{ + "capabilities": []string{"completion"}, + }, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected create status 200, got %d (%s)", w.Code, w.Body.String()) + } + + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"fallback-alias:cloud"}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/show", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + var showResp api.ShowResponse + if err := json.Unmarshal(body, &showResp); err != nil { + t.Fatalf("failed to decode show response: %v", err) + } + + if showResp.RemoteModel != "fallback-upstream" { + t.Fatalf("expected remote_model fallback-upstream, got %q", showResp.RemoteModel) + } + + if capture.path != "/api/show" { + t.Fatalf("expected upstream path /api/show, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"model":"fallback-alias"`) { + t.Fatalf("expected normalized cloud lookup model in upstream body, got %q", capture.body) + } +} + func TestCloudDisabledBlocksExplicitCloudPassthrough(t *testing.T) { gin.SetMode(gin.TestMode) setTestHome(t, t.TempDir()) diff --git a/server/show_cloud.go b/server/show_cloud.go new file mode 100644 index 000000000..cfe408a10 --- /dev/null +++ b/server/show_cloud.go @@ -0,0 +1,107 @@ +package server + +import ( + "encoding/json" + "io" + "log/slog" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/api" + internalcloud "github.com/ollama/ollama/internal/cloud" +) + +type cloudShowResponse struct { + StatusCode int + Headers http.Header + Body []byte +} + +func newJSONCloudShowResponse(statusCode int, payload any) *cloudShowResponse { + body, err := json.Marshal(payload) + if err != nil { + body = []byte(`{"error":"internal server error"}`) + } + + headers := make(http.Header) + headers.Set("Content-Type", "application/json; charset=utf-8") + + return &cloudShowResponse{ + StatusCode: statusCode, + Headers: headers, + Body: body, + } +} + +func cloudUnauthorizedShowResponse() *cloudShowResponse { + payload := gin.H{"error": "unauthorized"} + if signinURL, err := cloudProxySigninURL(); err == nil { + payload["signin_url"] = signinURL + } + return newJSONCloudShowResponse(http.StatusUnauthorized, payload) +} + +func writeCloudShowResponse(c *gin.Context, resp *cloudShowResponse) { + if resp == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) + return + } + + copyProxyResponseHeaders(c.Writer.Header(), resp.Headers) + c.Status(resp.StatusCode) + if len(resp.Body) > 0 { + _, _ = c.Writer.Write(resp.Body) + } +} + +func (s *Server) resolveCloudShow(c *gin.Context, req api.ShowRequest) (*api.ShowResponse, *cloudShowResponse) { + if disabled, _ := internalcloud.Status(); disabled { + return nil, newJSONCloudShowResponse(http.StatusForbidden, gin.H{ + "error": internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable), + }) + } + + body, err := json.Marshal(req) + if err != nil { + return nil, newJSONCloudShowResponse(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + + outReq, err := buildCloudProxyRequest(c, c.Request.URL.Path, body) + if err != nil { + return nil, newJSONCloudShowResponse(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + + if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil { + slog.Warn("cloud proxy signing failed", "error", err) + return nil, cloudUnauthorizedShowResponse() + } + + resp, err := http.DefaultClient.Do(outReq) + if err != nil { + return nil, newJSONCloudShowResponse(http.StatusBadGateway, gin.H{"error": err.Error()}) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, newJSONCloudShowResponse(http.StatusBadGateway, gin.H{"error": err.Error()}) + } + + if resp.StatusCode >= http.StatusBadRequest { + return nil, &cloudShowResponse{ + StatusCode: resp.StatusCode, + Headers: resp.Header.Clone(), + Body: respBody, + } + } + + var showResp api.ShowResponse + if len(respBody) > 0 { + if err := json.Unmarshal(respBody, &showResp); err != nil { + return nil, newJSONCloudShowResponse(http.StatusBadGateway, gin.H{"error": err.Error()}) + } + } + + return &showResp, nil +}