cloud_proxy: for the web_search legacy path, flush on newlines (#14897)

`WebSearchAnthropicWriter` expects a single object per write. The new
transparent proxy will instead send it whatever bytes it sees. This
cloud-model + local-orchestration + cloud-search is a temporary code
path, so instead of making the web search code more robust to this, I
put an adapter in the middle that will flush line-by-line to preserve
the old behavior.
This commit is contained in:
Devon Rifkin
2026-03-17 13:30:17 -07:00
committed by GitHub
parent d727aacd04
commit e37a9b4c01
3 changed files with 197 additions and 1 deletions

View File

@@ -226,7 +226,24 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
c.Status(resp.StatusCode)
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
var bodyWriter http.ResponseWriter = c.Writer
var framedWriter *jsonlFramingResponseWriter
// TEMP(drifkin): only needed on the cloud-proxied first leg of Anthropic
// web_search fallback (which is a path we're removing soon). Local
// /v1/messages writes one JSON value per streamResponse callback directly
// into WebSearchAnthropicWriter, but this proxy copy loop may coalesce
// multiple jsonl records into one Write. WebSearchAnthropicWriter currently
// unmarshals one JSON value per Write.
if path == "/api/chat" && resp.StatusCode == http.StatusOK && c.GetBool(legacyCloudAnthropicKey) {
framedWriter = &jsonlFramingResponseWriter{ResponseWriter: c.Writer}
bodyWriter = framedWriter
}
err = copyProxyResponseBody(bodyWriter, resp.Body)
if err == nil && framedWriter != nil {
err = framedWriter.FlushPending()
}
if err != nil {
ctxErr := c.Request.Context().Err()
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
slog.Debug(
@@ -240,6 +257,7 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
slog.Warn(
"cloud proxy response copy failed",
"path", c.Request.URL.Path,
"upstream_path", path,
"status", resp.StatusCode,
"request_context_canceled", ctxErr != nil,
"request_context_err", ctxErr,
@@ -473,6 +491,55 @@ func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
}
}
type jsonlFramingResponseWriter struct {
http.ResponseWriter
pending []byte
}
func (w *jsonlFramingResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *jsonlFramingResponseWriter) Write(p []byte) (int, error) {
w.pending = append(w.pending, p...)
if err := w.flushCompleteLines(); err != nil {
return len(p), err
}
return len(p), nil
}
func (w *jsonlFramingResponseWriter) FlushPending() error {
trailing := bytes.TrimSpace(w.pending)
w.pending = nil
if len(trailing) == 0 {
return nil
}
_, err := w.ResponseWriter.Write(trailing)
return err
}
func (w *jsonlFramingResponseWriter) flushCompleteLines() error {
for {
newline := bytes.IndexByte(w.pending, '\n')
if newline < 0 {
return nil
}
line := bytes.TrimSpace(w.pending[:newline])
w.pending = w.pending[newline+1:]
if len(line) == 0 {
continue
}
if _, err := w.ResponseWriter.Write(line); err != nil {
return err
}
}
}
func isHopByHopHeader(name string) bool {
_, ok := hopByHopHeaders[strings.ToLower(name)]
return ok

View File

@@ -248,3 +248,71 @@ func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
}
}
func TestJSONLFramingResponseWriter_SplitsCoalescedLines(t *testing.T) {
rec := &chunkRecorder{header: http.Header{}}
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
payload := []byte("{\"a\":1}\n{\"b\":2}\n")
if n, err := w.Write(payload); err != nil {
t.Fatalf("write failed: %v", err)
} else if n != len(payload) {
t.Fatalf("write byte count mismatch: got %d want %d", n, len(payload))
}
if err := w.FlushPending(); err != nil {
t.Fatalf("FlushPending failed: %v", err)
}
if len(rec.chunks) != 2 {
t.Fatalf("expected 2 framed writes, got %d", len(rec.chunks))
}
if got := string(rec.chunks[0]); got != `{"a":1}` {
t.Fatalf("first chunk mismatch: got %q", got)
}
if got := string(rec.chunks[1]); got != `{"b":2}` {
t.Fatalf("second chunk mismatch: got %q", got)
}
}
func TestJSONLFramingResponseWriter_FlushPendingWritesTrailingLine(t *testing.T) {
rec := &chunkRecorder{header: http.Header{}}
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
if _, err := w.Write([]byte("{\"a\":1")); err != nil {
t.Fatalf("write failed: %v", err)
}
if len(rec.chunks) != 0 {
t.Fatalf("expected no writes before newline/flush, got %d", len(rec.chunks))
}
if err := w.FlushPending(); err != nil {
t.Fatalf("FlushPending failed: %v", err)
}
if len(rec.chunks) != 1 {
t.Fatalf("expected 1 write after FlushPending, got %d", len(rec.chunks))
}
if got := string(rec.chunks[0]); got != `{"a":1` {
t.Fatalf("trailing chunk mismatch: got %q", got)
}
}
type chunkRecorder struct {
header http.Header
status int
chunks [][]byte
}
func (r *chunkRecorder) Header() http.Header {
return r.header
}
func (r *chunkRecorder) WriteHeader(statusCode int) {
r.status = statusCode
}
func (r *chunkRecorder) Write(p []byte) (int, error) {
cp := append([]byte(nil), p...)
r.chunks = append(r.chunks, cp)
return len(p), nil
}

View File

@@ -652,6 +652,67 @@ func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) {
}
})
t.Run("v1 messages web_search fallback frames coalesced jsonl chunks", func(t *testing.T) {
type upstreamCapture struct {
path string
}
capture := &upstreamCapture{}
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capture.path = r.URL.Path
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(http.StatusOK)
combined := strings.Join([]string{
`{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hel"},"done":false}`,
`{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"lo"},"done":true}`,
}, "\n") + "\n"
_, _ = w.Write([]byte(combined))
}))
defer upstream.Close()
original := cloudProxyBaseURL
cloudProxyBaseURL = upstream.URL
t.Cleanup(func() { cloudProxyBaseURL = original })
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
reqBody := `{
"model":"gpt-oss:120b-cloud",
"max_tokens":10,
"stream":true,
"messages":[{"role":"user","content":"search the web"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages?beta=true", bytes.NewBufferString(reqBody))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
}
if capture.path != "/api/chat" {
t.Fatalf("expected upstream path /api/chat for web_search fallback, got %q", capture.path)
}
if !strings.Contains(string(body), "event: message_stop") {
t.Fatalf("expected anthropic streaming message_stop event, got body %q", string(body))
}
})
t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) {
upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`)
defer upstream.Close()