[WIP] api/show: fall back to local on cloud 404

It's possible to have a local model with a `:cloud` suffix. If the cloud
responds with a 404, then we now fall back to looking for it locally
before returning to the client.

The cloud proxy has been slightly refactored so parts of it can be
reused here (in this flow we inspect the response, so it's no longer a
pure passthrough)
This commit is contained in:
Devon Rifkin
2026-03-18 15:56:15 -07:00
parent 676d9845ba
commit f9a46b73da
4 changed files with 248 additions and 28 deletions

View File

@@ -176,16 +176,10 @@ func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation) proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
} }
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) { func buildCloudProxyRequest(c *gin.Context, path string, body []byte) (*http.Request, error) {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
return
}
baseURL, err := url.Parse(cloudProxyBaseURL) baseURL, err := url.Parse(cloudProxyBaseURL)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return nil, err
return
} }
targetURL := baseURL.ResolveReference(&url.URL{ targetURL := baseURL.ResolveReference(&url.URL{
@@ -195,8 +189,7 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body)) outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return nil, err
return
} }
copyProxyRequestHeaders(outReq.Header, c.Request.Header) copyProxyRequestHeaders(outReq.Header, c.Request.Header)
@@ -207,22 +200,10 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
outReq.Header.Set("Content-Type", "application/json") outReq.Header.Set("Content-Type", "application/json")
} }
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil { return outReq, nil
slog.Warn("cloud proxy signing failed", "error", err) }
writeCloudUnauthorized(c)
return
}
// TODO(drifkin): Add phase-specific proxy timeouts.
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
// we should not enforce a short total timeout for long-lived responses.
resp, err := http.DefaultClient.Do(outReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
func writeCloudProxyResponse(c *gin.Context, path string, resp *http.Response) {
copyProxyResponseHeaders(c.Writer.Header(), resp.Header) copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
@@ -239,7 +220,7 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
bodyWriter = framedWriter bodyWriter = framedWriter
} }
err = copyProxyResponseBody(bodyWriter, resp.Body) err := copyProxyResponseBody(bodyWriter, resp.Body)
if err == nil && framedWriter != nil { if err == nil && framedWriter != nil {
err = framedWriter.FlushPending() err = framedWriter.FlushPending()
} }
@@ -263,8 +244,38 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
"request_context_err", ctxErr, "request_context_err", ctxErr,
"error", err, "error", err,
) )
}
}
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
return return
} }
outReq, err := buildCloudProxyRequest(c, path, body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
slog.Warn("cloud proxy signing failed", "error", err)
writeCloudUnauthorized(c)
return
}
// TODO(drifkin): Add phase-specific proxy timeouts.
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
// we should not enforce a short total timeout for long-lived responses.
resp, err := http.DefaultClient.Do(outReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
writeCloudProxyResponse(c, path, resp)
} }
func replaceJSONModelField(body []byte, model string) ([]byte, error) { func replaceJSONModelField(body []byte, model string) ([]byte, error) {

View File

@@ -1129,8 +1129,30 @@ func (s *Server) ShowHandler(c *gin.Context) {
} }
if modelRef.Source == modelSourceCloud { if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base cloudReq := req
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable) cloudReq.Model = modelRef.Base
cloudShow, cloudResp := s.resolveCloudShow(c, cloudReq)
// Best-effort compatibility: if cloud alias lookup fails, retry local
// lookup using the original verbatim model string (including source
// suffix/tag) so legacy/stub manifest names continue to resolve.
if cloudResp != nil && cloudResp.StatusCode == http.StatusNotFound {
localReq := req
localReq.Model = modelRef.Original
localReq.Name = ""
if resp, err := GetModelInfo(localReq); err == nil {
c.JSON(http.StatusOK, resp)
return
}
}
if cloudShow != nil {
c.JSON(http.StatusOK, cloudShow)
return
}
writeCloudShowResponse(c, cloudResp)
return return
} }

View File

@@ -797,6 +797,86 @@ func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) {
}) })
} }
func TestShowExplicitCloudFallsBackToLocalOnCloud404(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
type upstreamCapture struct {
path string
body string
}
capture := &upstreamCapture{}
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
payload, _ := io.ReadAll(r.Body)
capture.path = r.URL.Path
capture.body = string(payload)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`{"error":"model not found"}`))
}))
defer upstream.Close()
original := cloudProxyBaseURL
cloudProxyBaseURL = upstream.URL
t.Cleanup(func() { cloudProxyBaseURL = original })
s := &Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "fallback-alias:cloud",
From: "fallback-upstream:cloud",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected create status 200, got %d (%s)", w.Code, w.Body.String())
}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
reqBody := `{"model":"fallback-alias:cloud"}`
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/show", bytes.NewBufferString(reqBody))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
}
var showResp api.ShowResponse
if err := json.Unmarshal(body, &showResp); err != nil {
t.Fatalf("failed to decode show response: %v", err)
}
if showResp.RemoteModel != "fallback-upstream" {
t.Fatalf("expected remote_model fallback-upstream, got %q", showResp.RemoteModel)
}
if capture.path != "/api/show" {
t.Fatalf("expected upstream path /api/show, got %q", capture.path)
}
if !strings.Contains(capture.body, `"model":"fallback-alias"`) {
t.Fatalf("expected normalized cloud lookup model in upstream body, got %q", capture.body)
}
}
func TestCloudDisabledBlocksExplicitCloudPassthrough(t *testing.T) { func TestCloudDisabledBlocksExplicitCloudPassthrough(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir()) setTestHome(t, t.TempDir())

107
server/show_cloud.go Normal file
View File

@@ -0,0 +1,107 @@
package server
import (
"encoding/json"
"io"
"log/slog"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
)
type cloudShowResponse struct {
StatusCode int
Headers http.Header
Body []byte
}
func newJSONCloudShowResponse(statusCode int, payload any) *cloudShowResponse {
body, err := json.Marshal(payload)
if err != nil {
body = []byte(`{"error":"internal server error"}`)
}
headers := make(http.Header)
headers.Set("Content-Type", "application/json; charset=utf-8")
return &cloudShowResponse{
StatusCode: statusCode,
Headers: headers,
Body: body,
}
}
func cloudUnauthorizedShowResponse() *cloudShowResponse {
payload := gin.H{"error": "unauthorized"}
if signinURL, err := cloudProxySigninURL(); err == nil {
payload["signin_url"] = signinURL
}
return newJSONCloudShowResponse(http.StatusUnauthorized, payload)
}
func writeCloudShowResponse(c *gin.Context, resp *cloudShowResponse) {
if resp == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
return
}
copyProxyResponseHeaders(c.Writer.Header(), resp.Headers)
c.Status(resp.StatusCode)
if len(resp.Body) > 0 {
_, _ = c.Writer.Write(resp.Body)
}
}
func (s *Server) resolveCloudShow(c *gin.Context, req api.ShowRequest) (*api.ShowResponse, *cloudShowResponse) {
if disabled, _ := internalcloud.Status(); disabled {
return nil, newJSONCloudShowResponse(http.StatusForbidden, gin.H{
"error": internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable),
})
}
body, err := json.Marshal(req)
if err != nil {
return nil, newJSONCloudShowResponse(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
outReq, err := buildCloudProxyRequest(c, c.Request.URL.Path, body)
if err != nil {
return nil, newJSONCloudShowResponse(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
slog.Warn("cloud proxy signing failed", "error", err)
return nil, cloudUnauthorizedShowResponse()
}
resp, err := http.DefaultClient.Do(outReq)
if err != nil {
return nil, newJSONCloudShowResponse(http.StatusBadGateway, gin.H{"error": err.Error()})
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, newJSONCloudShowResponse(http.StatusBadGateway, gin.H{"error": err.Error()})
}
if resp.StatusCode >= http.StatusBadRequest {
return nil, &cloudShowResponse{
StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(),
Body: respBody,
}
}
var showResp api.ShowResponse
if len(respBody) > 0 {
if err := json.Unmarshal(respBody, &showResp); err != nil {
return nil, newJSONCloudShowResponse(http.StatusBadGateway, gin.H{"error": err.Error()})
}
}
return &showResp, nil
}