diff --git a/server/cloud_proxy.go b/server/cloud_proxy.go index 4ab9c1a77..70fe215f9 100644 --- a/server/cloud_proxy.go +++ b/server/cloud_proxy.go @@ -226,7 +226,24 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable copyProxyResponseHeaders(c.Writer.Header(), resp.Header) c.Status(resp.StatusCode) - if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil { + var bodyWriter http.ResponseWriter = c.Writer + var framedWriter *jsonlFramingResponseWriter + // TEMP(drifkin): only needed on the cloud-proxied first leg of Anthropic + // web_search fallback (which is a path we're removing soon). Local + // /v1/messages writes one JSON value per streamResponse callback directly + // into WebSearchAnthropicWriter, but this proxy copy loop may coalesce + // multiple jsonl records into one Write. WebSearchAnthropicWriter currently + // unmarshals one JSON value per Write. + if path == "/api/chat" && resp.StatusCode == http.StatusOK && c.GetBool(legacyCloudAnthropicKey) { + framedWriter = &jsonlFramingResponseWriter{ResponseWriter: c.Writer} + bodyWriter = framedWriter + } + + err = copyProxyResponseBody(bodyWriter, resp.Body) + if err == nil && framedWriter != nil { + err = framedWriter.FlushPending() + } + if err != nil { ctxErr := c.Request.Context().Err() if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) { slog.Debug( @@ -240,6 +257,7 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable slog.Warn( "cloud proxy response copy failed", "path", c.Request.URL.Path, + "upstream_path", path, "status", resp.StatusCode, "request_context_canceled", ctxErr != nil, "request_context_err", ctxErr, @@ -473,6 +491,55 @@ func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error { } } +type jsonlFramingResponseWriter struct { + http.ResponseWriter + pending []byte +} + +func (w *jsonlFramingResponseWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +func (w *jsonlFramingResponseWriter) Write(p []byte) (int, error) { + w.pending = append(w.pending, p...) + if err := w.flushCompleteLines(); err != nil { + return len(p), err + } + return len(p), nil +} + +func (w *jsonlFramingResponseWriter) FlushPending() error { + trailing := bytes.TrimSpace(w.pending) + w.pending = nil + if len(trailing) == 0 { + return nil + } + + _, err := w.ResponseWriter.Write(trailing) + return err +} + +func (w *jsonlFramingResponseWriter) flushCompleteLines() error { + for { + newline := bytes.IndexByte(w.pending, '\n') + if newline < 0 { + return nil + } + + line := bytes.TrimSpace(w.pending[:newline]) + w.pending = w.pending[newline+1:] + if len(line) == 0 { + continue + } + + if _, err := w.ResponseWriter.Write(line); err != nil { + return err + } + } +} + func isHopByHopHeader(name string) bool { _, ok := hopByHopHeaders[strings.ToLower(name)] return ok diff --git a/server/cloud_proxy_test.go b/server/cloud_proxy_test.go index 1bac5cc62..950ec2bc2 100644 --- a/server/cloud_proxy_test.go +++ b/server/cloud_proxy_test.go @@ -248,3 +248,71 @@ func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) { t.Fatalf("unexpected signed query: %q", req.URL.RawQuery) } } + +func TestJSONLFramingResponseWriter_SplitsCoalescedLines(t *testing.T) { + rec := &chunkRecorder{header: http.Header{}} + w := &jsonlFramingResponseWriter{ResponseWriter: rec} + + payload := []byte("{\"a\":1}\n{\"b\":2}\n") + if n, err := w.Write(payload); err != nil { + t.Fatalf("write failed: %v", err) + } else if n != len(payload) { + t.Fatalf("write byte count mismatch: got %d want %d", n, len(payload)) + } + + if err := w.FlushPending(); err != nil { + t.Fatalf("FlushPending failed: %v", err) + } + + if len(rec.chunks) != 2 { + t.Fatalf("expected 2 framed writes, got %d", len(rec.chunks)) + } + if got := string(rec.chunks[0]); got != `{"a":1}` { + t.Fatalf("first chunk mismatch: got %q", got) + } + if got := string(rec.chunks[1]); got != `{"b":2}` { + t.Fatalf("second chunk mismatch: got %q", got) + } +} + +func TestJSONLFramingResponseWriter_FlushPendingWritesTrailingLine(t *testing.T) { + rec := &chunkRecorder{header: http.Header{}} + w := &jsonlFramingResponseWriter{ResponseWriter: rec} + + if _, err := w.Write([]byte("{\"a\":1")); err != nil { + t.Fatalf("write failed: %v", err) + } + if len(rec.chunks) != 0 { + t.Fatalf("expected no writes before newline/flush, got %d", len(rec.chunks)) + } + + if err := w.FlushPending(); err != nil { + t.Fatalf("FlushPending failed: %v", err) + } + if len(rec.chunks) != 1 { + t.Fatalf("expected 1 write after FlushPending, got %d", len(rec.chunks)) + } + if got := string(rec.chunks[0]); got != `{"a":1` { + t.Fatalf("trailing chunk mismatch: got %q", got) + } +} + +type chunkRecorder struct { + header http.Header + status int + chunks [][]byte +} + +func (r *chunkRecorder) Header() http.Header { + return r.header +} + +func (r *chunkRecorder) WriteHeader(statusCode int) { + r.status = statusCode +} + +func (r *chunkRecorder) Write(p []byte) (int, error) { + cp := append([]byte(nil), p...) + r.chunks = append(r.chunks, cp) + return len(p), nil +} diff --git a/server/routes_cloud_test.go b/server/routes_cloud_test.go index 8bbc52e08..aaaf5b73d 100644 --- a/server/routes_cloud_test.go +++ b/server/routes_cloud_test.go @@ -652,6 +652,67 @@ func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) { } }) + t.Run("v1 messages web_search fallback frames coalesced jsonl chunks", func(t *testing.T) { + type upstreamCapture struct { + path string + } + capture := &upstreamCapture{} + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capture.path = r.URL.Path + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + + combined := strings.Join([]string{ + `{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hel"},"done":false}`, + `{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"lo"},"done":true}`, + }, "\n") + "\n" + _, _ = w.Write([]byte(combined)) + })) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{ + "model":"gpt-oss:120b-cloud", + "max_tokens":10, + "stream":true, + "messages":[{"role":"user","content":"search the web"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages?beta=true", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + if capture.path != "/api/chat" { + t.Fatalf("expected upstream path /api/chat for web_search fallback, got %q", capture.path) + } + if !strings.Contains(string(body), "event: message_stop") { + t.Fatalf("expected anthropic streaming message_stop event, got body %q", string(body)) + } + }) + t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) { upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`) defer upstream.Close()