mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
When a zstd-compressed request (e.g. from Codex CLI) hits /v1/responses with a cloud model the request failed. Fix by decompressing zstd bodies before model extraction, so cloud models are detected and proxied directly without the writer being wrapped.
251 lines
7.9 KiB
Go
251 lines
7.9 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/klauspost/compress/zstd"
|
|
)
|
|
|
|
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
|
src := http.Header{}
|
|
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
|
|
src.Add("X-Trace-Hop", "drop-me")
|
|
src.Add("X-Alt-Hop", "drop-me-too")
|
|
src.Add("Keep-Alive", "timeout=5")
|
|
src.Add("X-End-To-End", "keep-me")
|
|
|
|
dst := http.Header{}
|
|
copyProxyRequestHeaders(dst, src)
|
|
|
|
if got := dst.Get("Connection"); got != "" {
|
|
t.Fatalf("expected Connection to be stripped, got %q", got)
|
|
}
|
|
if got := dst.Get("Keep-Alive"); got != "" {
|
|
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
|
|
}
|
|
if got := dst.Get("X-Trace-Hop"); got != "" {
|
|
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
|
|
}
|
|
if got := dst.Get("X-Alt-Hop"); got != "" {
|
|
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
|
|
}
|
|
if got := dst.Get("X-End-To-End"); got != "keep-me" {
|
|
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
|
src := http.Header{}
|
|
src.Add("Connection", "X-Upstream-Hop")
|
|
src.Add("X-Upstream-Hop", "drop-me")
|
|
src.Add("Content-Type", "application/json")
|
|
src.Add("X-Server-Trace", "keep-me")
|
|
|
|
dst := http.Header{}
|
|
copyProxyResponseHeaders(dst, src)
|
|
|
|
if got := dst.Get("Connection"); got != "" {
|
|
t.Fatalf("expected Connection to be stripped, got %q", got)
|
|
}
|
|
if got := dst.Get("X-Upstream-Hop"); got != "" {
|
|
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
|
|
}
|
|
if got := dst.Get("Content-Type"); got != "application/json" {
|
|
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
|
|
}
|
|
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
|
|
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
|
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if overridden {
|
|
t.Fatal("expected override=false for empty input")
|
|
}
|
|
if baseURL != defaultCloudProxyBaseURL {
|
|
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
|
|
}
|
|
if signingHost != defaultCloudProxySigningHost {
|
|
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
|
|
}
|
|
}
|
|
|
|
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
|
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !overridden {
|
|
t.Fatal("expected override=true")
|
|
}
|
|
if baseURL != "http://localhost:8080" {
|
|
t.Fatalf("unexpected base URL: %q", baseURL)
|
|
}
|
|
if signingHost != "localhost" {
|
|
t.Fatalf("unexpected signing host: %q", signingHost)
|
|
}
|
|
}
|
|
|
|
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
|
|
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
|
|
if err == nil {
|
|
t.Fatal("expected error for non-loopback override in release mode")
|
|
}
|
|
}
|
|
|
|
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
|
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !overridden {
|
|
t.Fatal("expected override=true")
|
|
}
|
|
if baseURL != "https://example.com:8443" {
|
|
t.Fatalf("unexpected base URL: %q", baseURL)
|
|
}
|
|
if signingHost != "example.com" {
|
|
t.Fatalf("unexpected signing host: %q", signingHost)
|
|
}
|
|
}
|
|
|
|
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
|
|
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
|
|
if err == nil {
|
|
t.Fatal("expected error for non-loopback http override in dev mode")
|
|
}
|
|
}
|
|
|
|
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
|
|
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create request: %v", err)
|
|
}
|
|
|
|
got := buildCloudSignatureChallenge(req, "123")
|
|
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
|
|
if got != want {
|
|
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
|
}
|
|
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
|
|
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
|
}
|
|
}
|
|
|
|
func TestCloudPassthroughMiddleware_ZstdBody(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
plainBody := []byte(`{"model":"test-model:cloud","messages":[{"role":"user","content":"hi"}]}`)
|
|
var compressed bytes.Buffer
|
|
w, err := zstd.NewWriter(&compressed)
|
|
if err != nil {
|
|
t.Fatalf("zstd writer: %v", err)
|
|
}
|
|
if _, err := w.Write(plainBody); err != nil {
|
|
t.Fatalf("zstd write: %v", err)
|
|
}
|
|
if err := w.Close(); err != nil {
|
|
t.Fatalf("zstd close: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
|
|
req.Header.Set("Content-Encoding", "zstd")
|
|
rec := httptest.NewRecorder()
|
|
|
|
// Track whether the middleware detected the cloud model by checking
|
|
// if c.Next() was called (non-cloud path) vs c.Abort() (cloud path).
|
|
nextCalled := false
|
|
|
|
r := gin.New()
|
|
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
|
|
nextCalled = true
|
|
// Verify the body is decompressed and Content-Encoding is removed.
|
|
body, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
t.Fatalf("read body: %v", err)
|
|
}
|
|
model, ok := extractModelField(body)
|
|
if !ok {
|
|
t.Fatal("expected to extract model from decompressed body")
|
|
}
|
|
if model != "test-model:cloud" {
|
|
t.Fatalf("expected model %q, got %q", "test-model:cloud", model)
|
|
}
|
|
if enc := c.GetHeader("Content-Encoding"); enc != "" {
|
|
t.Fatalf("expected Content-Encoding to be removed, got %q", enc)
|
|
}
|
|
c.Status(http.StatusOK)
|
|
})
|
|
r.ServeHTTP(rec, req)
|
|
|
|
// The cloud passthrough middleware should detect the cloud model and
|
|
// proxy (abort), so the next handler should NOT be called.
|
|
// However, since there's no actual cloud server to proxy to, the
|
|
// middleware will attempt to proxy and fail. We just verify it didn't
|
|
// fall through to c.Next() due to failure to read the compressed body.
|
|
if nextCalled {
|
|
t.Fatal("expected cloud passthrough to detect cloud model from zstd body, but it fell through to next handler")
|
|
}
|
|
}
|
|
|
|
func TestCloudPassthroughMiddleware_ZstdBodyTooLarge(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
// Create a body that exceeds the 20MB limit
|
|
oversized := make([]byte, maxDecompressedBodySize+1024)
|
|
for i := range oversized {
|
|
oversized[i] = 'A'
|
|
}
|
|
|
|
var compressed bytes.Buffer
|
|
w, err := zstd.NewWriter(&compressed)
|
|
if err != nil {
|
|
t.Fatalf("zstd writer: %v", err)
|
|
}
|
|
if _, err := w.Write(oversized); err != nil {
|
|
t.Fatalf("zstd write: %v", err)
|
|
}
|
|
if err := w.Close(); err != nil {
|
|
t.Fatalf("zstd close: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
|
|
req.Header.Set("Content-Encoding", "zstd")
|
|
rec := httptest.NewRecorder()
|
|
|
|
r := gin.New()
|
|
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
|
|
t.Fatal("handler should not be reached for oversized body")
|
|
})
|
|
r.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Fatalf("expected status 400, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
|
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create request: %v", err)
|
|
}
|
|
|
|
got := buildCloudSignatureChallenge(req, "123")
|
|
want := "POST,/v1/messages?beta=true&ts=123"
|
|
if got != want {
|
|
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
|
}
|
|
if req.URL.RawQuery != "beta=true&ts=123" {
|
|
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
|
}
|
|
}
|