openai: split mixed thinking stream chunks via ToChunks (#14648)

This commit is contained in:
Parth Sareen
2026-03-11 14:21:29 -07:00
committed by GitHub
parent c222735c02
commit 97013a190c
4 changed files with 627 additions and 14 deletions

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"strings"
@@ -76,22 +77,29 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
// chat chunk
if w.stream {
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
chunks := openai.ToChunks(w.id, chatResponse, w.toolCallSent)
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
for _, c := range chunks {
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
if chatResponse.Done {
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
if len(chunks) > 0 {
c = chunks[len(chunks)-1]
} else {
slog.Warn("ToChunks returned no chunks; falling back to ToChunk for usage chunk", "id", w.id, "model", chatResponse.Model)
}
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := openai.ToUsage(chatResponse)
c.Usage = &u

View File

@@ -76,6 +76,299 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
}
}
func sseDataFrames(body string) []string {
frames := strings.Split(body, "\n\n")
data := make([]string, 0, len(frames))
for _, frame := range frames {
frame = strings.TrimSpace(frame)
if !strings.HasPrefix(frame, "data: ") {
continue
}
data = append(data, strings.TrimPrefix(frame, "data: "))
}
return data
}
func TestChatWriter_StreamMixedThinkingAndContentEmitsSplitChunks(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
context, _ := gin.CreateTestContext(recorder)
writer := &ChatWriter{
stream: true,
streamOptions: &openai.StreamOptions{IncludeUsage: true},
id: "chatcmpl-test",
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
}
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "reasoning",
Content: "final answer",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 3,
EvalCount: 2,
},
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("marshal response: %v", err)
}
if _, err = writer.Write(data); err != nil {
t.Fatalf("write response: %v", err)
}
if got := recorder.Header().Get("Content-Type"); got != "text/event-stream" {
t.Fatalf("expected Content-Type text/event-stream, got %q", got)
}
frames := sseDataFrames(recorder.Body.String())
if len(frames) != 4 {
t.Fatalf("expected 4 SSE data frames (2 chunks + usage + [DONE]), got %d:\n%s", len(frames), recorder.Body.String())
}
if frames[3] != "[DONE]" {
t.Fatalf("expected final frame [DONE], got %q", frames[3])
}
var reasoningChunk openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
t.Fatalf("unmarshal reasoning chunk: %v", err)
}
var contentChunk openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(frames[1]), &contentChunk); err != nil {
t.Fatalf("unmarshal content chunk: %v", err)
}
var usageChunk openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(frames[2]), &usageChunk); err != nil {
t.Fatalf("unmarshal usage chunk: %v", err)
}
if len(reasoningChunk.Choices) != 1 {
t.Fatalf("expected 1 reasoning choice, got %d", len(reasoningChunk.Choices))
}
if reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
t.Fatalf("expected reasoning chunk reasoning %q, got %q", "reasoning", reasoningChunk.Choices[0].Delta.Reasoning)
}
if reasoningChunk.Choices[0].Delta.Content != "" {
t.Fatalf("expected reasoning chunk content to be empty, got %v", reasoningChunk.Choices[0].Delta.Content)
}
if reasoningChunk.Choices[0].FinishReason != nil {
t.Fatalf("expected reasoning chunk finish reason nil, got %v", reasoningChunk.Choices[0].FinishReason)
}
if len(contentChunk.Choices) != 1 {
t.Fatalf("expected 1 content choice, got %d", len(contentChunk.Choices))
}
if contentChunk.Choices[0].Delta.Reasoning != "" {
t.Fatalf("expected content chunk reasoning to be empty, got %q", contentChunk.Choices[0].Delta.Reasoning)
}
if contentChunk.Choices[0].Delta.Content != "final answer" {
t.Fatalf("expected content chunk content %q, got %v", "final answer", contentChunk.Choices[0].Delta.Content)
}
if contentChunk.Choices[0].FinishReason == nil || *contentChunk.Choices[0].FinishReason != "stop" {
t.Fatalf("expected content chunk finish reason %q, got %v", "stop", contentChunk.Choices[0].FinishReason)
}
if usageChunk.Usage == nil {
t.Fatal("expected usage chunk to include usage")
}
if usageChunk.Usage.TotalTokens != 5 {
t.Fatalf("expected usage total tokens 5, got %d", usageChunk.Usage.TotalTokens)
}
if len(usageChunk.Choices) != 0 {
t.Fatalf("expected usage chunk choices to be empty, got %d", len(usageChunk.Choices))
}
}
func TestChatWriter_StreamSingleChunkPathStillEmitsOneChunk(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
context, _ := gin.CreateTestContext(recorder)
writer := &ChatWriter{
stream: true,
id: "chatcmpl-test",
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
}
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Content: "single chunk",
},
Done: true,
DoneReason: "stop",
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("marshal response: %v", err)
}
if _, err = writer.Write(data); err != nil {
t.Fatalf("write response: %v", err)
}
frames := sseDataFrames(recorder.Body.String())
if len(frames) != 2 {
t.Fatalf("expected 2 SSE data frames (1 chunk + [DONE]), got %d:\n%s", len(frames), recorder.Body.String())
}
if frames[1] != "[DONE]" {
t.Fatalf("expected final frame [DONE], got %q", frames[1])
}
var chunk openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(frames[0]), &chunk); err != nil {
t.Fatalf("unmarshal chunk: %v", err)
}
if len(chunk.Choices) != 1 {
t.Fatalf("expected 1 chunk choice, got %d", len(chunk.Choices))
}
if chunk.Choices[0].Delta.Content != "single chunk" {
t.Fatalf("expected chunk content %q, got %v", "single chunk", chunk.Choices[0].Delta.Content)
}
}
func TestChatWriter_StreamMixedThinkingAndToolCallsWithoutDoneEmitsChunksOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
context, _ := gin.CreateTestContext(recorder)
writer := &ChatWriter{
stream: true,
streamOptions: &openai.StreamOptions{IncludeUsage: true},
id: "chatcmpl-test",
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
}
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "reasoning",
ToolCalls: []api.ToolCall{
{
ID: "call_234",
Function: api.ToolCallFunction{
Index: 0,
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Portland",
}),
},
},
},
},
Done: false,
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("marshal response: %v", err)
}
if _, err = writer.Write(data); err != nil {
t.Fatalf("write response: %v", err)
}
frames := sseDataFrames(recorder.Body.String())
if len(frames) != 2 {
t.Fatalf("expected 2 SSE data frames (reasoning + tool-calls), got %d:\n%s", len(frames), recorder.Body.String())
}
if frames[len(frames)-1] == "[DONE]" {
t.Fatalf("did not expect [DONE] frame for non-final chunk: %s", recorder.Body.String())
}
var reasoningChunk openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
t.Fatalf("unmarshal reasoning chunk: %v", err)
}
var toolCallChunk openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(frames[1]), &toolCallChunk); err != nil {
t.Fatalf("unmarshal tool-call chunk: %v", err)
}
if len(reasoningChunk.Choices) != 1 || reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
t.Fatalf("expected first chunk to be reasoning-only, got %+v", reasoningChunk.Choices)
}
if len(toolCallChunk.Choices) != 1 || len(toolCallChunk.Choices[0].Delta.ToolCalls) != 1 {
t.Fatalf("expected second chunk to contain tool calls, got %+v", toolCallChunk.Choices)
}
if toolCallChunk.Choices[0].FinishReason != nil {
t.Fatalf("expected nil finish reason for non-final tool-call chunk, got %v", toolCallChunk.Choices[0].FinishReason)
}
if !writer.toolCallSent {
t.Fatal("expected toolCallSent to be tracked after tool-call chunk emission")
}
}
func TestChatWriter_StreamMixedThinkingAndContentWithoutDoneEmitsChunksOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
context, _ := gin.CreateTestContext(recorder)
writer := &ChatWriter{
stream: true,
streamOptions: &openai.StreamOptions{IncludeUsage: true},
id: "chatcmpl-test",
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
}
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "reasoning",
Content: "partial content",
},
Done: false,
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("marshal response: %v", err)
}
if _, err = writer.Write(data); err != nil {
t.Fatalf("write response: %v", err)
}
frames := sseDataFrames(recorder.Body.String())
if len(frames) != 2 {
t.Fatalf("expected 2 SSE data frames (reasoning + content), got %d:\n%s", len(frames), recorder.Body.String())
}
if frames[len(frames)-1] == "[DONE]" {
t.Fatalf("did not expect [DONE] frame for non-final chunk: %s", recorder.Body.String())
}
var reasoningChunk openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
t.Fatalf("unmarshal reasoning chunk: %v", err)
}
var contentChunk openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(frames[1]), &contentChunk); err != nil {
t.Fatalf("unmarshal content chunk: %v", err)
}
if len(reasoningChunk.Choices) != 1 || reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
t.Fatalf("expected first chunk to be reasoning-only, got %+v", reasoningChunk.Choices)
}
if len(contentChunk.Choices) != 1 || contentChunk.Choices[0].Delta.Content != "partial content" {
t.Fatalf("expected second chunk to contain content, got %+v", contentChunk.Choices)
}
if contentChunk.Choices[0].FinishReason != nil {
t.Fatalf("expected nil finish reason for non-final content chunk, got %v", contentChunk.Choices[0].FinishReason)
}
}
func TestChatMiddleware(t *testing.T) {
type testCase struct {
name string

View File

@@ -291,8 +291,7 @@ func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}
}
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
toolCalls := ToToolCalls(r.Message.ToolCalls)
var logprobs *ChoiceLogprobs
@@ -323,6 +322,36 @@ func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
}
}
// ToChunks converts an api.ChatResponse to one or more ChatCompletionChunk values.
func ToChunks(id string, r api.ChatResponse, toolCallSent bool) []ChatCompletionChunk {
hasMixedResponse := r.Message.Thinking != "" && (r.Message.Content != "" || len(r.Message.ToolCalls) > 0)
if !hasMixedResponse {
return []ChatCompletionChunk{toChunk(id, r, toolCallSent)}
}
reasoningChunk := toChunk(id, r, toolCallSent)
// The logprobs here might include tokens not in this chunk because we now split between thinking and content/tool calls.
reasoningChunk.Choices[0].Delta.Content = ""
reasoningChunk.Choices[0].Delta.ToolCalls = nil
reasoningChunk.Choices[0].FinishReason = nil
contentOrToolCallsChunk := toChunk(id, r, toolCallSent)
// Keep both split chunks on the same timestamp since they represent one logical emission.
contentOrToolCallsChunk.Created = reasoningChunk.Created
contentOrToolCallsChunk.Choices[0].Delta.Reasoning = ""
contentOrToolCallsChunk.Choices[0].Logprobs = nil
return []ChatCompletionChunk{
reasoningChunk,
contentOrToolCallsChunk,
}
}
// Deprecated: use ToChunks for streaming conversion.
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
return toChunk(id, r, toolCallSent)
}
// ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
return Usage{

View File

@@ -413,6 +413,289 @@ func TestToChatCompletion_WithoutLogprobs(t *testing.T) {
}
}
func TestToChunks_SplitsThinkingAndContent(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "step-by-step",
Content: "final answer",
},
Done: true,
DoneReason: "stop",
}
chunks := ToChunks("test-id", resp, false)
if len(chunks) != 2 {
t.Fatalf("expected 2 chunks, got %d", len(chunks))
}
reasoning := chunks[0].Choices[0]
if reasoning.Delta.Reasoning != "step-by-step" {
t.Fatalf("expected reasoning chunk to contain thinking, got %q", reasoning.Delta.Reasoning)
}
if reasoning.Delta.Content != "" {
t.Fatalf("expected reasoning chunk content to be empty, got %v", reasoning.Delta.Content)
}
if len(reasoning.Delta.ToolCalls) != 0 {
t.Fatalf("expected reasoning chunk tool calls to be empty, got %d", len(reasoning.Delta.ToolCalls))
}
if reasoning.FinishReason != nil {
t.Fatalf("expected reasoning chunk finish reason to be nil, got %q", *reasoning.FinishReason)
}
content := chunks[1].Choices[0]
if content.Delta.Reasoning != "" {
t.Fatalf("expected content chunk reasoning to be empty, got %q", content.Delta.Reasoning)
}
if content.Delta.Content != "final answer" {
t.Fatalf("expected content chunk content %q, got %v", "final answer", content.Delta.Content)
}
if content.FinishReason == nil || *content.FinishReason != "stop" {
t.Fatalf("expected content chunk finish reason %q, got %v", "stop", content.FinishReason)
}
}
func TestToChunks_SplitsThinkingAndToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "need a tool",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Index: 0,
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Seattle",
}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
chunks := ToChunks("test-id", resp, false)
if len(chunks) != 2 {
t.Fatalf("expected 2 chunks, got %d", len(chunks))
}
reasoning := chunks[0].Choices[0]
if reasoning.Delta.Reasoning != "need a tool" {
t.Fatalf("expected reasoning chunk to contain thinking, got %q", reasoning.Delta.Reasoning)
}
if len(reasoning.Delta.ToolCalls) != 0 {
t.Fatalf("expected reasoning chunk tool calls to be empty, got %d", len(reasoning.Delta.ToolCalls))
}
if reasoning.FinishReason != nil {
t.Fatalf("expected reasoning chunk finish reason to be nil, got %q", *reasoning.FinishReason)
}
toolCallChunk := chunks[1].Choices[0]
if toolCallChunk.Delta.Reasoning != "" {
t.Fatalf("expected tool-call chunk reasoning to be empty, got %q", toolCallChunk.Delta.Reasoning)
}
if len(toolCallChunk.Delta.ToolCalls) != 1 {
t.Fatalf("expected one tool call in second chunk, got %d", len(toolCallChunk.Delta.ToolCalls))
}
if toolCallChunk.Delta.ToolCalls[0].ID != "call_123" {
t.Fatalf("expected tool call id %q, got %q", "call_123", toolCallChunk.Delta.ToolCalls[0].ID)
}
if toolCallChunk.FinishReason == nil || *toolCallChunk.FinishReason != finishReasonToolCalls {
t.Fatalf("expected tool-call chunk finish reason %q, got %v", finishReasonToolCalls, toolCallChunk.FinishReason)
}
}
func TestToChunks_SingleChunkForNonMixedResponses(t *testing.T) {
toolCalls := []api.ToolCall{
{
ID: "call_456",
Function: api.ToolCallFunction{
Index: 0,
Name: "get_time",
Arguments: testArgs(map[string]any{
"timezone": "UTC",
}),
},
},
}
tests := []struct {
name string
message api.Message
}{
{
name: "thinking-only",
message: api.Message{Thinking: "pondering"},
},
{
name: "content-only",
message: api.Message{Content: "hello"},
},
{
name: "toolcalls-only",
message: api.Message{ToolCalls: toolCalls},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: tt.message,
}
chunks := ToChunks("test-id", resp, false)
if len(chunks) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(chunks))
}
})
}
}
func TestToChunks_SplitsThinkingAndToolCallsWhenNotDone(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "need a tool",
ToolCalls: []api.ToolCall{
{
ID: "call_789",
Function: api.ToolCallFunction{
Index: 0,
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "San Francisco",
}),
},
},
},
},
Done: false,
}
chunks := ToChunks("test-id", resp, false)
if len(chunks) != 2 {
t.Fatalf("expected 2 chunks, got %d", len(chunks))
}
reasoning := chunks[0].Choices[0]
if reasoning.Delta.Reasoning != "need a tool" {
t.Fatalf("expected reasoning chunk to contain thinking, got %q", reasoning.Delta.Reasoning)
}
if reasoning.FinishReason != nil {
t.Fatalf("expected reasoning chunk finish reason nil, got %v", reasoning.FinishReason)
}
toolCallChunk := chunks[1].Choices[0]
if len(toolCallChunk.Delta.ToolCalls) != 1 {
t.Fatalf("expected one tool call in second chunk, got %d", len(toolCallChunk.Delta.ToolCalls))
}
if toolCallChunk.Delta.ToolCalls[0].ID != "call_789" {
t.Fatalf("expected tool call id %q, got %q", "call_789", toolCallChunk.Delta.ToolCalls[0].ID)
}
if toolCallChunk.FinishReason != nil {
t.Fatalf("expected tool-call chunk finish reason nil when not done, got %v", toolCallChunk.FinishReason)
}
}
func TestToChunks_SplitsThinkingAndContentWhenNotDone(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "thinking",
Content: "partial content",
},
Done: false,
}
chunks := ToChunks("test-id", resp, false)
if len(chunks) != 2 {
t.Fatalf("expected 2 chunks, got %d", len(chunks))
}
reasoning := chunks[0].Choices[0]
if reasoning.Delta.Reasoning != "thinking" {
t.Fatalf("expected reasoning chunk to contain thinking, got %q", reasoning.Delta.Reasoning)
}
if reasoning.FinishReason != nil {
t.Fatalf("expected reasoning chunk finish reason nil, got %v", reasoning.FinishReason)
}
content := chunks[1].Choices[0]
if content.Delta.Content != "partial content" {
t.Fatalf("expected content chunk content %q, got %v", "partial content", content.Delta.Content)
}
if content.FinishReason != nil {
t.Fatalf("expected content chunk finish reason nil when not done, got %v", content.FinishReason)
}
}
func TestToChunks_SplitSendsLogprobsOnlyOnFirstChunk(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "thinking",
Content: "content",
},
Logprobs: []api.Logprob{
{
TokenLogprob: api.TokenLogprob{
Token: "tok",
Logprob: -0.25,
},
},
},
Done: true,
DoneReason: "stop",
}
chunks := ToChunks("test-id", resp, false)
if len(chunks) != 2 {
t.Fatalf("expected 2 chunks, got %d", len(chunks))
}
first := chunks[0].Choices[0]
if first.Logprobs == nil {
t.Fatal("expected first chunk to include logprobs")
}
if len(first.Logprobs.Content) != 1 || first.Logprobs.Content[0].Token != "tok" {
t.Fatalf("unexpected first chunk logprobs: %+v", first.Logprobs.Content)
}
second := chunks[1].Choices[0]
if second.Logprobs != nil {
t.Fatalf("expected second chunk logprobs to be nil, got %+v", second.Logprobs)
}
}
func TestToChunk_LegacyMixedThinkingAndContentSingleChunk(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Thinking: "reasoning",
Content: "answer",
},
Done: true,
DoneReason: "stop",
}
chunk := ToChunk("test-id", resp, false)
if len(chunk.Choices) != 1 {
t.Fatalf("expected 1 choice, got %d", len(chunk.Choices))
}
delta := chunk.Choices[0].Delta
if delta.Reasoning != "reasoning" {
t.Fatalf("expected reasoning %q, got %q", "reasoning", delta.Reasoning)
}
if delta.Content != "answer" {
t.Fatalf("expected content %q, got %v", "answer", delta.Content)
}
}
func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
tests := []struct {
name string