mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
writeError in both OpenAI and Anthropic middleware writers would return a raw json.SyntaxError when the error payload wasn't valid JSON (e.g. "invalid character 'e' looking for beginning of value"). Fall back to using the raw bytes as the error message instead. Also use the actual HTTP status code rather than hardcoding 500, so error types map correctly
1634 lines
41 KiB
Go
1634 lines
41 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/klauspost/compress/zstd"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/openai"
|
|
)
|
|
|
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
|
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
|
props := api.NewToolPropertiesMap()
|
|
for k, v := range m {
|
|
props.Set(k, v)
|
|
}
|
|
return props
|
|
}
|
|
|
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
|
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
|
args := api.NewToolCallFunctionArguments()
|
|
for k, v := range m {
|
|
args.Set(k, v)
|
|
}
|
|
return args
|
|
}
|
|
|
|
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
|
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
|
return cmp.Equal(a.ToMap(), b.ToMap())
|
|
})
|
|
|
|
// propsComparer provides cmp options for comparing ToolPropertiesMap by value
|
|
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
|
|
if a == nil && b == nil {
|
|
return true
|
|
}
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
return cmp.Equal(a.ToMap(), b.ToMap())
|
|
})
|
|
|
|
const (
|
|
prefix = `data:image/jpeg;base64,`
|
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
|
)
|
|
|
|
var (
|
|
False = false
|
|
True = true
|
|
)
|
|
|
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
err := json.Unmarshal(bodyBytes, capturedRequest)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func sseDataFrames(body string) []string {
|
|
frames := strings.Split(body, "\n\n")
|
|
data := make([]string, 0, len(frames))
|
|
for _, frame := range frames {
|
|
frame = strings.TrimSpace(frame)
|
|
if !strings.HasPrefix(frame, "data: ") {
|
|
continue
|
|
}
|
|
data = append(data, strings.TrimPrefix(frame, "data: "))
|
|
}
|
|
return data
|
|
}
|
|
|
|
func TestChatWriter_StreamMixedThinkingAndContentEmitsSplitChunks(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
recorder := httptest.NewRecorder()
|
|
context, _ := gin.CreateTestContext(recorder)
|
|
|
|
writer := &ChatWriter{
|
|
stream: true,
|
|
streamOptions: &openai.StreamOptions{IncludeUsage: true},
|
|
id: "chatcmpl-test",
|
|
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
|
|
}
|
|
|
|
response := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Thinking: "reasoning",
|
|
Content: "final answer",
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{
|
|
PromptEvalCount: 3,
|
|
EvalCount: 2,
|
|
},
|
|
}
|
|
|
|
data, err := json.Marshal(response)
|
|
if err != nil {
|
|
t.Fatalf("marshal response: %v", err)
|
|
}
|
|
|
|
if _, err = writer.Write(data); err != nil {
|
|
t.Fatalf("write response: %v", err)
|
|
}
|
|
|
|
if got := recorder.Header().Get("Content-Type"); got != "text/event-stream" {
|
|
t.Fatalf("expected Content-Type text/event-stream, got %q", got)
|
|
}
|
|
|
|
frames := sseDataFrames(recorder.Body.String())
|
|
if len(frames) != 4 {
|
|
t.Fatalf("expected 4 SSE data frames (2 chunks + usage + [DONE]), got %d:\n%s", len(frames), recorder.Body.String())
|
|
}
|
|
if frames[3] != "[DONE]" {
|
|
t.Fatalf("expected final frame [DONE], got %q", frames[3])
|
|
}
|
|
|
|
var reasoningChunk openai.ChatCompletionChunk
|
|
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
|
|
t.Fatalf("unmarshal reasoning chunk: %v", err)
|
|
}
|
|
|
|
var contentChunk openai.ChatCompletionChunk
|
|
if err := json.Unmarshal([]byte(frames[1]), &contentChunk); err != nil {
|
|
t.Fatalf("unmarshal content chunk: %v", err)
|
|
}
|
|
|
|
var usageChunk openai.ChatCompletionChunk
|
|
if err := json.Unmarshal([]byte(frames[2]), &usageChunk); err != nil {
|
|
t.Fatalf("unmarshal usage chunk: %v", err)
|
|
}
|
|
|
|
if len(reasoningChunk.Choices) != 1 {
|
|
t.Fatalf("expected 1 reasoning choice, got %d", len(reasoningChunk.Choices))
|
|
}
|
|
if reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
|
|
t.Fatalf("expected reasoning chunk reasoning %q, got %q", "reasoning", reasoningChunk.Choices[0].Delta.Reasoning)
|
|
}
|
|
if reasoningChunk.Choices[0].Delta.Content != "" {
|
|
t.Fatalf("expected reasoning chunk content to be empty, got %v", reasoningChunk.Choices[0].Delta.Content)
|
|
}
|
|
if reasoningChunk.Choices[0].FinishReason != nil {
|
|
t.Fatalf("expected reasoning chunk finish reason nil, got %v", reasoningChunk.Choices[0].FinishReason)
|
|
}
|
|
|
|
if len(contentChunk.Choices) != 1 {
|
|
t.Fatalf("expected 1 content choice, got %d", len(contentChunk.Choices))
|
|
}
|
|
if contentChunk.Choices[0].Delta.Reasoning != "" {
|
|
t.Fatalf("expected content chunk reasoning to be empty, got %q", contentChunk.Choices[0].Delta.Reasoning)
|
|
}
|
|
if contentChunk.Choices[0].Delta.Content != "final answer" {
|
|
t.Fatalf("expected content chunk content %q, got %v", "final answer", contentChunk.Choices[0].Delta.Content)
|
|
}
|
|
if contentChunk.Choices[0].FinishReason == nil || *contentChunk.Choices[0].FinishReason != "stop" {
|
|
t.Fatalf("expected content chunk finish reason %q, got %v", "stop", contentChunk.Choices[0].FinishReason)
|
|
}
|
|
|
|
if usageChunk.Usage == nil {
|
|
t.Fatal("expected usage chunk to include usage")
|
|
}
|
|
if usageChunk.Usage.TotalTokens != 5 {
|
|
t.Fatalf("expected usage total tokens 5, got %d", usageChunk.Usage.TotalTokens)
|
|
}
|
|
if len(usageChunk.Choices) != 0 {
|
|
t.Fatalf("expected usage chunk choices to be empty, got %d", len(usageChunk.Choices))
|
|
}
|
|
}
|
|
|
|
func TestChatWriter_StreamSingleChunkPathStillEmitsOneChunk(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
recorder := httptest.NewRecorder()
|
|
context, _ := gin.CreateTestContext(recorder)
|
|
|
|
writer := &ChatWriter{
|
|
stream: true,
|
|
id: "chatcmpl-test",
|
|
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
|
|
}
|
|
|
|
response := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Content: "single chunk",
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
}
|
|
|
|
data, err := json.Marshal(response)
|
|
if err != nil {
|
|
t.Fatalf("marshal response: %v", err)
|
|
}
|
|
|
|
if _, err = writer.Write(data); err != nil {
|
|
t.Fatalf("write response: %v", err)
|
|
}
|
|
|
|
frames := sseDataFrames(recorder.Body.String())
|
|
if len(frames) != 2 {
|
|
t.Fatalf("expected 2 SSE data frames (1 chunk + [DONE]), got %d:\n%s", len(frames), recorder.Body.String())
|
|
}
|
|
if frames[1] != "[DONE]" {
|
|
t.Fatalf("expected final frame [DONE], got %q", frames[1])
|
|
}
|
|
|
|
var chunk openai.ChatCompletionChunk
|
|
if err := json.Unmarshal([]byte(frames[0]), &chunk); err != nil {
|
|
t.Fatalf("unmarshal chunk: %v", err)
|
|
}
|
|
if len(chunk.Choices) != 1 {
|
|
t.Fatalf("expected 1 chunk choice, got %d", len(chunk.Choices))
|
|
}
|
|
if chunk.Choices[0].Delta.Content != "single chunk" {
|
|
t.Fatalf("expected chunk content %q, got %v", "single chunk", chunk.Choices[0].Delta.Content)
|
|
}
|
|
}
|
|
|
|
func TestChatWriter_StreamMixedThinkingAndToolCallsWithoutDoneEmitsChunksOnly(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
recorder := httptest.NewRecorder()
|
|
context, _ := gin.CreateTestContext(recorder)
|
|
|
|
writer := &ChatWriter{
|
|
stream: true,
|
|
streamOptions: &openai.StreamOptions{IncludeUsage: true},
|
|
id: "chatcmpl-test",
|
|
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
|
|
}
|
|
|
|
response := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Thinking: "reasoning",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_234",
|
|
Function: api.ToolCallFunction{
|
|
Index: 0,
|
|
Name: "get_weather",
|
|
Arguments: testArgs(map[string]any{
|
|
"location": "Portland",
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: false,
|
|
}
|
|
|
|
data, err := json.Marshal(response)
|
|
if err != nil {
|
|
t.Fatalf("marshal response: %v", err)
|
|
}
|
|
|
|
if _, err = writer.Write(data); err != nil {
|
|
t.Fatalf("write response: %v", err)
|
|
}
|
|
|
|
frames := sseDataFrames(recorder.Body.String())
|
|
if len(frames) != 2 {
|
|
t.Fatalf("expected 2 SSE data frames (reasoning + tool-calls), got %d:\n%s", len(frames), recorder.Body.String())
|
|
}
|
|
if frames[len(frames)-1] == "[DONE]" {
|
|
t.Fatalf("did not expect [DONE] frame for non-final chunk: %s", recorder.Body.String())
|
|
}
|
|
|
|
var reasoningChunk openai.ChatCompletionChunk
|
|
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
|
|
t.Fatalf("unmarshal reasoning chunk: %v", err)
|
|
}
|
|
|
|
var toolCallChunk openai.ChatCompletionChunk
|
|
if err := json.Unmarshal([]byte(frames[1]), &toolCallChunk); err != nil {
|
|
t.Fatalf("unmarshal tool-call chunk: %v", err)
|
|
}
|
|
|
|
if len(reasoningChunk.Choices) != 1 || reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
|
|
t.Fatalf("expected first chunk to be reasoning-only, got %+v", reasoningChunk.Choices)
|
|
}
|
|
if len(toolCallChunk.Choices) != 1 || len(toolCallChunk.Choices[0].Delta.ToolCalls) != 1 {
|
|
t.Fatalf("expected second chunk to contain tool calls, got %+v", toolCallChunk.Choices)
|
|
}
|
|
if toolCallChunk.Choices[0].FinishReason != nil {
|
|
t.Fatalf("expected nil finish reason for non-final tool-call chunk, got %v", toolCallChunk.Choices[0].FinishReason)
|
|
}
|
|
if !writer.toolCallSent {
|
|
t.Fatal("expected toolCallSent to be tracked after tool-call chunk emission")
|
|
}
|
|
}
|
|
|
|
func TestChatWriter_StreamMixedThinkingAndContentWithoutDoneEmitsChunksOnly(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
recorder := httptest.NewRecorder()
|
|
context, _ := gin.CreateTestContext(recorder)
|
|
|
|
writer := &ChatWriter{
|
|
stream: true,
|
|
streamOptions: &openai.StreamOptions{IncludeUsage: true},
|
|
id: "chatcmpl-test",
|
|
BaseWriter: BaseWriter{ResponseWriter: context.Writer},
|
|
}
|
|
|
|
response := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Thinking: "reasoning",
|
|
Content: "partial content",
|
|
},
|
|
Done: false,
|
|
}
|
|
|
|
data, err := json.Marshal(response)
|
|
if err != nil {
|
|
t.Fatalf("marshal response: %v", err)
|
|
}
|
|
|
|
if _, err = writer.Write(data); err != nil {
|
|
t.Fatalf("write response: %v", err)
|
|
}
|
|
|
|
frames := sseDataFrames(recorder.Body.String())
|
|
if len(frames) != 2 {
|
|
t.Fatalf("expected 2 SSE data frames (reasoning + content), got %d:\n%s", len(frames), recorder.Body.String())
|
|
}
|
|
if frames[len(frames)-1] == "[DONE]" {
|
|
t.Fatalf("did not expect [DONE] frame for non-final chunk: %s", recorder.Body.String())
|
|
}
|
|
|
|
var reasoningChunk openai.ChatCompletionChunk
|
|
if err := json.Unmarshal([]byte(frames[0]), &reasoningChunk); err != nil {
|
|
t.Fatalf("unmarshal reasoning chunk: %v", err)
|
|
}
|
|
|
|
var contentChunk openai.ChatCompletionChunk
|
|
if err := json.Unmarshal([]byte(frames[1]), &contentChunk); err != nil {
|
|
t.Fatalf("unmarshal content chunk: %v", err)
|
|
}
|
|
|
|
if len(reasoningChunk.Choices) != 1 || reasoningChunk.Choices[0].Delta.Reasoning != "reasoning" {
|
|
t.Fatalf("expected first chunk to be reasoning-only, got %+v", reasoningChunk.Choices)
|
|
}
|
|
if len(contentChunk.Choices) != 1 || contentChunk.Choices[0].Delta.Content != "partial content" {
|
|
t.Fatalf("expected second chunk to contain content, got %+v", contentChunk.Choices)
|
|
}
|
|
if contentChunk.Choices[0].FinishReason != nil {
|
|
t.Fatalf("expected nil finish reason for non-final content chunk, got %v", contentChunk.Choices[0].FinishReason)
|
|
}
|
|
}
|
|
|
|
func TestChatMiddleware(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
body string
|
|
req api.ChatRequest
|
|
err openai.ErrorResponse
|
|
}
|
|
|
|
var capturedRequest *api.ChatRequest
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "chat handler",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "Hello",
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler with options",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
],
|
|
"stream": true,
|
|
"max_tokens": 999,
|
|
"seed": 123,
|
|
"stop": ["\n", "stop"],
|
|
"temperature": 3.0,
|
|
"frequency_penalty": 4.0,
|
|
"presence_penalty": 5.0,
|
|
"top_p": 6.0,
|
|
"response_format": {"type": "json_object"}
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "Hello",
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
|
"seed": 123.0,
|
|
"stop": []any{"\n", "stop"},
|
|
"temperature": 3.0,
|
|
"frequency_penalty": 4.0,
|
|
"presence_penalty": 5.0,
|
|
"top_p": 6.0,
|
|
},
|
|
Format: json.RawMessage(`"json"`),
|
|
Stream: &True,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler with streaming usage",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
],
|
|
"stream": true,
|
|
"stream_options": {"include_usage": true},
|
|
"max_tokens": 999,
|
|
"seed": 123,
|
|
"stop": ["\n", "stop"],
|
|
"temperature": 3.0,
|
|
"frequency_penalty": 4.0,
|
|
"presence_penalty": 5.0,
|
|
"top_p": 6.0,
|
|
"response_format": {"type": "json_object"}
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "Hello",
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
|
"seed": 123.0,
|
|
"stop": []any{"\n", "stop"},
|
|
"temperature": 3.0,
|
|
"frequency_penalty": 4.0,
|
|
"presence_penalty": 5.0,
|
|
"top_p": 6.0,
|
|
},
|
|
Format: json.RawMessage(`"json"`),
|
|
Stream: &True,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler with image content",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Hello"
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": "` + prefix + image + `"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "Hello",
|
|
},
|
|
{
|
|
Role: "user",
|
|
Images: []api.ImageData{
|
|
func() []byte {
|
|
img, _ := base64.StdEncoding.DecodeString(image)
|
|
return img
|
|
}(),
|
|
},
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler with tools",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
|
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "What's the weather like in Paris Today?",
|
|
},
|
|
{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "id",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_current_weather",
|
|
Arguments: testArgs(map[string]any{
|
|
"location": "Paris, France",
|
|
"format": "celsius",
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler with tools and content",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
|
{"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "What's the weather like in Paris Today?",
|
|
},
|
|
{
|
|
Role: "assistant",
|
|
Content: "Let's see what the weather is like in Paris",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "id",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_current_weather",
|
|
Arguments: testArgs(map[string]any{
|
|
"location": "Paris, France",
|
|
"format": "celsius",
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler with tools and empty content",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
|
{"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "What's the weather like in Paris Today?",
|
|
},
|
|
{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "id",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_current_weather",
|
|
Arguments: testArgs(map[string]any{
|
|
"location": "Paris, France",
|
|
"format": "celsius",
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler with tools and thinking content",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
|
{"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "What's the weather like in Paris Today?",
|
|
},
|
|
{
|
|
Role: "assistant",
|
|
Thinking: "Let's see what the weather is like in Paris",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "id",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_current_weather",
|
|
Arguments: testArgs(map[string]any{
|
|
"location": "Paris, France",
|
|
"format": "celsius",
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "tool response with call ID",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
|
{"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
|
{"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "What's the weather like in Paris Today?",
|
|
},
|
|
{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "id_abc",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_current_weather",
|
|
Arguments: testArgs(map[string]any{
|
|
"location": "Paris, France",
|
|
"format": "celsius",
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Role: "tool",
|
|
Content: "The weather in Paris is 20 degrees Celsius",
|
|
ToolName: "get_current_weather",
|
|
ToolCallID: "id_abc",
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "tool response with name",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
|
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
|
{"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "What's the weather like in Paris Today?",
|
|
},
|
|
{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "id",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_current_weather",
|
|
Arguments: testArgs(map[string]any{
|
|
"location": "Paris, France",
|
|
"format": "celsius",
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Role: "tool",
|
|
Content: "The weather in Paris is 20 degrees Celsius",
|
|
ToolName: "get_current_weather",
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler with streaming tools",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather like in Paris?"}
|
|
],
|
|
"stream": true,
|
|
"tools": [{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Get the current weather",
|
|
"parameters": {
|
|
"type": "object",
|
|
"required": ["location"],
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state"
|
|
},
|
|
"unit": {
|
|
"type": "string",
|
|
"enum": ["celsius", "fahrenheit"]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "What's the weather like in Paris?",
|
|
},
|
|
},
|
|
Tools: []api.Tool{
|
|
{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: "get_weather",
|
|
Description: "Get the current weather",
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: "object",
|
|
Required: []string{"location"},
|
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
|
"location": {
|
|
Type: api.PropertyType{"string"},
|
|
Description: "The city and state",
|
|
},
|
|
"unit": {
|
|
Type: api.PropertyType{"string"},
|
|
Enum: []any{"celsius", "fahrenheit"},
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Options: map[string]any{
|
|
"temperature": 1.0,
|
|
"top_p": 1.0,
|
|
},
|
|
Stream: &True,
|
|
},
|
|
},
|
|
{
|
|
name: "chat handler error forwarding",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": 2}
|
|
]
|
|
}`,
|
|
err: openai.ErrorResponse{
|
|
Error: openai.Error{
|
|
Message: "invalid message content type: float64",
|
|
Type: "invalid_request_error",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
endpoint := func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
defer func() { capturedRequest = nil }()
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
var errResp openai.ErrorResponse
|
|
if resp.Code != http.StatusOK {
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return
|
|
}
|
|
if diff := cmp.Diff(&tc.req, capturedRequest, argsComparer, propsComparer); diff != "" {
|
|
t.Fatalf("requests did not match: %+v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
|
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCompletionsMiddleware(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
body string
|
|
req api.GenerateRequest
|
|
err openai.ErrorResponse
|
|
}
|
|
|
|
var capturedRequest *api.GenerateRequest
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "completions handler",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "Hello",
|
|
"temperature": 0.8,
|
|
"stop": ["\n", "stop"],
|
|
"suffix": "suffix"
|
|
}`,
|
|
req: api.GenerateRequest{
|
|
Model: "test-model",
|
|
Prompt: "Hello",
|
|
Options: map[string]any{
|
|
"frequency_penalty": 0.0,
|
|
"presence_penalty": 0.0,
|
|
"temperature": 0.8,
|
|
"top_p": 1.0,
|
|
"stop": []any{"\n", "stop"},
|
|
},
|
|
Suffix: "suffix",
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "completions handler stream",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "Hello",
|
|
"stream": true,
|
|
"temperature": 0.8,
|
|
"stop": ["\n", "stop"],
|
|
"suffix": "suffix"
|
|
}`,
|
|
req: api.GenerateRequest{
|
|
Model: "test-model",
|
|
Prompt: "Hello",
|
|
Options: map[string]any{
|
|
"frequency_penalty": 0.0,
|
|
"presence_penalty": 0.0,
|
|
"temperature": 0.8,
|
|
"top_p": 1.0,
|
|
"stop": []any{"\n", "stop"},
|
|
},
|
|
Suffix: "suffix",
|
|
Stream: &True,
|
|
},
|
|
},
|
|
{
|
|
name: "completions handler stream with usage",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "Hello",
|
|
"stream": true,
|
|
"stream_options": {"include_usage": true},
|
|
"temperature": 0.8,
|
|
"stop": ["\n", "stop"],
|
|
"suffix": "suffix"
|
|
}`,
|
|
req: api.GenerateRequest{
|
|
Model: "test-model",
|
|
Prompt: "Hello",
|
|
Options: map[string]any{
|
|
"frequency_penalty": 0.0,
|
|
"presence_penalty": 0.0,
|
|
"temperature": 0.8,
|
|
"top_p": 1.0,
|
|
"stop": []any{"\n", "stop"},
|
|
},
|
|
Suffix: "suffix",
|
|
Stream: &True,
|
|
},
|
|
},
|
|
{
|
|
name: "completions handler error forwarding",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "Hello",
|
|
"temperature": null,
|
|
"stop": [1, 2],
|
|
"suffix": "suffix"
|
|
}`,
|
|
err: openai.ErrorResponse{
|
|
Error: openai.Error{
|
|
Message: "invalid type for 'stop' field: float64",
|
|
Type: "invalid_request_error",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
endpoint := func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
var errResp openai.ErrorResponse
|
|
if resp.Code != http.StatusOK {
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
t.Fatal("requests did not match")
|
|
}
|
|
|
|
if !reflect.DeepEqual(tc.err, errResp) {
|
|
t.Fatal("errors did not match")
|
|
}
|
|
|
|
capturedRequest = nil
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
body string
|
|
req api.EmbedRequest
|
|
err openai.ErrorResponse
|
|
}
|
|
|
|
var capturedRequest *api.EmbedRequest
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "embed handler single input",
|
|
body: `{
|
|
"input": "Hello",
|
|
"model": "test-model"
|
|
}`,
|
|
req: api.EmbedRequest{
|
|
Input: "Hello",
|
|
Model: "test-model",
|
|
},
|
|
},
|
|
{
|
|
name: "embed handler batch input",
|
|
body: `{
|
|
"input": ["Hello", "World"],
|
|
"model": "test-model"
|
|
}`,
|
|
req: api.EmbedRequest{
|
|
Input: []any{"Hello", "World"},
|
|
Model: "test-model",
|
|
},
|
|
},
|
|
{
|
|
name: "embed handler error forwarding",
|
|
body: `{
|
|
"model": "test-model"
|
|
}`,
|
|
err: openai.ErrorResponse{
|
|
Error: openai.Error{
|
|
Message: "invalid input",
|
|
Type: "invalid_request_error",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
endpoint := func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
var errResp openai.ErrorResponse
|
|
if resp.Code != http.StatusOK {
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
t.Fatal("requests did not match")
|
|
}
|
|
|
|
if !reflect.DeepEqual(tc.err, errResp) {
|
|
t.Fatal("errors did not match")
|
|
}
|
|
|
|
capturedRequest = nil
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestListMiddleware(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
endpoint func(c *gin.Context)
|
|
resp string
|
|
}
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "list handler",
|
|
endpoint: func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, api.ListResponse{
|
|
Models: []api.ListModelResponse{
|
|
{
|
|
Name: "test-model",
|
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
|
},
|
|
},
|
|
})
|
|
},
|
|
resp: `{
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": "test-model",
|
|
"object": "model",
|
|
"created": 1686935002,
|
|
"owned_by": "library"
|
|
}
|
|
]
|
|
}`,
|
|
},
|
|
{
|
|
name: "list handler empty output",
|
|
endpoint: func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, api.ListResponse{})
|
|
},
|
|
resp: `{
|
|
"object": "list",
|
|
"data": null
|
|
}`,
|
|
},
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
for _, tc := range testCases {
|
|
router := gin.New()
|
|
router.Use(ListMiddleware())
|
|
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
|
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
var expected, actual map[string]any
|
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
|
if err != nil {
|
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
|
}
|
|
|
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
|
if err != nil {
|
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(expected, actual) {
|
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRetrieveMiddleware(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
endpoint func(c *gin.Context)
|
|
resp string
|
|
}
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "retrieve handler",
|
|
endpoint: func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, api.ShowResponse{
|
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
|
})
|
|
},
|
|
resp: `{
|
|
"id":"test-model",
|
|
"object":"model",
|
|
"created":1686935002,
|
|
"owned_by":"library"}
|
|
`,
|
|
},
|
|
{
|
|
name: "retrieve handler error forwarding",
|
|
endpoint: func(c *gin.Context) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
|
},
|
|
resp: `{
|
|
"error": {
|
|
"code": null,
|
|
"message": "model not found",
|
|
"param": null,
|
|
"type": "invalid_request_error"
|
|
}
|
|
}`,
|
|
},
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
for _, tc := range testCases {
|
|
router := gin.New()
|
|
router.Use(RetrieveMiddleware())
|
|
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
|
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
var expected, actual map[string]any
|
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
|
if err != nil {
|
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
|
}
|
|
|
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
|
if err != nil {
|
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(expected, actual) {
|
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestImageGenerationsMiddleware(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
body string
|
|
req api.GenerateRequest
|
|
err openai.ErrorResponse
|
|
}
|
|
|
|
var capturedRequest *api.GenerateRequest
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "image generation basic",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "a beautiful sunset"
|
|
}`,
|
|
req: api.GenerateRequest{
|
|
Model: "test-model",
|
|
Prompt: "a beautiful sunset",
|
|
},
|
|
},
|
|
{
|
|
name: "image generation with size",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "a beautiful sunset",
|
|
"size": "512x768"
|
|
}`,
|
|
req: api.GenerateRequest{
|
|
Model: "test-model",
|
|
Prompt: "a beautiful sunset",
|
|
Width: 512,
|
|
Height: 768,
|
|
},
|
|
},
|
|
{
|
|
name: "image generation missing prompt",
|
|
body: `{
|
|
"model": "test-model"
|
|
}`,
|
|
err: openai.ErrorResponse{
|
|
Error: openai.Error{
|
|
Message: "prompt is required",
|
|
Type: "invalid_request_error",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "image generation missing model",
|
|
body: `{
|
|
"prompt": "a beautiful sunset"
|
|
}`,
|
|
err: openai.ErrorResponse{
|
|
Error: openai.Error{
|
|
Message: "model is required",
|
|
Type: "invalid_request_error",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
endpoint := func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
defer func() { capturedRequest = nil }()
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if tc.err.Error.Message != "" {
|
|
var errResp openai.ErrorResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
|
t.Fatalf("errors did not match:\n%s", diff)
|
|
}
|
|
return
|
|
}
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
|
t.Fatalf("requests did not match:\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestImageWriterResponse(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
// Test that ImageWriter transforms GenerateResponse to OpenAI format
|
|
endpoint := func(c *gin.Context) {
|
|
resp := api.GenerateResponse{
|
|
Model: "test-model",
|
|
CreatedAt: time.Unix(1234567890, 0).UTC(),
|
|
Done: true,
|
|
Image: "dGVzdC1pbWFnZS1kYXRh", // base64 of "test-image-data"
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.Write(append(data, '\n'))
|
|
}
|
|
|
|
router := gin.New()
|
|
router.Use(ImageGenerationsMiddleware())
|
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
|
|
|
body := `{"model": "test-model", "prompt": "test"}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var imageResp openai.ImageGenerationResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
|
|
if imageResp.Created != 1234567890 {
|
|
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
|
|
}
|
|
|
|
if len(imageResp.Data) != 1 {
|
|
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
|
|
}
|
|
|
|
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
|
|
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
|
}
|
|
}
|
|
|
|
func TestImageEditsMiddleware(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
body string
|
|
req api.GenerateRequest
|
|
err openai.ErrorResponse
|
|
}
|
|
|
|
var capturedRequest *api.GenerateRequest
|
|
|
|
// Base64-encoded test image (1x1 pixel PNG)
|
|
testImage := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
|
|
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "image edit basic",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "make it blue",
|
|
"image": "` + testImage + `"
|
|
}`,
|
|
req: api.GenerateRequest{
|
|
Model: "test-model",
|
|
Prompt: "make it blue",
|
|
Images: []api.ImageData{decodedImage},
|
|
},
|
|
},
|
|
{
|
|
name: "image edit with size",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "make it blue",
|
|
"image": "` + testImage + `",
|
|
"size": "512x768"
|
|
}`,
|
|
req: api.GenerateRequest{
|
|
Model: "test-model",
|
|
Prompt: "make it blue",
|
|
Images: []api.ImageData{decodedImage},
|
|
Width: 512,
|
|
Height: 768,
|
|
},
|
|
},
|
|
{
|
|
name: "image edit missing prompt",
|
|
body: `{
|
|
"model": "test-model",
|
|
"image": "` + testImage + `"
|
|
}`,
|
|
err: openai.ErrorResponse{
|
|
Error: openai.Error{
|
|
Message: "prompt is required",
|
|
Type: "invalid_request_error",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "image edit missing model",
|
|
body: `{
|
|
"prompt": "make it blue",
|
|
"image": "` + testImage + `"
|
|
}`,
|
|
err: openai.ErrorResponse{
|
|
Error: openai.Error{
|
|
Message: "model is required",
|
|
Type: "invalid_request_error",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "image edit missing image",
|
|
body: `{
|
|
"model": "test-model",
|
|
"prompt": "make it blue"
|
|
}`,
|
|
err: openai.ErrorResponse{
|
|
Error: openai.Error{
|
|
Message: "image is required",
|
|
Type: "invalid_request_error",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
endpoint := func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
defer func() { capturedRequest = nil }()
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if tc.err.Error.Message != "" {
|
|
var errResp openai.ErrorResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
|
t.Fatalf("errors did not match:\n%s", diff)
|
|
}
|
|
return
|
|
}
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
|
t.Fatalf("requests did not match:\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func zstdCompress(t *testing.T, data []byte) []byte {
|
|
t.Helper()
|
|
var buf bytes.Buffer
|
|
w, err := zstd.NewWriter(&buf)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := w.Write(data); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := w.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func TestResponsesMiddlewareZstd(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
useZstd bool
|
|
oversized bool
|
|
wantCode int
|
|
wantModel string
|
|
wantMessage string
|
|
}{
|
|
{
|
|
name: "plain JSON",
|
|
body: `{"model": "test-model", "input": "Hello"}`,
|
|
wantCode: http.StatusOK,
|
|
wantModel: "test-model",
|
|
wantMessage: "Hello",
|
|
},
|
|
{
|
|
name: "zstd compressed",
|
|
body: `{"model": "test-model", "input": "Hello"}`,
|
|
useZstd: true,
|
|
wantCode: http.StatusOK,
|
|
wantModel: "test-model",
|
|
wantMessage: "Hello",
|
|
},
|
|
{
|
|
name: "zstd over max decompressed size",
|
|
oversized: true,
|
|
useZstd: true,
|
|
wantCode: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var capturedRequest *api.ChatRequest
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
var bodyReader io.Reader
|
|
if tt.oversized {
|
|
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
|
|
} else if tt.useZstd {
|
|
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
|
|
} else {
|
|
bodyReader = strings.NewReader(tt.body)
|
|
}
|
|
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if tt.useZstd || tt.oversized {
|
|
req.Header.Set("Content-Encoding", "zstd")
|
|
}
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != tt.wantCode {
|
|
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
|
|
}
|
|
|
|
if tt.wantCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
if capturedRequest == nil {
|
|
t.Fatal("expected captured request, got nil")
|
|
}
|
|
if capturedRequest.Model != tt.wantModel {
|
|
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
|
|
}
|
|
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
|
|
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
|
|
}
|
|
})
|
|
}
|
|
}
|