mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
openai: split mixed thinking stream chunks via ToChunks (#14648)
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -76,22 +77,29 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||
w.toolCallSent = true
|
||||
}
|
||||
|
||||
chunks := openai.ToChunks(w.id, chatResponse, w.toolCallSent)
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
for _, c := range chunks {
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||
w.toolCallSent = true
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
|
||||
if len(chunks) > 0 {
|
||||
c = chunks[len(chunks)-1]
|
||||
} else {
|
||||
slog.Warn("ToChunks returned no chunks; falling back to ToChunk for usage chunk", "id", w.id, "model", chatResponse.Model)
|
||||
}
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := openai.ToUsage(chatResponse)
|
||||
c.Usage = &u
|
||||
|
||||
@@ -76,6 +76,299 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func sseDataFrames(body string) []string {
|
||||
frames := strings.Split(body, "\n\n")
|
||||
data := make([]string, 0, len(frames))
|
||||
for _, frame := range frames {
|
||||
frame = strings.TrimSpace(frame)
|
||||
if !strings.HasPrefix(frame, "data: ") {
|
||||
continue
|
||||
}
|
||||
data = append(data, strings.TrimPrefix(frame, "data: "))
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestChatWriter_StreamMixedThinkingAndContentEmitsSplitChunks(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
context, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
writer := &ChatWriter{
|
||||
stream: true,
|
||||
streamOptions: &openai.StreamOptions{IncludeUsage: true},
|
||||
id: "chatcmpl-test",
|
||||
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
|
||||
}
|
||||
|
||||
response := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "reasoning",
|
||||
Content: "final answer",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 3,
|
||||
EvalCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal response: %v", err)
|
||||
}
|
||||
|
||||
if _, err = writer.Write(data); err != nil {
|
||||
t.Fatalf("write response: %v", err)
|
||||
}
|
||||
|
||||
if got := recorder.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Fatalf("expected Content-Type text/event-stream, got %q", got)
|
||||
}
|
||||
|
||||
frames := sseDataFrames(recorder.Body.String())
|
||||
if len(frames) != 4 {
|
||||
t.Fatalf("expected 4 SSE data frames (2 chunks + usage + [DONE]), got %d:\n%s", len(frames), recorder.Body.String())
|
||||
}
|
||||
if frames[3] != "[DONE]" {
|
||||
t.Fatalf("expected final frame [DONE], got %q", frames[3])
|
||||
}
|
||||
|
||||
var reasoningChunk openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
|
||||
t.Fatalf("unmarshal reasoning chunk: %v", err)
|
||||
}
|
||||
|
||||
var contentChunk openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(frames[1]), &contentChunk); err != nil {
|
||||
t.Fatalf("unmarshal content chunk: %v", err)
|
||||
}
|
||||
|
||||
var usageChunk openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(frames[2]), &usageChunk); err != nil {
|
||||
t.Fatalf("unmarshal usage chunk: %v", err)
|
||||
}
|
||||
|
||||
if len(reasoningChunk.Choices) != 1 {
|
||||
t.Fatalf("expected 1 reasoning choice, got %d", len(reasoningChunk.Choices))
|
||||
}
|
||||
if reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
|
||||
t.Fatalf("expected reasoning chunk reasoning %q, got %q", "reasoning", reasoningChunk.Choices[0].Delta.Reasoning)
|
||||
}
|
||||
if reasoningChunk.Choices[0].Delta.Content != "" {
|
||||
t.Fatalf("expected reasoning chunk content to be empty, got %v", reasoningChunk.Choices[0].Delta.Content)
|
||||
}
|
||||
if reasoningChunk.Choices[0].FinishReason != nil {
|
||||
t.Fatalf("expected reasoning chunk finish reason nil, got %v", reasoningChunk.Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
if len(contentChunk.Choices) != 1 {
|
||||
t.Fatalf("expected 1 content choice, got %d", len(contentChunk.Choices))
|
||||
}
|
||||
if contentChunk.Choices[0].Delta.Reasoning != "" {
|
||||
t.Fatalf("expected content chunk reasoning to be empty, got %q", contentChunk.Choices[0].Delta.Reasoning)
|
||||
}
|
||||
if contentChunk.Choices[0].Delta.Content != "final answer" {
|
||||
t.Fatalf("expected content chunk content %q, got %v", "final answer", contentChunk.Choices[0].Delta.Content)
|
||||
}
|
||||
if contentChunk.Choices[0].FinishReason == nil || *contentChunk.Choices[0].FinishReason != "stop" {
|
||||
t.Fatalf("expected content chunk finish reason %q, got %v", "stop", contentChunk.Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
if usageChunk.Usage == nil {
|
||||
t.Fatal("expected usage chunk to include usage")
|
||||
}
|
||||
if usageChunk.Usage.TotalTokens != 5 {
|
||||
t.Fatalf("expected usage total tokens 5, got %d", usageChunk.Usage.TotalTokens)
|
||||
}
|
||||
if len(usageChunk.Choices) != 0 {
|
||||
t.Fatalf("expected usage chunk choices to be empty, got %d", len(usageChunk.Choices))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatWriter_StreamSingleChunkPathStillEmitsOneChunk(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
context, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
writer := &ChatWriter{
|
||||
stream: true,
|
||||
id: "chatcmpl-test",
|
||||
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
|
||||
}
|
||||
|
||||
response := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Content: "single chunk",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal response: %v", err)
|
||||
}
|
||||
|
||||
if _, err = writer.Write(data); err != nil {
|
||||
t.Fatalf("write response: %v", err)
|
||||
}
|
||||
|
||||
frames := sseDataFrames(recorder.Body.String())
|
||||
if len(frames) != 2 {
|
||||
t.Fatalf("expected 2 SSE data frames (1 chunk + [DONE]), got %d:\n%s", len(frames), recorder.Body.String())
|
||||
}
|
||||
if frames[1] != "[DONE]" {
|
||||
t.Fatalf("expected final frame [DONE], got %q", frames[1])
|
||||
}
|
||||
|
||||
var chunk openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(frames[0]), &chunk); err != nil {
|
||||
t.Fatalf("unmarshal chunk: %v", err)
|
||||
}
|
||||
if len(chunk.Choices) != 1 {
|
||||
t.Fatalf("expected 1 chunk choice, got %d", len(chunk.Choices))
|
||||
}
|
||||
if chunk.Choices[0].Delta.Content != "single chunk" {
|
||||
t.Fatalf("expected chunk content %q, got %v", "single chunk", chunk.Choices[0].Delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatWriter_StreamMixedThinkingAndToolCallsWithoutDoneEmitsChunksOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
context, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
writer := &ChatWriter{
|
||||
stream: true,
|
||||
streamOptions: &openai.StreamOptions{IncludeUsage: true},
|
||||
id: "chatcmpl-test",
|
||||
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
|
||||
}
|
||||
|
||||
response := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "reasoning",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_234",
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Portland",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: false,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal response: %v", err)
|
||||
}
|
||||
|
||||
if _, err = writer.Write(data); err != nil {
|
||||
t.Fatalf("write response: %v", err)
|
||||
}
|
||||
|
||||
frames := sseDataFrames(recorder.Body.String())
|
||||
if len(frames) != 2 {
|
||||
t.Fatalf("expected 2 SSE data frames (reasoning + tool-calls), got %d:\n%s", len(frames), recorder.Body.String())
|
||||
}
|
||||
if frames[len(frames)-1] == "[DONE]" {
|
||||
t.Fatalf("did not expect [DONE] frame for non-final chunk: %s", recorder.Body.String())
|
||||
}
|
||||
|
||||
var reasoningChunk openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
|
||||
t.Fatalf("unmarshal reasoning chunk: %v", err)
|
||||
}
|
||||
|
||||
var toolCallChunk openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(frames[1]), &toolCallChunk); err != nil {
|
||||
t.Fatalf("unmarshal tool-call chunk: %v", err)
|
||||
}
|
||||
|
||||
if len(reasoningChunk.Choices) != 1 || reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
|
||||
t.Fatalf("expected first chunk to be reasoning-only, got %+v", reasoningChunk.Choices)
|
||||
}
|
||||
if len(toolCallChunk.Choices) != 1 || len(toolCallChunk.Choices[0].Delta.ToolCalls) != 1 {
|
||||
t.Fatalf("expected second chunk to contain tool calls, got %+v", toolCallChunk.Choices)
|
||||
}
|
||||
if toolCallChunk.Choices[0].FinishReason != nil {
|
||||
t.Fatalf("expected nil finish reason for non-final tool-call chunk, got %v", toolCallChunk.Choices[0].FinishReason)
|
||||
}
|
||||
if !writer.toolCallSent {
|
||||
t.Fatal("expected toolCallSent to be tracked after tool-call chunk emission")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatWriter_StreamMixedThinkingAndContentWithoutDoneEmitsChunksOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
context, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
writer := &ChatWriter{
|
||||
stream: true,
|
||||
streamOptions: &openai.StreamOptions{IncludeUsage: true},
|
||||
id: "chatcmpl-test",
|
||||
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
|
||||
}
|
||||
|
||||
response := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "reasoning",
|
||||
Content: "partial content",
|
||||
},
|
||||
Done: false,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal response: %v", err)
|
||||
}
|
||||
|
||||
if _, err = writer.Write(data); err != nil {
|
||||
t.Fatalf("write response: %v", err)
|
||||
}
|
||||
|
||||
frames := sseDataFrames(recorder.Body.String())
|
||||
if len(frames) != 2 {
|
||||
t.Fatalf("expected 2 SSE data frames (reasoning + content), got %d:\n%s", len(frames), recorder.Body.String())
|
||||
}
|
||||
if frames[len(frames)-1] == "[DONE]" {
|
||||
t.Fatalf("did not expect [DONE] frame for non-final chunk: %s", recorder.Body.String())
|
||||
}
|
||||
|
||||
var reasoningChunk openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
|
||||
t.Fatalf("unmarshal reasoning chunk: %v", err)
|
||||
}
|
||||
|
||||
var contentChunk openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(frames[1]), &contentChunk); err != nil {
|
||||
t.Fatalf("unmarshal content chunk: %v", err)
|
||||
}
|
||||
|
||||
if len(reasoningChunk.Choices) != 1 || reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
|
||||
t.Fatalf("expected first chunk to be reasoning-only, got %+v", reasoningChunk.Choices)
|
||||
}
|
||||
if len(contentChunk.Choices) != 1 || contentChunk.Choices[0].Delta.Content != "partial content" {
|
||||
t.Fatalf("expected second chunk to contain content, got %+v", contentChunk.Choices)
|
||||
}
|
||||
if contentChunk.Choices[0].FinishReason != nil {
|
||||
t.Fatalf("expected nil finish reason for non-final content chunk, got %v", contentChunk.Choices[0].FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user