mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
server: decompress zstd request bodies in cloud passthrough middleware (#14827)
When a zstd-compressed request (e.g. from Codex CLI) hits /v1/responses with a cloud model the request failed. Fix by decompressing zstd bodies before model extraction, so cloud models are detected and proxied directly without the writer being wrapped.
This commit is contained in:
@@ -18,6 +18,9 @@ import (
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
// maxDecompressedBodySize limits the size of a decompressed request body
|
||||
const maxDecompressedBodySize = 20 << 20
|
||||
|
||||
type BaseWriter struct {
|
||||
gin.ResponseWriter
|
||||
}
|
||||
@@ -512,7 +515,7 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -29,6 +30,9 @@ const (
|
||||
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
||||
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
||||
cloudProxyClientVersionHeader = "X-Ollama-Client-Version"
|
||||
|
||||
// maxDecompressedBodySize limits the size of a decompressed request body
|
||||
maxDecompressedBodySize = 20 << 20
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,6 +77,19 @@ func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Decompress zstd-encoded request bodies so we can inspect the model
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to decompress request body"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
||||
// A future optimization can parse just enough JSON to read "model" (and
|
||||
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
@@ -137,6 +141,98 @@ func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudPassthroughMiddleware_ZstdBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
plainBody := []byte(`{"model":"test-model:cloud","messages":[{"role":"user","content":"hi"}]}`)
|
||||
var compressed bytes.Buffer
|
||||
w, err := zstd.NewWriter(&compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd writer: %v", err)
|
||||
}
|
||||
if _, err := w.Write(plainBody); err != nil {
|
||||
t.Fatalf("zstd write: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("zstd close: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Track whether the middleware detected the cloud model by checking
|
||||
// if c.Next() was called (non-cloud path) vs c.Abort() (cloud path).
|
||||
nextCalled := false
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
|
||||
nextCalled = true
|
||||
// Verify the body is decompressed and Content-Encoding is removed.
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
model, ok := extractModelField(body)
|
||||
if !ok {
|
||||
t.Fatal("expected to extract model from decompressed body")
|
||||
}
|
||||
if model != "test-model:cloud" {
|
||||
t.Fatalf("expected model %q, got %q", "test-model:cloud", model)
|
||||
}
|
||||
if enc := c.GetHeader("Content-Encoding"); enc != "" {
|
||||
t.Fatalf("expected Content-Encoding to be removed, got %q", enc)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
// The cloud passthrough middleware should detect the cloud model and
|
||||
// proxy (abort), so the next handler should NOT be called.
|
||||
// However, since there's no actual cloud server to proxy to, the
|
||||
// middleware will attempt to proxy and fail. We just verify it didn't
|
||||
// fall through to c.Next() due to failure to read the compressed body.
|
||||
if nextCalled {
|
||||
t.Fatal("expected cloud passthrough to detect cloud model from zstd body, but it fell through to next handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudPassthroughMiddleware_ZstdBodyTooLarge(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Create a body that exceeds the 20MB limit
|
||||
oversized := make([]byte, maxDecompressedBodySize+1024)
|
||||
for i := range oversized {
|
||||
oversized[i] = 'A'
|
||||
}
|
||||
|
||||
var compressed bytes.Buffer
|
||||
w, err := zstd.NewWriter(&compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd writer: %v", err)
|
||||
}
|
||||
if _, err := w.Write(oversized); err != nil {
|
||||
t.Fatalf("zstd write: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("zstd close: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
|
||||
t.Fatal("handler should not be reached for oversized body")
|
||||
})
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user