mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
cloud_proxy: for the web_search legacy path, flush on newlines (#14897)
`WebSearchAnthropicWriter` expects a single object per write. The new transparent proxy will instead send it whatever bytes it sees. This cloud-model + local-orchestration + cloud-search is a temporary code path, so instead of making the web search code more robust to this, I put an adapter in the middle that will flush line-by-line to preserve the old behavior.
This commit is contained in:
@@ -226,7 +226,24 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
|
||||
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
|
||||
var bodyWriter http.ResponseWriter = c.Writer
|
||||
var framedWriter *jsonlFramingResponseWriter
|
||||
// TEMP(drifkin): only needed on the cloud-proxied first leg of Anthropic
|
||||
// web_search fallback (which is a path we're removing soon). Local
|
||||
// /v1/messages writes one JSON value per streamResponse callback directly
|
||||
// into WebSearchAnthropicWriter, but this proxy copy loop may coalesce
|
||||
// multiple jsonl records into one Write. WebSearchAnthropicWriter currently
|
||||
// unmarshals one JSON value per Write.
|
||||
if path == "/api/chat" && resp.StatusCode == http.StatusOK && c.GetBool(legacyCloudAnthropicKey) {
|
||||
framedWriter = &jsonlFramingResponseWriter{ResponseWriter: c.Writer}
|
||||
bodyWriter = framedWriter
|
||||
}
|
||||
|
||||
err = copyProxyResponseBody(bodyWriter, resp.Body)
|
||||
if err == nil && framedWriter != nil {
|
||||
err = framedWriter.FlushPending()
|
||||
}
|
||||
if err != nil {
|
||||
ctxErr := c.Request.Context().Err()
|
||||
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
|
||||
slog.Debug(
|
||||
@@ -240,6 +257,7 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
|
||||
slog.Warn(
|
||||
"cloud proxy response copy failed",
|
||||
"path", c.Request.URL.Path,
|
||||
"upstream_path", path,
|
||||
"status", resp.StatusCode,
|
||||
"request_context_canceled", ctxErr != nil,
|
||||
"request_context_err", ctxErr,
|
||||
@@ -473,6 +491,55 @@ func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
type jsonlFramingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
pending []byte
|
||||
}
|
||||
|
||||
func (w *jsonlFramingResponseWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *jsonlFramingResponseWriter) Write(p []byte) (int, error) {
|
||||
w.pending = append(w.pending, p...)
|
||||
if err := w.flushCompleteLines(); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *jsonlFramingResponseWriter) FlushPending() error {
|
||||
trailing := bytes.TrimSpace(w.pending)
|
||||
w.pending = nil
|
||||
if len(trailing) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := w.ResponseWriter.Write(trailing)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *jsonlFramingResponseWriter) flushCompleteLines() error {
|
||||
for {
|
||||
newline := bytes.IndexByte(w.pending, '\n')
|
||||
if newline < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
line := bytes.TrimSpace(w.pending[:newline])
|
||||
w.pending = w.pending[newline+1:]
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := w.ResponseWriter.Write(line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isHopByHopHeader(name string) bool {
|
||||
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
||||
return ok
|
||||
|
||||
@@ -248,3 +248,71 @@ func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLFramingResponseWriter_SplitsCoalescedLines(t *testing.T) {
|
||||
rec := &chunkRecorder{header: http.Header{}}
|
||||
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
|
||||
|
||||
payload := []byte("{\"a\":1}\n{\"b\":2}\n")
|
||||
if n, err := w.Write(payload); err != nil {
|
||||
t.Fatalf("write failed: %v", err)
|
||||
} else if n != len(payload) {
|
||||
t.Fatalf("write byte count mismatch: got %d want %d", n, len(payload))
|
||||
}
|
||||
|
||||
if err := w.FlushPending(); err != nil {
|
||||
t.Fatalf("FlushPending failed: %v", err)
|
||||
}
|
||||
|
||||
if len(rec.chunks) != 2 {
|
||||
t.Fatalf("expected 2 framed writes, got %d", len(rec.chunks))
|
||||
}
|
||||
if got := string(rec.chunks[0]); got != `{"a":1}` {
|
||||
t.Fatalf("first chunk mismatch: got %q", got)
|
||||
}
|
||||
if got := string(rec.chunks[1]); got != `{"b":2}` {
|
||||
t.Fatalf("second chunk mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLFramingResponseWriter_FlushPendingWritesTrailingLine(t *testing.T) {
|
||||
rec := &chunkRecorder{header: http.Header{}}
|
||||
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
|
||||
|
||||
if _, err := w.Write([]byte("{\"a\":1")); err != nil {
|
||||
t.Fatalf("write failed: %v", err)
|
||||
}
|
||||
if len(rec.chunks) != 0 {
|
||||
t.Fatalf("expected no writes before newline/flush, got %d", len(rec.chunks))
|
||||
}
|
||||
|
||||
if err := w.FlushPending(); err != nil {
|
||||
t.Fatalf("FlushPending failed: %v", err)
|
||||
}
|
||||
if len(rec.chunks) != 1 {
|
||||
t.Fatalf("expected 1 write after FlushPending, got %d", len(rec.chunks))
|
||||
}
|
||||
if got := string(rec.chunks[0]); got != `{"a":1` {
|
||||
t.Fatalf("trailing chunk mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
type chunkRecorder struct {
|
||||
header http.Header
|
||||
status int
|
||||
chunks [][]byte
|
||||
}
|
||||
|
||||
func (r *chunkRecorder) Header() http.Header {
|
||||
return r.header
|
||||
}
|
||||
|
||||
func (r *chunkRecorder) WriteHeader(statusCode int) {
|
||||
r.status = statusCode
|
||||
}
|
||||
|
||||
func (r *chunkRecorder) Write(p []byte) (int, error) {
|
||||
cp := append([]byte(nil), p...)
|
||||
r.chunks = append(r.chunks, cp)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@@ -652,6 +652,67 @@ func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("v1 messages web_search fallback frames coalesced jsonl chunks", func(t *testing.T) {
|
||||
type upstreamCapture struct {
|
||||
path string
|
||||
}
|
||||
capture := &upstreamCapture{}
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capture.path = r.URL.Path
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
combined := strings.Join([]string{
|
||||
`{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hel"},"done":false}`,
|
||||
`{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"lo"},"done":true}`,
|
||||
}, "\n") + "\n"
|
||||
_, _ = w.Write([]byte(combined))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
original := cloudProxyBaseURL
|
||||
cloudProxyBaseURL = upstream.URL
|
||||
t.Cleanup(func() { cloudProxyBaseURL = original })
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
reqBody := `{
|
||||
"model":"gpt-oss:120b-cloud",
|
||||
"max_tokens":10,
|
||||
"stream":true,
|
||||
"messages":[{"role":"user","content":"search the web"}],
|
||||
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
||||
}`
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages?beta=true", bytes.NewBufferString(reqBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
if capture.path != "/api/chat" {
|
||||
t.Fatalf("expected upstream path /api/chat for web_search fallback, got %q", capture.path)
|
||||
}
|
||||
if !strings.Contains(string(body), "event: message_stop") {
|
||||
t.Fatalf("expected anthropic streaming message_stop event, got body %q", string(body))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) {
|
||||
upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`)
|
||||
defer upstream.Close()
|
||||
|
||||
Reference in New Issue
Block a user