server: decompress zstd request bodies in cloud passthrough middleware (#14827)

When a zstd-compressed request (e.g. from Codex CLI) hits /v1/responses
with a cloud model the request failed.

Fix by decompressing zstd bodies before
model extraction, so cloud models are detected and proxied directly
without the writer being wrapped.
This commit is contained in:
Bruce MacDonald
2026-03-13 15:06:47 -07:00
committed by GitHub
parent 870599f5da
commit 3980c0217d
3 changed files with 117 additions and 1 deletions

View File

@@ -18,6 +18,9 @@ import (
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
) )
// maxDecompressedBodySize limits the size of a decompressed request body
const maxDecompressedBodySize = 20 << 20
type BaseWriter struct { type BaseWriter struct {
gin.ResponseWriter gin.ResponseWriter
} }
@@ -512,7 +515,7 @@ func ResponsesMiddleware() gin.HandlerFunc {
return return
} }
defer reader.Close() defer reader.Close()
c.Request.Body = io.NopCloser(reader) c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize)
c.Request.Header.Del("Content-Encoding") c.Request.Header.Del("Content-Encoding")
} }

View File

@@ -16,6 +16,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@@ -29,6 +30,9 @@ const (
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL" cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search" legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
cloudProxyClientVersionHeader = "X-Ollama-Client-Version" cloudProxyClientVersionHeader = "X-Ollama-Client-Version"
// maxDecompressedBodySize limits the size of a decompressed request body
maxDecompressedBodySize = 20 << 20
) )
var ( var (
@@ -73,6 +77,19 @@ func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
return return
} }
// Decompress zstd-encoded request bodies so we can inspect the model
if c.GetHeader("Content-Encoding") == "zstd" {
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to decompress request body"})
c.Abort()
return
}
defer reader.Close()
c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize)
c.Request.Header.Del("Content-Encoding")
}
// TODO(drifkin): Avoid full-body buffering here for model detection. // TODO(drifkin): Avoid full-body buffering here for model detection.
// A future optimization can parse just enough JSON to read "model" (and // A future optimization can parse just enough JSON to read "model" (and
// optionally short-circuit cloud-disabled explicit-cloud requests) while // optionally short-circuit cloud-disabled explicit-cloud requests) while

View File

@@ -1,10 +1,14 @@
package server package server
import ( import (
"bytes"
"io"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
) )
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) { func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
@@ -137,6 +141,98 @@ func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
} }
} }
func TestCloudPassthroughMiddleware_ZstdBody(t *testing.T) {
gin.SetMode(gin.TestMode)
plainBody := []byte(`{"model":"test-model:cloud","messages":[{"role":"user","content":"hi"}]}`)
var compressed bytes.Buffer
w, err := zstd.NewWriter(&compressed)
if err != nil {
t.Fatalf("zstd writer: %v", err)
}
if _, err := w.Write(plainBody); err != nil {
t.Fatalf("zstd write: %v", err)
}
if err := w.Close(); err != nil {
t.Fatalf("zstd close: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "zstd")
rec := httptest.NewRecorder()
// Track whether the middleware detected the cloud model by checking
// if c.Next() was called (non-cloud path) vs c.Abort() (cloud path).
nextCalled := false
r := gin.New()
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
nextCalled = true
// Verify the body is decompressed and Content-Encoding is removed.
body, err := io.ReadAll(c.Request.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
model, ok := extractModelField(body)
if !ok {
t.Fatal("expected to extract model from decompressed body")
}
if model != "test-model:cloud" {
t.Fatalf("expected model %q, got %q", "test-model:cloud", model)
}
if enc := c.GetHeader("Content-Encoding"); enc != "" {
t.Fatalf("expected Content-Encoding to be removed, got %q", enc)
}
c.Status(http.StatusOK)
})
r.ServeHTTP(rec, req)
// The cloud passthrough middleware should detect the cloud model and
// proxy (abort), so the next handler should NOT be called.
// However, since there's no actual cloud server to proxy to, the
// middleware will attempt to proxy and fail. We just verify it didn't
// fall through to c.Next() due to failure to read the compressed body.
if nextCalled {
t.Fatal("expected cloud passthrough to detect cloud model from zstd body, but it fell through to next handler")
}
}
func TestCloudPassthroughMiddleware_ZstdBodyTooLarge(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a body that exceeds the 20MB limit
oversized := make([]byte, maxDecompressedBodySize+1024)
for i := range oversized {
oversized[i] = 'A'
}
var compressed bytes.Buffer
w, err := zstd.NewWriter(&compressed)
if err != nil {
t.Fatalf("zstd writer: %v", err)
}
if _, err := w.Write(oversized); err != nil {
t.Fatalf("zstd write: %v", err)
}
if err := w.Close(); err != nil {
t.Fatalf("zstd close: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "zstd")
rec := httptest.NewRecorder()
r := gin.New()
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
t.Fatal("handler should not be reached for oversized body")
})
r.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", rec.Code)
}
}
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) { func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil) req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
if err != nil { if err != nil {