diff --git a/middleware/openai.go b/middleware/openai.go index 8527f00de..76853aca5 100644 --- a/middleware/openai.go +++ b/middleware/openai.go @@ -18,6 +18,9 @@ import ( "github.com/ollama/ollama/openai" ) +// maxDecompressedBodySize limits the size of a decompressed request body +const maxDecompressedBodySize = 20 << 20 + type BaseWriter struct { gin.ResponseWriter } @@ -512,7 +515,7 @@ func ResponsesMiddleware() gin.HandlerFunc { return } defer reader.Close() - c.Request.Body = io.NopCloser(reader) + c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize) c.Request.Header.Del("Content-Encoding") } diff --git a/server/cloud_proxy.go b/server/cloud_proxy.go index ca9a10da4..4ab9c1a77 100644 --- a/server/cloud_proxy.go +++ b/server/cloud_proxy.go @@ -16,6 +16,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/klauspost/compress/zstd" "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" @@ -29,6 +30,9 @@ const ( cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL" legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search" cloudProxyClientVersionHeader = "X-Ollama-Client-Version" + + // maxDecompressedBodySize limits the size of a decompressed request body + maxDecompressedBodySize = 20 << 20 ) var ( @@ -73,6 +77,19 @@ func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc { return } + // Decompress zstd-encoded request bodies so we can inspect the model + if c.GetHeader("Content-Encoding") == "zstd" { + reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20)) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "failed to decompress request body"}) + c.Abort() + return + } + defer reader.Close() + c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize) + c.Request.Header.Del("Content-Encoding") + } + // TODO(drifkin): Avoid full-body buffering here for model detection. // A future optimization can parse just enough JSON to read "model" (and // optionally short-circuit cloud-disabled explicit-cloud requests) while diff --git a/server/cloud_proxy_test.go b/server/cloud_proxy_test.go index 1a7b27956..1bac5cc62 100644 --- a/server/cloud_proxy_test.go +++ b/server/cloud_proxy_test.go @@ -1,10 +1,14 @@ package server import ( + "bytes" + "io" "net/http" + "net/http/httptest" "testing" "github.com/gin-gonic/gin" + "github.com/klauspost/compress/zstd" ) func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) { @@ -137,6 +141,98 @@ func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) { } } +func TestCloudPassthroughMiddleware_ZstdBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + plainBody := []byte(`{"model":"test-model:cloud","messages":[{"role":"user","content":"hi"}]}`) + var compressed bytes.Buffer + w, err := zstd.NewWriter(&compressed) + if err != nil { + t.Fatalf("zstd writer: %v", err) + } + if _, err := w.Write(plainBody); err != nil { + t.Fatalf("zstd write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("zstd close: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes())) + req.Header.Set("Content-Encoding", "zstd") + rec := httptest.NewRecorder() + + // Track whether the middleware detected the cloud model by checking + // if c.Next() was called (non-cloud path) vs c.Abort() (cloud path). + nextCalled := false + + r := gin.New() + r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) { + nextCalled = true + // Verify the body is decompressed and Content-Encoding is removed. + body, err := io.ReadAll(c.Request.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + model, ok := extractModelField(body) + if !ok { + t.Fatal("expected to extract model from decompressed body") + } + if model != "test-model:cloud" { + t.Fatalf("expected model %q, got %q", "test-model:cloud", model) + } + if enc := c.GetHeader("Content-Encoding"); enc != "" { + t.Fatalf("expected Content-Encoding to be removed, got %q", enc) + } + c.Status(http.StatusOK) + }) + r.ServeHTTP(rec, req) + + // The cloud passthrough middleware should detect the cloud model and + // proxy (abort), so the next handler should NOT be called. + // However, since there's no actual cloud server to proxy to, the + // middleware will attempt to proxy and fail. We just verify it didn't + // fall through to c.Next() due to failure to read the compressed body. + if nextCalled { + t.Fatal("expected cloud passthrough to detect cloud model from zstd body, but it fell through to next handler") + } +} + +func TestCloudPassthroughMiddleware_ZstdBodyTooLarge(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Create a body that exceeds the 20MB limit + oversized := make([]byte, maxDecompressedBodySize+1024) + for i := range oversized { + oversized[i] = 'A' + } + + var compressed bytes.Buffer + w, err := zstd.NewWriter(&compressed) + if err != nil { + t.Fatalf("zstd writer: %v", err) + } + if _, err := w.Write(oversized); err != nil { + t.Fatalf("zstd write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("zstd close: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes())) + req.Header.Set("Content-Encoding", "zstd") + rec := httptest.NewRecorder() + + r := gin.New() + r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) { + t.Fatal("handler should not be reached for oversized body") + }) + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", rec.Code) + } +} + func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) { req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil) if err != nil {