diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index 94d435932..37eb0baf4 100755 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -1,17 +1,25 @@ package anthropic import ( + "bytes" + "context" "crypto/rand" "encoding/base64" "encoding/json" "errors" "fmt" + "io" "log/slog" "net/http" + "net/url" + "strconv" "strings" "time" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" + internalcloud "github.com/ollama/ollama/internal/cloud" + "github.com/ollama/ollama/logutil" ) // Error types matching Anthropic API @@ -82,22 +90,25 @@ type MessageParam struct { // Text and Thinking use pointers so they serialize as the field being present (even if empty) // only when set, which is required for SDK streaming accumulation. type ContentBlock struct { - Type string `json:"type"` // text, image, tool_use, tool_result, thinking + Type string `json:"type"` // text, image, tool_use, tool_result, thinking, server_tool_use, web_search_tool_result // For text blocks - pointer so field only appears when set (SDK requires it for accumulation) Text *string `json:"text,omitempty"` + // For text blocks with citations + Citations []Citation `json:"citations,omitempty"` + // For image blocks Source *ImageSource `json:"source,omitempty"` - // For tool_use blocks + // For tool_use and server_tool_use blocks ID string `json:"id,omitempty"` Name string `json:"name,omitempty"` Input any `json:"input,omitempty"` - // For tool_result blocks + // For tool_result and web_search_tool_result blocks ToolUseID string `json:"tool_use_id,omitempty"` - Content any `json:"content,omitempty"` // string or []ContentBlock + Content any `json:"content,omitempty"` // string, []ContentBlock, []WebSearchResult, or WebSearchToolResultError IsError bool `json:"is_error,omitempty"` // For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation) @@ -105,6 +116,30 @@ type ContentBlock struct { Signature string `json:"signature,omitempty"` } +// Citation represents a citation in a text block +type Citation struct { + Type string `json:"type"` // "web_search_result_location" + URL string `json:"url"` + Title string `json:"title"` + EncryptedIndex string `json:"encrypted_index,omitempty"` + CitedText string `json:"cited_text,omitempty"` +} + +// WebSearchResult represents a single web search result +type WebSearchResult struct { + Type string `json:"type"` // "web_search_result" + URL string `json:"url"` + Title string `json:"title"` + EncryptedContent string `json:"encrypted_content,omitempty"` + PageAge string `json:"page_age,omitempty"` +} + +// WebSearchToolResultError represents an error from web search +type WebSearchToolResultError struct { + Type string `json:"type"` // "web_search_tool_result_error" + ErrorCode string `json:"error_code"` +} + // ImageSource represents the source of an image type ImageSource struct { Type string `json:"type"` // "base64" or "url" @@ -115,10 +150,13 @@ type ImageSource struct { // Tool represents a tool definition type Tool struct { - Type string `json:"type,omitempty"` // "custom" for user-defined tools + Type string `json:"type,omitempty"` // "custom" for user-defined tools, or "web_search_20250305" for web search Name string `json:"name"` Description string `json:"description,omitempty"` InputSchema json.RawMessage `json:"input_schema,omitempty"` + + // Web search specific fields + MaxUses int `json:"max_uses,omitempty"` } // ToolChoice controls how the model uses tools @@ -233,6 +271,8 @@ type StreamErrorEvent struct { // FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) { + logutil.Trace("anthropic: converting request", "req", TraceMessagesRequest(r)) + var messages []api.Message if r.System != nil { @@ -259,9 +299,10 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) { } } - for _, msg := range r.Messages { + for i, msg := range r.Messages { converted, err := convertMessage(msg) if err != nil { + logutil.Trace("anthropic: message conversion failed", "index", i, "role", msg.Role, "err", err) return nil, err } messages = append(messages, converted...) @@ -288,8 +329,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) { } var tools api.Tools + hasBuiltinWebSearch := false for _, t := range r.Tools { - tool, err := convertTool(t) + if strings.HasPrefix(t.Type, "web_search") { + hasBuiltinWebSearch = true + break + } + } + + for _, t := range r.Tools { + // Anthropic built-in web_search maps to Ollama function name "web_search". + // If a user-defined tool also uses that name in the same request, drop the + // user-defined one to avoid ambiguous tool-call routing. + if hasBuiltinWebSearch && !strings.HasPrefix(t.Type, "web_search") && t.Name == "web_search" { + logutil.Trace("anthropic: dropping colliding custom web_search tool", "tool", TraceTool(t)) + continue + } + + tool, _, err := convertTool(t) if err != nil { return nil, err } @@ -302,15 +359,17 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) { } stream := r.Stream - - return &api.ChatRequest{ + convertedRequest := &api.ChatRequest{ Model: r.Model, Messages: messages, Options: options, Stream: &stream, Tools: tools, Think: think, - }, nil + } + logutil.Trace("anthropic: converted request", "req", TraceChatRequest(convertedRequest)) + + return convertedRequest, nil } // convertMessage converts an Anthropic MessageParam to Ollama api.Message(s) @@ -328,10 +387,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { var toolCalls []api.ToolCall var thinking string var toolResults []api.Message + textBlocks := 0 + imageBlocks := 0 + toolUseBlocks := 0 + toolResultBlocks := 0 + serverToolUseBlocks := 0 + webSearchToolResultBlocks := 0 + thinkingBlocks := 0 + unknownBlocks := 0 for _, block := range content { blockMap, ok := block.(map[string]any) if !ok { + logutil.Trace("anthropic: invalid content block format", "role", role) return nil, errors.New("invalid content block format") } @@ -339,13 +407,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { switch blockType { case "text": + textBlocks++ if text, ok := blockMap["text"].(string); ok { textContent.WriteString(text) } case "image": + imageBlocks++ source, ok := blockMap["source"].(map[string]any) if !ok { + logutil.Trace("anthropic: invalid image source", "role", role) return nil, errors.New("invalid image source") } @@ -354,21 +425,26 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { data, _ := source["data"].(string) decoded, err := base64.StdEncoding.DecodeString(data) if err != nil { + logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err) return nil, fmt.Errorf("invalid base64 image data: %w", err) } images = append(images, decoded) } else { + logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType) return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType) } // URL images would need to be fetched - skip for now case "tool_use": + toolUseBlocks++ id, ok := blockMap["id"].(string) if !ok { + logutil.Trace("anthropic: tool_use block missing id", "role", role) return nil, errors.New("tool_use block missing required 'id' field") } name, ok := blockMap["name"].(string) if !ok { + logutil.Trace("anthropic: tool_use block missing name", "role", role) return nil, errors.New("tool_use block missing required 'name' field") } tc := api.ToolCall{ @@ -383,6 +459,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { toolCalls = append(toolCalls, tc) case "tool_result": + toolResultBlocks++ toolUseID, _ := blockMap["tool_use_id"].(string) var resultContent string @@ -408,9 +485,36 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { }) case "thinking": + thinkingBlocks++ if t, ok := blockMap["thinking"].(string); ok { thinking = t } + + case "server_tool_use": + serverToolUseBlocks++ + id, _ := blockMap["id"].(string) + name, _ := blockMap["name"].(string) + tc := api.ToolCall{ + ID: id, + Function: api.ToolCallFunction{ + Name: name, + }, + } + if input, ok := blockMap["input"].(map[string]any); ok { + tc.Function.Arguments = mapToArgs(input) + } + toolCalls = append(toolCalls, tc) + + case "web_search_tool_result": + webSearchToolResultBlocks++ + toolUseID, _ := blockMap["tool_use_id"].(string) + toolResults = append(toolResults, api.Message{ + Role: "tool", + Content: formatWebSearchToolResultContent(blockMap["content"]), + ToolCallID: toolUseID, + }) + default: + unknownBlocks++ } } @@ -427,6 +531,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { // Add tool results as separate messages messages = append(messages, toolResults...) + logutil.Trace("anthropic: converted block message", + "role", role, + "blocks", len(content), + "text", textBlocks, + "image", imageBlocks, + "tool_use", toolUseBlocks, + "tool_result", toolResultBlocks, + "server_tool_use", serverToolUseBlocks, + "web_search_result", webSearchToolResultBlocks, + "thinking", thinkingBlocks, + "unknown", unknownBlocks, + "messages", TraceAPIMessages(messages), + ) default: return nil, fmt.Errorf("invalid message content type: %T", content) @@ -435,12 +552,94 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { return messages, nil } -// convertTool converts an Anthropic Tool to an Ollama api.Tool -func convertTool(t Tool) (api.Tool, error) { +func formatWebSearchToolResultContent(content any) string { + switch c := content.(type) { + case string: + return c + case []WebSearchResult: + var resultContent strings.Builder + for _, item := range c { + if item.Type != "web_search_result" { + continue + } + fmt.Fprintf(&resultContent, "- %s: %s\n", item.Title, item.URL) + } + return resultContent.String() + case []any: + var resultContent strings.Builder + for _, item := range c { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + switch itemMap["type"] { + case "web_search_result": + title, _ := itemMap["title"].(string) + url, _ := itemMap["url"].(string) + fmt.Fprintf(&resultContent, "- %s: %s\n", title, url) + case "web_search_tool_result_error": + errorCode, _ := itemMap["error_code"].(string) + if errorCode == "" { + return "web_search_tool_result_error" + } + return "web_search_tool_result_error: " + errorCode + } + } + return resultContent.String() + case map[string]any: + if c["type"] == "web_search_tool_result_error" { + errorCode, _ := c["error_code"].(string) + if errorCode == "" { + return "web_search_tool_result_error" + } + return "web_search_tool_result_error: " + errorCode + } + data, err := json.Marshal(c) + if err != nil { + return "" + } + return string(data) + case WebSearchToolResultError: + if c.ErrorCode == "" { + return "web_search_tool_result_error" + } + return "web_search_tool_result_error: " + c.ErrorCode + default: + data, err := json.Marshal(c) + if err != nil { + return "" + } + return string(data) + } +} + +// convertTool converts an Anthropic Tool to an Ollama api.Tool, returning true if it's a server tool +func convertTool(t Tool) (api.Tool, bool, error) { + if strings.HasPrefix(t.Type, "web_search") { + props := api.NewToolPropertiesMap() + props.Set("query", api.ToolProperty{ + Type: api.PropertyType{"string"}, + Description: "The search query to look up on the web", + }) + return api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "web_search", + Description: "Search the web for current information. Use this to find up-to-date information about any topic.", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"query"}, + Properties: props, + }, + }, + }, true, nil + } + var params api.ToolFunctionParameters if len(t.InputSchema) > 0 { if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil { - return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err) + logutil.Trace("anthropic: invalid tool schema", "tool", t.Name, "err", err) + return api.Tool{}, false, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err) } } @@ -451,7 +650,7 @@ func convertTool(t Tool) (api.Tool, error) { Description: t.Description, Parameters: params, }, - }, nil + }, false, nil } // ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse @@ -899,3 +1098,113 @@ func countContentBlock(block any) int { return total } + +// OllamaWebSearchRequest represents a request to the Ollama web search API +type OllamaWebSearchRequest struct { + Query string `json:"query"` + MaxResults int `json:"max_results,omitempty"` +} + +// OllamaWebSearchResult represents a single search result from Ollama API +type OllamaWebSearchResult struct { + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` +} + +// OllamaWebSearchResponse represents the response from the Ollama web search API +type OllamaWebSearchResponse struct { + Results []OllamaWebSearchResult `json:"results"` +} + +var WebSearchEndpoint = "https://ollama.com/api/web_search" + +func WebSearch(ctx context.Context, query string, maxResults int) (*OllamaWebSearchResponse, error) { + if internalcloud.Disabled() { + logutil.TraceContext(ctx, "anthropic: web search blocked", "reason", "cloud_disabled") + return nil, errors.New(internalcloud.DisabledError("web search is unavailable")) + } + + if maxResults <= 0 { + maxResults = 5 + } + if maxResults > 10 { + maxResults = 10 + } + + reqBody := OllamaWebSearchRequest{ + Query: query, + MaxResults: maxResults, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal web search request: %w", err) + } + + searchURL, err := url.Parse(WebSearchEndpoint) + if err != nil { + return nil, fmt.Errorf("failed to parse web search URL: %w", err) + } + logutil.TraceContext(ctx, "anthropic: web search request", + "query", TraceTruncateString(query), + "max_results", maxResults, + "url", searchURL.String(), + ) + + q := searchURL.Query() + q.Set("ts", strconv.FormatInt(time.Now().Unix(), 10)) + searchURL.RawQuery = q.Encode() + + signature := "" + if strings.EqualFold(searchURL.Hostname(), "ollama.com") { + challenge := fmt.Sprintf("%s,%s", http.MethodPost, searchURL.RequestURI()) + signature, err = auth.Sign(ctx, []byte(challenge)) + if err != nil { + return nil, fmt.Errorf("failed to sign web search request: %w", err) + } + } + logutil.TraceContext(ctx, "anthropic: web search auth", "signed", signature != "") + + req, err := http.NewRequestWithContext(ctx, "POST", searchURL.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create web search request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if signature != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature)) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("web search request failed: %w", err) + } + defer resp.Body.Close() + logutil.TraceContext(ctx, "anthropic: web search response", "status", resp.StatusCode) + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("web search returned status %d: %s", resp.StatusCode, string(respBody)) + } + + var searchResp OllamaWebSearchResponse + if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil { + return nil, fmt.Errorf("failed to decode web search response: %w", err) + } + logutil.TraceContext(ctx, "anthropic: web search results", "count", len(searchResp.Results)) + + return &searchResp, nil +} + +func ConvertOllamaToAnthropicResults(ollamaResults *OllamaWebSearchResponse) []WebSearchResult { + var results []WebSearchResult + for _, r := range ollamaResults.Results { + results = append(results, WebSearchResult{ + Type: "web_search_result", + URL: r.URL, + Title: r.Title, + }) + } + return results +} diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go index 2f2717bf0..9aade4099 100755 --- a/anthropic/anthropic_test.go +++ b/anthropic/anthropic_test.go @@ -3,6 +3,7 @@ package anthropic import ( "encoding/base64" "encoding/json" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -300,6 +301,78 @@ func TestFromMessagesRequest_WithTools(t *testing.T) { } } +func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Tools: []Tool{ + { + Type: "web_search_20250305", + Name: "web_search", + }, + { + Type: "custom", + Name: "web_search", + Description: "User-defined web search that should be dropped", + InputSchema: json.RawMessage(`{"type":"invalid"}`), + }, + { + Type: "custom", + Name: "get_weather", + Description: "Get current weather", + InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`), + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Tools) != 2 { + t.Fatalf("expected 2 tools after dropping custom web_search, got %d", len(result.Tools)) + } + if result.Tools[0].Function.Name != "web_search" { + t.Fatalf("expected first tool to be built-in web_search, got %q", result.Tools[0].Function.Name) + } + if result.Tools[1].Function.Name != "get_weather" { + t.Fatalf("expected second tool to be get_weather, got %q", result.Tools[1].Function.Name) + } +} + +func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Tools: []Tool{ + { + Type: "custom", + Name: "web_search", + Description: "User-defined web search", + InputSchema: json.RawMessage(`{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`), + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Tools) != 1 { + t.Fatalf("expected 1 custom tool, got %d", len(result.Tools)) + } + if result.Tools[0].Function.Name != "web_search" { + t.Fatalf("expected custom tool name web_search, got %q", result.Tools[0].Function.Name) + } + if result.Tools[0].Function.Description != "User-defined web search" { + t.Fatalf("expected custom description preserved, got %q", result.Tools[0].Function.Description) + } +} + func TestFromMessagesRequest_WithThinking(t *testing.T) { req := MessagesRequest{ Model: "test-model", @@ -1063,3 +1136,320 @@ func TestEstimateTokens_EmptyContent(t *testing.T) { t.Errorf("expected 0 tokens for empty content, got %d", tokens) } } + +// Web Search Tests + +func TestConvertTool_WebSearch(t *testing.T) { + tool := Tool{ + Type: "web_search_20250305", + Name: "web_search", + MaxUses: 5, + } + + result, isServerTool, err := convertTool(tool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !isServerTool { + t.Error("expected isServerTool to be true for web_search tool") + } + + if result.Type != "function" { + t.Errorf("expected type 'function', got %q", result.Type) + } + + if result.Function.Name != "web_search" { + t.Errorf("expected name 'web_search', got %q", result.Function.Name) + } + + if result.Function.Description == "" { + t.Error("expected non-empty description for web_search tool") + } + + // Check that query parameter is defined + if result.Function.Parameters.Properties == nil { + t.Fatal("expected properties to be defined") + } + + queryProp, ok := result.Function.Parameters.Properties.Get("query") + if !ok { + t.Error("expected 'query' property to be defined") + } + + if len(queryProp.Type) == 0 || queryProp.Type[0] != "string" { + t.Errorf("expected query type to be 'string', got %v", queryProp.Type) + } +} + +func TestConvertTool_RegularTool(t *testing.T) { + tool := Tool{ + Type: "custom", + Name: "get_weather", + Description: "Get the weather", + InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`), + } + + result, isServerTool, err := convertTool(tool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if isServerTool { + t.Error("expected isServerTool to be false for regular tool") + } + + if result.Function.Name != "get_weather" { + t.Errorf("expected name 'get_weather', got %q", result.Function.Name) + } +} + +func TestConvertMessage_ServerToolUse(t *testing.T) { + msg := MessageParam{ + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "server_tool_use", + "id": "srvtoolu_123", + "name": "web_search", + "input": map[string]any{"query": "test query"}, + }, + }, + } + + messages, err := convertMessage(msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(messages)) + } + + if len(messages[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(messages[0].ToolCalls)) + } + + tc := messages[0].ToolCalls[0] + if tc.ID != "srvtoolu_123" { + t.Errorf("expected tool call ID 'srvtoolu_123', got %q", tc.ID) + } + + if tc.Function.Name != "web_search" { + t.Errorf("expected tool name 'web_search', got %q", tc.Function.Name) + } +} + +func TestConvertMessage_WebSearchToolResult(t *testing.T) { + msg := MessageParam{ + Role: "user", + Content: []any{ + map[string]any{ + "type": "web_search_tool_result", + "tool_use_id": "srvtoolu_123", + "content": []any{ + map[string]any{ + "type": "web_search_result", + "title": "Test Result", + "url": "https://example.com", + }, + }, + }, + }, + } + + messages, err := convertMessage(msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should have a tool result message + if len(messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(messages)) + } + + if messages[0].Role != "tool" { + t.Errorf("expected role 'tool', got %q", messages[0].Role) + } + + if messages[0].ToolCallID != "srvtoolu_123" { + t.Errorf("expected tool_call_id 'srvtoolu_123', got %q", messages[0].ToolCallID) + } + + if messages[0].Content == "" { + t.Error("expected non-empty content from web search results") + } +} + +func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) { + msg := MessageParam{ + Role: "user", + Content: []any{ + map[string]any{ + "type": "web_search_tool_result", + "tool_use_id": "srvtoolu_empty", + "content": []any{}, + }, + }, + } + + messages, err := convertMessage(msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(messages)) + } + if messages[0].Role != "tool" { + t.Fatalf("expected role tool, got %q", messages[0].Role) + } + if messages[0].ToolCallID != "srvtoolu_empty" { + t.Fatalf("expected tool_call_id srvtoolu_empty, got %q", messages[0].ToolCallID) + } + if messages[0].Content != "" { + t.Fatalf("expected empty content for empty web search results, got %q", messages[0].Content) + } +} + +func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) { + msg := MessageParam{ + Role: "user", + Content: []any{ + map[string]any{ + "type": "web_search_tool_result", + "tool_use_id": "srvtoolu_error", + "content": map[string]any{ + "type": "web_search_tool_result_error", + "error_code": "max_uses_exceeded", + }, + }, + }, + } + + messages, err := convertMessage(msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(messages)) + } + if messages[0].Role != "tool" { + t.Fatalf("expected role tool, got %q", messages[0].Role) + } + if messages[0].ToolCallID != "srvtoolu_error" { + t.Fatalf("expected tool_call_id srvtoolu_error, got %q", messages[0].ToolCallID) + } + if !strings.Contains(messages[0].Content, "max_uses_exceeded") { + t.Fatalf("expected error code in converted tool content, got %q", messages[0].Content) + } +} + +func TestConvertOllamaToAnthropicResults(t *testing.T) { + ollamaResp := &OllamaWebSearchResponse{ + Results: []OllamaWebSearchResult{ + { + Title: "Test Title", + URL: "https://example.com", + Content: "Test content", + }, + { + Title: "Another Result", + URL: "https://example.org", + Content: "More content", + }, + }, + } + + results := ConvertOllamaToAnthropicResults(ollamaResp) + + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + + if results[0].Type != "web_search_result" { + t.Errorf("expected type 'web_search_result', got %q", results[0].Type) + } + + if results[0].Title != "Test Title" { + t.Errorf("expected title 'Test Title', got %q", results[0].Title) + } + + if results[0].URL != "https://example.com" { + t.Errorf("expected URL 'https://example.com', got %q", results[0].URL) + } +} + +func TestWebSearchTypes(t *testing.T) { + // Test that WebSearchResult serializes correctly + result := WebSearchResult{ + Type: "web_search_result", + URL: "https://example.com", + Title: "Test", + EncryptedContent: "abc123", + PageAge: "2025-01-01", + } + + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("failed to marshal WebSearchResult: %v", err) + } + + var unmarshaled WebSearchResult + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("failed to unmarshal WebSearchResult: %v", err) + } + + if unmarshaled.Type != result.Type { + t.Errorf("type mismatch: expected %q, got %q", result.Type, unmarshaled.Type) + } + + // Test WebSearchToolResultError + errResult := WebSearchToolResultError{ + Type: "web_search_tool_result_error", + ErrorCode: "max_uses_exceeded", + } + + data, err = json.Marshal(errResult) + if err != nil { + t.Fatalf("failed to marshal WebSearchToolResultError: %v", err) + } + + var unmarshaledErr WebSearchToolResultError + if err := json.Unmarshal(data, &unmarshaledErr); err != nil { + t.Fatalf("failed to unmarshal WebSearchToolResultError: %v", err) + } + + if unmarshaledErr.ErrorCode != "max_uses_exceeded" { + t.Errorf("error_code mismatch: expected 'max_uses_exceeded', got %q", unmarshaledErr.ErrorCode) + } +} + +func TestCitation(t *testing.T) { + citation := Citation{ + Type: "web_search_result_location", + URL: "https://example.com", + Title: "Example", + EncryptedIndex: "enc123", + CitedText: "Some cited text...", + } + + data, err := json.Marshal(citation) + if err != nil { + t.Fatalf("failed to marshal Citation: %v", err) + } + + var unmarshaled Citation + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("failed to unmarshal Citation: %v", err) + } + + if unmarshaled.Type != "web_search_result_location" { + t.Errorf("type mismatch: expected 'web_search_result_location', got %q", unmarshaled.Type) + } + + if unmarshaled.CitedText != "Some cited text..." { + t.Errorf("cited_text mismatch: expected 'Some cited text...', got %q", unmarshaled.CitedText) + } +} diff --git a/anthropic/trace.go b/anthropic/trace.go new file mode 100644 index 000000000..65dc8dbc2 --- /dev/null +++ b/anthropic/trace.go @@ -0,0 +1,352 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + "sort" + + "github.com/ollama/ollama/api" +) + +// Trace truncation limits. +const ( + TraceMaxStringRunes = 240 + TraceMaxSliceItems = 8 + TraceMaxMapEntries = 16 + TraceMaxDepth = 4 +) + +// TraceTruncateString shortens s to TraceMaxStringRunes, appending a count of +// omitted characters when truncated. +func TraceTruncateString(s string) string { + if len(s) == 0 { + return s + } + runes := []rune(s) + if len(runes) <= TraceMaxStringRunes { + return s + } + return fmt.Sprintf("%s...(+%d chars)", string(runes[:TraceMaxStringRunes]), len(runes)-TraceMaxStringRunes) +} + +// TraceJSON round-trips v through JSON and returns a compacted representation. +func TraceJSON(v any) any { + if v == nil { + return nil + } + data, err := json.Marshal(v) + if err != nil { + return map[string]any{"marshal_error": err.Error(), "type": fmt.Sprintf("%T", v)} + } + var out any + if err := json.Unmarshal(data, &out); err != nil { + return TraceTruncateString(string(data)) + } + return TraceCompactValue(out, 0) +} + +// TraceCompactValue recursively truncates strings, slices, and maps for trace +// output. depth tracks recursion to enforce TraceMaxDepth. +func TraceCompactValue(v any, depth int) any { + if v == nil { + return nil + } + if depth >= TraceMaxDepth { + switch t := v.(type) { + case string: + return TraceTruncateString(t) + case []any: + return fmt.Sprintf("", len(t)) + case map[string]any: + return fmt.Sprintf("", len(t)) + default: + return fmt.Sprintf("<%T>", v) + } + } + switch t := v.(type) { + case string: + return TraceTruncateString(t) + case []any: + limit := min(len(t), TraceMaxSliceItems) + out := make([]any, 0, limit+1) + for i := range limit { + out = append(out, TraceCompactValue(t[i], depth+1)) + } + if len(t) > limit { + out = append(out, fmt.Sprintf("... +%d more items", len(t)-limit)) + } + return out + case map[string]any: + keys := make([]string, 0, len(t)) + for k := range t { + keys = append(keys, k) + } + sort.Strings(keys) + limit := min(len(keys), TraceMaxMapEntries) + out := make(map[string]any, limit+1) + for i := range limit { + out[keys[i]] = TraceCompactValue(t[keys[i]], depth+1) + } + if len(keys) > limit { + out["__truncated_keys"] = len(keys) - limit + } + return out + default: + return t + } +} + +// --------------------------------------------------------------------------- +// Anthropic request/response tracing +// --------------------------------------------------------------------------- + +// TraceMessagesRequest returns a compact trace representation of a MessagesRequest. +func TraceMessagesRequest(r MessagesRequest) map[string]any { + return map[string]any{ + "model": r.Model, + "max_tokens": r.MaxTokens, + "messages": traceMessageParams(r.Messages), + "system": traceAnthropicContent(r.System), + "stream": r.Stream, + "tools": traceTools(r.Tools), + "tool_choice": TraceJSON(r.ToolChoice), + "thinking": TraceJSON(r.Thinking), + "stop_sequences": r.StopSequences, + "temperature": ptrVal(r.Temperature), + "top_p": ptrVal(r.TopP), + "top_k": ptrVal(r.TopK), + } +} + +// TraceMessagesResponse returns a compact trace representation of a MessagesResponse. +func TraceMessagesResponse(r MessagesResponse) map[string]any { + return map[string]any{ + "id": r.ID, + "model": r.Model, + "content": TraceJSON(r.Content), + "stop_reason": r.StopReason, + "usage": r.Usage, + } +} + +func traceMessageParams(msgs []MessageParam) []map[string]any { + out := make([]map[string]any, 0, len(msgs)) + for _, m := range msgs { + out = append(out, map[string]any{ + "role": m.Role, + "content": traceAnthropicContent(m.Content), + }) + } + return out +} + +func traceAnthropicContent(content any) any { + switch c := content.(type) { + case nil: + return nil + case string: + return TraceTruncateString(c) + case []any: + blocks := make([]any, 0, len(c)) + for _, block := range c { + blockMap, ok := block.(map[string]any) + if !ok { + blocks = append(blocks, TraceCompactValue(block, 0)) + continue + } + blocks = append(blocks, traceAnthropicBlock(blockMap)) + } + return blocks + default: + return TraceJSON(c) + } +} + +func traceAnthropicBlock(block map[string]any) map[string]any { + blockType, _ := block["type"].(string) + out := map[string]any{"type": blockType} + switch blockType { + case "text": + if text, ok := block["text"].(string); ok { + out["text"] = TraceTruncateString(text) + } else { + out["text"] = TraceCompactValue(block["text"], 0) + } + case "thinking": + if thinking, ok := block["thinking"].(string); ok { + out["thinking"] = TraceTruncateString(thinking) + } else { + out["thinking"] = TraceCompactValue(block["thinking"], 0) + } + case "tool_use", "server_tool_use": + out["id"] = block["id"] + out["name"] = block["name"] + out["input"] = TraceCompactValue(block["input"], 0) + case "tool_result", "web_search_tool_result": + out["tool_use_id"] = block["tool_use_id"] + out["content"] = TraceCompactValue(block["content"], 0) + case "image": + if source, ok := block["source"].(map[string]any); ok { + out["source"] = map[string]any{ + "type": source["type"], + "media_type": source["media_type"], + "url": source["url"], + "data_len": len(fmt.Sprint(source["data"])), + } + } + default: + out["block"] = TraceCompactValue(block, 0) + } + return out +} + +func traceTools(tools []Tool) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, t := range tools { + out = append(out, TraceTool(t)) + } + return out +} + +// TraceTool returns a compact trace representation of an Anthropic Tool. +func TraceTool(t Tool) map[string]any { + return map[string]any{ + "type": t.Type, + "name": t.Name, + "description": TraceTruncateString(t.Description), + "input_schema": TraceJSON(t.InputSchema), + "max_uses": t.MaxUses, + } +} + +// ContentBlockTypes returns the type strings from content (when it's []any blocks). +func ContentBlockTypes(content any) []string { + blocks, ok := content.([]any) + if !ok { + return nil + } + types := make([]string, 0, len(blocks)) + for _, block := range blocks { + blockMap, ok := block.(map[string]any) + if !ok { + types = append(types, fmt.Sprintf("%T", block)) + continue + } + t, _ := blockMap["type"].(string) + types = append(types, t) + } + return types +} + +func ptrVal[T any](v *T) any { + if v == nil { + return nil + } + return *v +} + +// --------------------------------------------------------------------------- +// Ollama api.* tracing (shared between anthropic and middleware packages) +// --------------------------------------------------------------------------- + +// TraceChatRequest returns a compact trace representation of an Ollama ChatRequest. +func TraceChatRequest(req *api.ChatRequest) map[string]any { + if req == nil { + return nil + } + stream := false + if req.Stream != nil { + stream = *req.Stream + } + return map[string]any{ + "model": req.Model, + "messages": TraceAPIMessages(req.Messages), + "tools": TraceAPITools(req.Tools), + "stream": stream, + "options": req.Options, + "think": TraceJSON(req.Think), + } +} + +// TraceChatResponse returns a compact trace representation of an Ollama ChatResponse. +func TraceChatResponse(resp api.ChatResponse) map[string]any { + return map[string]any{ + "model": resp.Model, + "done": resp.Done, + "done_reason": resp.DoneReason, + "message": TraceAPIMessage(resp.Message), + "metrics": TraceJSON(resp.Metrics), + } +} + +// TraceAPIMessages returns compact trace representations for a slice of api.Message. +func TraceAPIMessages(msgs []api.Message) []map[string]any { + out := make([]map[string]any, 0, len(msgs)) + for _, m := range msgs { + out = append(out, TraceAPIMessage(m)) + } + return out +} + +// TraceAPIMessage returns a compact trace representation of a single api.Message. +func TraceAPIMessage(m api.Message) map[string]any { + return map[string]any{ + "role": m.Role, + "content": TraceTruncateString(m.Content), + "thinking": TraceTruncateString(m.Thinking), + "images": traceImageSizes(m.Images), + "tool_calls": traceToolCalls(m.ToolCalls), + "tool_name": m.ToolName, + "tool_call_id": m.ToolCallID, + } +} + +func traceImageSizes(images []api.ImageData) []int { + if len(images) == 0 { + return nil + } + sizes := make([]int, 0, len(images)) + for _, img := range images { + sizes = append(sizes, len(img)) + } + return sizes +} + +// TraceAPITools returns compact trace representations for a slice of api.Tool. +func TraceAPITools(tools api.Tools) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, t := range tools { + out = append(out, TraceAPITool(t)) + } + return out +} + +// TraceAPITool returns a compact trace representation of a single api.Tool. +func TraceAPITool(t api.Tool) map[string]any { + return map[string]any{ + "type": t.Type, + "name": t.Function.Name, + "description": TraceTruncateString(t.Function.Description), + "parameters": TraceJSON(t.Function.Parameters), + } +} + +// TraceToolCall returns a compact trace representation of an api.ToolCall. +func TraceToolCall(tc api.ToolCall) map[string]any { + return map[string]any{ + "id": tc.ID, + "name": tc.Function.Name, + "args": TraceJSON(tc.Function.Arguments), + } +} + +func traceToolCalls(tcs []api.ToolCall) []map[string]any { + if len(tcs) == 0 { + return nil + } + out := make([]map[string]any, 0, len(tcs)) + for _, tc := range tcs { + out = append(out, TraceToolCall(tc)) + } + return out +} diff --git a/middleware/anthropic.go b/middleware/anthropic.go index 5df87a84a..85c95e60c 100644 --- a/middleware/anthropic.go +++ b/middleware/anthropic.go @@ -2,15 +2,22 @@ package middleware import ( "bytes" + "context" "encoding/json" "fmt" "io" + "log/slog" "net/http" + "strings" + "time" "github.com/gin-gonic/gin" "github.com/ollama/ollama/anthropic" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" + internalcloud "github.com/ollama/ollama/internal/cloud" + "github.com/ollama/ollama/logutil" ) // AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format @@ -18,7 +25,6 @@ type AnthropicWriter struct { BaseWriter stream bool id string - model string converter *anthropic.StreamConverter } @@ -31,7 +37,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) { } w.ResponseWriter.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error)) + err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.Status(), errData.Error)) if err != nil { return 0, err } @@ -40,18 +46,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) { } func (w *AnthropicWriter) writeEvent(eventType string, data any) error { - d, err := json.Marshal(data) - if err != nil { - return err - } - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d))) - if err != nil { - return err - } - if f, ok := w.ResponseWriter.(http.Flusher); ok { - f.Flush() - } - return nil + return writeSSE(w.ResponseWriter, eventType, data) } func (w *AnthropicWriter) writeResponse(data []byte) (int, error) { @@ -65,6 +60,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) { w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") events := w.converter.Process(chatResponse) + logutil.Trace("anthropic middleware: stream chunk", "resp", anthropic.TraceChatResponse(chatResponse), "events", len(events)) for _, event := range events { if err := w.writeEvent(event.Event, event.Data); err != nil { return 0, err @@ -75,6 +71,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) { w.ResponseWriter.Header().Set("Content-Type", "application/json") response := anthropic.ToMessagesResponse(w.id, chatResponse) + logutil.Trace("anthropic middleware: converted response", "resp", anthropic.TraceMessagesResponse(response)) return len(data), json.NewEncoder(w.ResponseWriter).Encode(response) } @@ -87,9 +84,743 @@ func (w *AnthropicWriter) Write(data []byte) (int, error) { return w.writeResponse(data) } +// WebSearchAnthropicWriter intercepts responses containing web_search tool calls, +// executes the search, re-invokes the model with results, and assembles the +// Anthropic-format response (server_tool_use + web_search_tool_result + text). +type WebSearchAnthropicWriter struct { + BaseWriter + newLoopContext func() (context.Context, context.CancelFunc) + inner *AnthropicWriter + req anthropic.MessagesRequest // original Anthropic request + chatReq *api.ChatRequest // converted Ollama request (for followup calls) + stream bool + + estimatedInputTokens int + + terminalSent bool + + observedPromptEvalCount int + observedEvalCount int + + loopInFlight bool + loopBaseInputTok int + loopBaseOutputTok int + loopResultCh chan webSearchLoopResult + + streamMessageStarted bool + streamHasOpenBlock bool + streamOpenBlockIndex int + streamNextIndex int +} + +const maxWebSearchLoops = 3 + +type webSearchLoopResult struct { + response anthropic.MessagesResponse + loopErr *webSearchLoopError +} + +type webSearchLoopError struct { + code string + query string + usage anthropic.Usage + err error +} + +func (e *webSearchLoopError) Error() string { + if e.err == nil { + return e.code + } + return fmt.Sprintf("%s: %v", e.code, e.err) +} + +func (w *WebSearchAnthropicWriter) Write(data []byte) (int, error) { + if w.terminalSent { + return len(data), nil + } + + code := w.Status() + if code != http.StatusOK { + return w.inner.writeError(data) + } + + var chatResponse api.ChatResponse + if err := json.Unmarshal(data, &chatResponse); err != nil { + return 0, err + } + w.recordObservedUsage(chatResponse.Metrics) + + if w.stream && w.loopInFlight { + if !chatResponse.Done { + return len(data), nil + } + if err := w.writeLoopResult(); err != nil { + return len(data), err + } + return len(data), nil + } + + webSearchCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(chatResponse.Message.ToolCalls) + logutil.Trace("anthropic middleware: upstream chunk", + "resp", anthropic.TraceChatResponse(chatResponse), + "web_search", hasWebSearch, + "other_tools", hasOtherTools, + ) + if hasWebSearch && hasOtherTools { + // Prefer web_search if both server and client tools are present in one chunk. + slog.Debug("preferring web_search tool call over client tool calls in mixed tool response") + } + + if !hasWebSearch { + if w.stream { + if err := w.writePassthroughStreamChunk(chatResponse); err != nil { + return 0, err + } + return len(data), nil + } + return w.inner.writeResponse(data) + } + + if w.stream { + // Let the original generation continue to completion while web search runs in parallel. + logutil.Trace("anthropic middleware: starting async web_search loop", + "tool_call", anthropic.TraceToolCall(webSearchCall), + "resp", anthropic.TraceChatResponse(chatResponse), + ) + w.startLoopWorker(chatResponse, webSearchCall) + if chatResponse.Done { + if err := w.writeLoopResult(); err != nil { + return len(data), err + } + } + return len(data), nil + } + + loopCtx, cancel := w.startLoopContext() + defer cancel() + + initialUsage := anthropic.Usage{ + InputTokens: max(w.observedPromptEvalCount, chatResponse.Metrics.PromptEvalCount), + OutputTokens: max(w.observedEvalCount, chatResponse.Metrics.EvalCount), + } + logutil.Trace("anthropic middleware: starting sync web_search loop", + "tool_call", anthropic.TraceToolCall(webSearchCall), + "resp", anthropic.TraceChatResponse(chatResponse), + "usage", initialUsage, + ) + response, loopErr := w.runWebSearchLoop(loopCtx, chatResponse, webSearchCall, initialUsage) + if loopErr != nil { + return len(data), w.sendError(loopErr.code, loopErr.query, loopErr.usage) + } + + if err := w.writeTerminalResponse(response); err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initialResponse api.ChatResponse, initialToolCall api.ToolCall, initialUsage anthropic.Usage) (anthropic.MessagesResponse, *webSearchLoopError) { + followUpMessages := make([]api.Message, 0, len(w.chatReq.Messages)+maxWebSearchLoops*2) + followUpMessages = append(followUpMessages, w.chatReq.Messages...) + + followUpTools := append(api.Tools(nil), w.chatReq.Tools...) + usage := initialUsage + logutil.TraceContext(ctx, "anthropic middleware: web_search loop init", + "model", w.req.Model, + "tool_call", anthropic.TraceToolCall(initialToolCall), + "messages", len(followUpMessages), + "tools", len(followUpTools), + "max_loops", maxWebSearchLoops, + ) + + currentResponse := initialResponse + currentToolCall := initialToolCall + + var serverContent []anthropic.ContentBlock + + if !isCloudModelName(w.req.Model) { + logutil.TraceContext(ctx, "anthropic middleware: web_search execution blocked", "reason", "non_cloud_model") + return anthropic.MessagesResponse{}, &webSearchLoopError{ + code: "web_search_not_supported_for_local_models", + query: extractQueryFromToolCall(&initialToolCall), + usage: usage, + } + } + + for loop := 1; loop <= maxWebSearchLoops; loop++ { + query := extractQueryFromToolCall(¤tToolCall) + logutil.TraceContext(ctx, "anthropic middleware: web_search loop iteration", + "loop", loop, + "query", anthropic.TraceTruncateString(query), + "messages", len(followUpMessages), + ) + if query == "" { + return anthropic.MessagesResponse{}, &webSearchLoopError{ + code: "invalid_request", + query: "", + usage: usage, + } + } + + const defaultMaxResults = 5 + searchResp, err := anthropic.WebSearch(ctx, query, defaultMaxResults) + if err != nil { + logutil.TraceContext(ctx, "anthropic middleware: web_search request failed", + "loop", loop, + "query", query, + "error", err, + ) + return anthropic.MessagesResponse{}, &webSearchLoopError{ + code: "unavailable", + query: query, + usage: usage, + err: err, + } + } + logutil.TraceContext(ctx, "anthropic middleware: web_search results", + "loop", loop, + "results", len(searchResp.Results), + ) + + toolUseID := loopServerToolUseID(w.inner.id, loop) + searchResults := anthropic.ConvertOllamaToAnthropicResults(searchResp) + serverContent = append(serverContent, + anthropic.ContentBlock{ + Type: "server_tool_use", + ID: toolUseID, + Name: "web_search", + Input: map[string]any{"query": query}, + }, + anthropic.ContentBlock{ + Type: "web_search_tool_result", + ToolUseID: toolUseID, + Content: searchResults, + }, + ) + + assistantMsg := buildWebSearchAssistantMessage(currentResponse, currentToolCall) + toolResultMsg := api.Message{ + Role: "tool", + Content: formatWebSearchResultsForToolMessage(searchResp.Results), + ToolCallID: currentToolCall.ID, + } + followUpMessages = append(followUpMessages, assistantMsg, toolResultMsg) + + followUpResponse, err := w.callFollowUpChat(ctx, followUpMessages, followUpTools) + if err != nil { + logutil.TraceContext(ctx, "anthropic middleware: followup /api/chat failed", + "loop", loop, + "query", query, + "error", err, + ) + return anthropic.MessagesResponse{}, &webSearchLoopError{ + code: "api_error", + query: query, + usage: usage, + err: err, + } + } + logutil.TraceContext(ctx, "anthropic middleware: followup response", + "loop", loop, + "resp", anthropic.TraceChatResponse(followUpResponse), + ) + + usage.InputTokens += followUpResponse.Metrics.PromptEvalCount + usage.OutputTokens += followUpResponse.Metrics.EvalCount + + nextToolCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(followUpResponse.Message.ToolCalls) + if hasWebSearch && hasOtherTools { + // Prefer web_search if both server and client tools are present in one chunk. + slog.Debug("preferring web_search tool call over client tool calls in mixed followup response") + } + + if !hasWebSearch { + finalResponse := w.combineServerAndFinalContent(serverContent, followUpResponse, usage) + logutil.TraceContext(ctx, "anthropic middleware: web_search loop complete", + "loop", loop, + "resp", anthropic.TraceMessagesResponse(finalResponse), + ) + return finalResponse, nil + } + + currentResponse = followUpResponse + currentToolCall = nextToolCall + } + + maxLoopQuery := extractQueryFromToolCall(¤tToolCall) + maxLoopToolUseID := loopServerToolUseID(w.inner.id, maxWebSearchLoops+1) + serverContent = append(serverContent, + anthropic.ContentBlock{ + Type: "server_tool_use", + ID: maxLoopToolUseID, + Name: "web_search", + Input: map[string]any{"query": maxLoopQuery}, + }, + anthropic.ContentBlock{ + Type: "web_search_tool_result", + ToolUseID: maxLoopToolUseID, + Content: anthropic.WebSearchToolResultError{ + Type: "web_search_tool_result_error", + ErrorCode: "max_uses_exceeded", + }, + }, + ) + + maxResponse := anthropic.MessagesResponse{ + ID: w.inner.id, + Type: "message", + Role: "assistant", + Model: w.req.Model, + Content: serverContent, + StopReason: "end_turn", + Usage: usage, + } + logutil.TraceContext(ctx, "anthropic middleware: web_search loop max reached", + "resp", anthropic.TraceMessagesResponse(maxResponse), + ) + return maxResponse, nil +} + +func (w *WebSearchAnthropicWriter) startLoopWorker(initialResponse api.ChatResponse, initialToolCall api.ToolCall) { + if w.loopInFlight { + return + } + + initialUsage := anthropic.Usage{ + InputTokens: max(w.observedPromptEvalCount, initialResponse.Metrics.PromptEvalCount), + OutputTokens: max(w.observedEvalCount, initialResponse.Metrics.EvalCount), + } + w.loopBaseInputTok = initialUsage.InputTokens + w.loopBaseOutputTok = initialUsage.OutputTokens + w.loopResultCh = make(chan webSearchLoopResult, 1) + w.loopInFlight = true + logutil.Trace("anthropic middleware: loop worker started", + "usage", initialUsage, + "tool_call", anthropic.TraceToolCall(initialToolCall), + ) + + go func() { + ctx, cancel := w.startLoopContext() + defer cancel() + + response, loopErr := w.runWebSearchLoop(ctx, initialResponse, initialToolCall, initialUsage) + w.loopResultCh <- webSearchLoopResult{ + response: response, + loopErr: loopErr, + } + }() +} + +func (w *WebSearchAnthropicWriter) writeLoopResult() error { + if w.loopResultCh == nil { + return w.sendError("api_error", "", w.currentObservedUsage()) + } + + result := <-w.loopResultCh + w.loopResultCh = nil + w.loopInFlight = false + if result.loopErr != nil { + logutil.Trace("anthropic middleware: loop worker returned error", + "code", result.loopErr.code, + "query", result.loopErr.query, + "usage", result.loopErr.usage, + "error", result.loopErr.err, + ) + usage := result.loopErr.usage + w.applyObservedUsageDeltaToUsage(&usage) + return w.sendError(result.loopErr.code, result.loopErr.query, usage) + } + logutil.Trace("anthropic middleware: loop worker done", "resp", anthropic.TraceMessagesResponse(result.response)) + + w.applyObservedUsageDelta(&result.response) + return w.writeTerminalResponse(result.response) +} + +func (w *WebSearchAnthropicWriter) applyObservedUsageDelta(response *anthropic.MessagesResponse) { + w.applyObservedUsageDeltaToUsage(&response.Usage) +} + +func (w *WebSearchAnthropicWriter) recordObservedUsage(metrics api.Metrics) { + if metrics.PromptEvalCount > w.observedPromptEvalCount { + w.observedPromptEvalCount = metrics.PromptEvalCount + } + if metrics.EvalCount > w.observedEvalCount { + w.observedEvalCount = metrics.EvalCount + } +} + +func (w *WebSearchAnthropicWriter) applyObservedUsageDeltaToUsage(usage *anthropic.Usage) { + if deltaIn := w.observedPromptEvalCount - w.loopBaseInputTok; deltaIn > 0 { + usage.InputTokens += deltaIn + } + if deltaOut := w.observedEvalCount - w.loopBaseOutputTok; deltaOut > 0 { + usage.OutputTokens += deltaOut + } +} + +func (w *WebSearchAnthropicWriter) currentObservedUsage() anthropic.Usage { + return anthropic.Usage{ + InputTokens: w.observedPromptEvalCount, + OutputTokens: w.observedEvalCount, + } +} + +func (w *WebSearchAnthropicWriter) startLoopContext() (context.Context, context.CancelFunc) { + if w.newLoopContext != nil { + return w.newLoopContext() + } + return context.WithTimeout(context.Background(), 5*time.Minute) +} + +func (w *WebSearchAnthropicWriter) combineServerAndFinalContent(serverContent []anthropic.ContentBlock, finalResponse api.ChatResponse, usage anthropic.Usage) anthropic.MessagesResponse { + converted := anthropic.ToMessagesResponse(w.inner.id, finalResponse) + + content := make([]anthropic.ContentBlock, 0, len(serverContent)+len(converted.Content)) + content = append(content, serverContent...) + content = append(content, converted.Content...) + + return anthropic.MessagesResponse{ + ID: w.inner.id, + Type: "message", + Role: "assistant", + Model: w.req.Model, + Content: content, + StopReason: converted.StopReason, + StopSequence: converted.StopSequence, + Usage: usage, + } +} + +func buildWebSearchAssistantMessage(response api.ChatResponse, webSearchCall api.ToolCall) api.Message { + assistantMsg := api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{webSearchCall}, + } + if response.Message.Content != "" { + assistantMsg.Content = response.Message.Content + } + if response.Message.Thinking != "" { + assistantMsg.Thinking = response.Message.Thinking + } + return assistantMsg +} + +func formatWebSearchResultsForToolMessage(results []anthropic.OllamaWebSearchResult) string { + var resultText strings.Builder + for _, r := range results { + fmt.Fprintf(&resultText, "Title: %s\nURL: %s\n", r.Title, r.URL) + if r.Content != "" { + fmt.Fprintf(&resultText, "Content: %s\n", r.Content) + } + resultText.WriteString("\n") + } + return resultText.String() +} + +func findWebSearchToolCall(toolCalls []api.ToolCall) (api.ToolCall, bool, bool) { + var webSearchCall api.ToolCall + hasWebSearch := false + hasOtherTools := false + + for _, toolCall := range toolCalls { + if toolCall.Function.Name == "web_search" { + if !hasWebSearch { + webSearchCall = toolCall + hasWebSearch = true + } + continue + } + hasOtherTools = true + } + + return webSearchCall, hasWebSearch, hasOtherTools +} + +func loopServerToolUseID(messageID string, loop int) string { + base := serverToolUseID(messageID) + if loop <= 1 { + return base + } + return fmt.Sprintf("%s_%d", base, loop) +} + +func (w *WebSearchAnthropicWriter) callFollowUpChat(ctx context.Context, messages []api.Message, tools api.Tools) (api.ChatResponse, error) { + streaming := false + followUp := api.ChatRequest{ + Model: w.chatReq.Model, + Messages: messages, + Stream: &streaming, + Tools: tools, + Options: w.chatReq.Options, + } + + body, err := json.Marshal(followUp) + if err != nil { + return api.ChatResponse{}, err + } + + chatURL := envconfig.Host().String() + "/api/chat" + logutil.TraceContext(ctx, "anthropic middleware: followup request", + "url", chatURL, + "req", anthropic.TraceChatRequest(&followUp), + ) + httpReq, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(body)) + if err != nil { + return api.ChatResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return api.ChatResponse{}, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + logutil.TraceContext(ctx, "anthropic middleware: followup non-200 response", + "status", resp.StatusCode, + "response", strings.TrimSpace(string(respBody)), + ) + return api.ChatResponse{}, fmt.Errorf("followup /api/chat returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + + var chatResp api.ChatResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + return api.ChatResponse{}, err + } + logutil.TraceContext(ctx, "anthropic middleware: followup decoded", "resp", anthropic.TraceChatResponse(chatResp)) + + return chatResp, nil +} + +func (w *WebSearchAnthropicWriter) writePassthroughStreamChunk(chatResponse api.ChatResponse) error { + events := w.inner.converter.Process(chatResponse) + for _, event := range events { + switch e := event.Data.(type) { + case anthropic.MessageStartEvent: + w.streamMessageStarted = true + case anthropic.ContentBlockStartEvent: + w.streamHasOpenBlock = true + w.streamOpenBlockIndex = e.Index + if e.Index+1 > w.streamNextIndex { + w.streamNextIndex = e.Index + 1 + } + case anthropic.ContentBlockStopEvent: + if w.streamHasOpenBlock && w.streamOpenBlockIndex == e.Index { + w.streamHasOpenBlock = false + } + if e.Index+1 > w.streamNextIndex { + w.streamNextIndex = e.Index + 1 + } + case anthropic.MessageStopEvent: + w.terminalSent = true + } + + if err := writeSSE(w.ResponseWriter, event.Event, event.Data); err != nil { + return err + } + } + + return nil +} + +func (w *WebSearchAnthropicWriter) ensureStreamMessageStart(usage anthropic.Usage) error { + if w.streamMessageStarted { + return nil + } + + inputTokens := usage.InputTokens + if inputTokens == 0 { + inputTokens = w.estimatedInputTokens + } + + if err := writeSSE(w.ResponseWriter, "message_start", anthropic.MessageStartEvent{ + Type: "message_start", + Message: anthropic.MessagesResponse{ + ID: w.inner.id, + Type: "message", + Role: "assistant", + Model: w.req.Model, + Content: []anthropic.ContentBlock{}, + Usage: anthropic.Usage{ + InputTokens: inputTokens, + }, + }, + }); err != nil { + return err + } + + w.streamMessageStarted = true + return nil +} + +func (w *WebSearchAnthropicWriter) closeOpenStreamBlock() error { + if !w.streamHasOpenBlock { + return nil + } + + if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{ + Type: "content_block_stop", + Index: w.streamOpenBlockIndex, + }); err != nil { + return err + } + + if w.streamOpenBlockIndex+1 > w.streamNextIndex { + w.streamNextIndex = w.streamOpenBlockIndex + 1 + } + w.streamHasOpenBlock = false + return nil +} + +func (w *WebSearchAnthropicWriter) writeStreamContentBlocks(content []anthropic.ContentBlock) error { + for _, block := range content { + index := w.streamNextIndex + if block.Type == "text" { + emptyText := "" + if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{ + Type: "content_block_start", + Index: index, + ContentBlock: anthropic.ContentBlock{ + Type: "text", + Text: &emptyText, + }, + }); err != nil { + return err + } + + text := "" + if block.Text != nil { + text = *block.Text + } + if err := writeSSE(w.ResponseWriter, "content_block_delta", anthropic.ContentBlockDeltaEvent{ + Type: "content_block_delta", + Index: index, + Delta: anthropic.Delta{ + Type: "text_delta", + Text: text, + }, + }); err != nil { + return err + } + } else { + if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{ + Type: "content_block_start", + Index: index, + ContentBlock: block, + }); err != nil { + return err + } + } + + if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{ + Type: "content_block_stop", + Index: index, + }); err != nil { + return err + } + + w.streamNextIndex++ + } + + return nil +} + +func (w *WebSearchAnthropicWriter) writeTerminalResponse(response anthropic.MessagesResponse) error { + if w.terminalSent { + return nil + } + + if !w.stream { + w.ResponseWriter.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w.ResponseWriter).Encode(response); err != nil { + return err + } + w.terminalSent = true + return nil + } + + if err := w.ensureStreamMessageStart(response.Usage); err != nil { + return err + } + if err := w.closeOpenStreamBlock(); err != nil { + return err + } + if err := w.writeStreamContentBlocks(response.Content); err != nil { + return err + } + + if err := writeSSE(w.ResponseWriter, "message_delta", anthropic.MessageDeltaEvent{ + Type: "message_delta", + Delta: anthropic.MessageDelta{ + StopReason: response.StopReason, + }, + Usage: anthropic.DeltaUsage{ + InputTokens: response.Usage.InputTokens, + OutputTokens: response.Usage.OutputTokens, + }, + }); err != nil { + return err + } + + if err := writeSSE(w.ResponseWriter, "message_stop", anthropic.MessageStopEvent{ + Type: "message_stop", + }); err != nil { + return err + } + + w.terminalSent = true + return nil +} + +// streamResponse emits a complete MessagesResponse as SSE events. +func (w *WebSearchAnthropicWriter) streamResponse(response anthropic.MessagesResponse) error { + return w.writeTerminalResponse(response) +} + +func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query string, usage anthropic.Usage) anthropic.MessagesResponse { + toolUseID := serverToolUseID(w.inner.id) + + return anthropic.MessagesResponse{ + ID: w.inner.id, + Type: "message", + Role: "assistant", + Model: w.req.Model, + Content: []anthropic.ContentBlock{ + { + Type: "server_tool_use", + ID: toolUseID, + Name: "web_search", + Input: map[string]any{"query": query}, + }, + { + Type: "web_search_tool_result", + ToolUseID: toolUseID, + Content: anthropic.WebSearchToolResultError{ + Type: "web_search_tool_result_error", + ErrorCode: errorCode, + }, + }, + }, + StopReason: "end_turn", + Usage: usage, + } +} + +// sendError sends a web search error response. +func (w *WebSearchAnthropicWriter) sendError(errorCode, query string, usage anthropic.Usage) error { + response := w.webSearchErrorResponse(errorCode, query, usage) + logutil.Trace("anthropic middleware: web_search error", "code", errorCode, "query", query, "usage", usage) + return w.writeTerminalResponse(response) +} + // AnthropicMessagesMiddleware handles Anthropic Messages API requests func AnthropicMessagesMiddleware() gin.HandlerFunc { return func(c *gin.Context) { + requestCtx := c.Request.Context() + var req anthropic.MessagesRequest err := c.ShouldBindJSON(&req) if err != nil { @@ -134,11 +865,10 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc { // Estimate input tokens for streaming (actual count not available until generation completes) estimatedTokens := anthropic.EstimateInputTokens(req) - w := &AnthropicWriter{ + innerWriter := &AnthropicWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, stream: req.Stream, id: messageID, - model: req.Model, converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens), } @@ -148,8 +878,78 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc { c.Writer.Header().Set("Connection", "keep-alive") } - c.Writer = w + if hasWebSearchTool(req.Tools) { + // Guard against runtime cloud-disable policy (OLLAMA_NO_CLOUD/server.json) + // for cloud models. Local models may still receive web_search tool definitions; + // execution is validated when the model actually emits a web_search tool call. + if isCloudModelName(req.Model) { + if disabled, _ := internalcloud.Status(); disabled { + c.AbortWithStatusJSON(http.StatusForbidden, anthropic.NewError(http.StatusForbidden, internalcloud.DisabledError("web search is unavailable"))) + return + } + } + + c.Writer = &WebSearchAnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + newLoopContext: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(requestCtx, 5*time.Minute) + }, + inner: innerWriter, + req: req, + chatReq: chatReq, + stream: req.Stream, + estimatedInputTokens: estimatedTokens, + } + } else { + c.Writer = innerWriter + } c.Next() } } + +// hasWebSearchTool checks if the request tools include a web_search tool +func hasWebSearchTool(tools []anthropic.Tool) bool { + for _, tool := range tools { + if strings.HasPrefix(tool.Type, "web_search") { + return true + } + } + return false +} + +func isCloudModelName(name string) bool { + return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud") +} + +// extractQueryFromToolCall extracts the search query from a web_search tool call +func extractQueryFromToolCall(tc *api.ToolCall) string { + q, ok := tc.Function.Arguments.Get("query") + if !ok { + return "" + } + if s, ok := q.(string); ok { + return s + } + return "" +} + +// writeSSE writes a Server-Sent Event +func writeSSE(w http.ResponseWriter, eventType string, data any) error { + d, err := json.Marshal(data) + if err != nil { + return err + } + if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, d); err != nil { + return err + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return nil +} + +// serverToolUseID derives a server tool use ID from a message ID +func serverToolUseID(messageID string) string { + return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_") +} diff --git a/middleware/anthropic_test.go b/middleware/anthropic_test.go index a913fd3c4..dacdd5ce6 100644 --- a/middleware/anthropic_test.go +++ b/middleware/anthropic_test.go @@ -605,3 +605,2375 @@ func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) { t.Error("expected relax_thinking flag to be set in context") } } + +// Web Search Tests + +func TestHasWebSearchTool(t *testing.T) { + tests := []struct { + name string + tools []anthropic.Tool + expected bool + }{ + { + name: "no tools", + tools: nil, + expected: false, + }, + { + name: "regular tool only", + tools: []anthropic.Tool{ + {Type: "custom", Name: "get_weather"}, + }, + expected: false, + }, + { + name: "web search tool", + tools: []anthropic.Tool{ + {Type: "web_search_20250305", Name: "web_search"}, + }, + expected: true, + }, + { + name: "mixed tools", + tools: []anthropic.Tool{ + {Type: "custom", Name: "get_weather"}, + {Type: "web_search_20250305", Name: "web_search"}, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := hasWebSearchTool(tt.tools) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestExtractQueryFromToolCall(t *testing.T) { + tests := []struct { + name string + tc *api.ToolCall + expected string + }{ + { + name: "valid query", + tc: &api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "test search"), + }, + }, + expected: "test search", + }, + { + name: "empty arguments", + tc: &api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "web_search", + }, + }, + expected: "", + }, + { + name: "no query key", + tc: &api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("other", "value"), + }, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractQueryFromToolCall(tt.tc) + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +// makeArgs is a test helper that creates ToolCallFunctionArguments +func makeArgs(key string, value any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + args.Set(key, value) + return args +} + +// --- Web Search Integration Tests --- + +// TestWebSearchServerToolUseID tests the ID derivation logic. +func TestWebSearchServerToolUseID(t *testing.T) { + tests := []struct { + msgID string + expected string + }{ + {"msg_abc123", "srvtoolu_abc123"}, + {"msg_", "srvtoolu_"}, + {"nomsgprefix", "srvtoolu_nomsgprefix"}, + } + for _, tt := range tests { + got := serverToolUseID(tt.msgID) + if got != tt.expected { + t.Errorf("serverToolUseID(%q) = %q, want %q", tt.msgID, got, tt.expected) + } + } +} + +// TestWebSearchNoWebSearchTool verifies that when there is no web_search tool, +// requests pass through to the normal AnthropicWriter without interception. +func TestWebSearchNoWebSearchTool(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "Normal response", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model":"test-model","max_tokens":100,"messages":[{"role":"user","content":"Hello"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if result.Type != "message" { + t.Errorf("expected type 'message', got %q", result.Type) + } + if len(result.Content) != 1 || result.Content[0].Type != "text" { + t.Fatalf("expected single text block, got %d blocks", len(result.Content)) + } + if *result.Content[0].Text != "Normal response" { + t.Errorf("expected text 'Normal response', got %q", *result.Content[0].Text) + } +} + +// TestWebSearchToolPresent_ModelDoesNotCallIt_NonStreaming verifies that when +// the web_search tool is present but the model does not call it, the response +// passes through normally (non-streaming case). +func TestWebSearchToolPresent_ModelDoesNotCallIt_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "I can answer that without searching.", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 8}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"What is 2+2?"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if result.Type != "message" { + t.Errorf("expected type 'message', got %q", result.Type) + } + if len(result.Content) != 1 || result.Content[0].Type != "text" { + t.Fatalf("expected single text block, got %+v", result.Content) + } + if *result.Content[0].Text != "I can answer that without searching." { + t.Errorf("unexpected text: %q", *result.Content[0].Text) + } + if result.StopReason != "end_turn" { + t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason) + } +} + +// TestWebSearchToolPresent_ModelDoesNotCallIt_Streaming verifies the streaming +// pass-through case when the model does not invoke web_search. +func TestWebSearchToolPresent_ModelDoesNotCallIt_Streaming(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + // Simulate streaming: two partial chunks then a final chunk + chunks := []api.ChatResponse{ + { + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "Hello "}, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "world"}, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{Role: "assistant", Content: ""}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + }, + } + c.Writer.WriteHeader(http.StatusOK) + for _, chunk := range chunks { + data, _ := json.Marshal(chunk) + _, _ = c.Writer.Write(data) + } + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "stream":true, + "messages":[{"role":"user","content":"Hi"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + // Parse SSE events + events := parseSSEEvents(t, resp.Body.String()) + + // Should have standard streaming event flow + if len(events) == 0 { + t.Fatal("expected SSE events, got none") + } + + // First event should be message_start + if events[0].event != "message_start" { + t.Errorf("first event should be message_start, got %q", events[0].event) + } + + // Should have content_block_start for text + hasTextStart := false + hasTextDelta := false + hasMessageStop := false + for _, e := range events { + if e.event == "content_block_start" { + var cbs anthropic.ContentBlockStartEvent + if err := json.Unmarshal([]byte(e.data), &cbs); err == nil { + if cbs.ContentBlock.Type == "text" { + hasTextStart = true + } + } + } + if e.event == "content_block_delta" { + var cbd anthropic.ContentBlockDeltaEvent + if err := json.Unmarshal([]byte(e.data), &cbd); err == nil { + if cbd.Delta.Type == "text_delta" { + hasTextDelta = true + } + } + } + if e.event == "message_stop" { + hasMessageStop = true + } + } + if !hasTextStart { + t.Error("expected content_block_start with text type") + } + if !hasTextDelta { + t.Error("expected content_block_delta with text_delta") + } + if !hasMessageStop { + t.Error("expected message_stop event") + } +} + +// TestWebSearchToolPresent_ModelCallsIt_NonStreaming tests the full web search flow +// in non-streaming mode. It mocks the followup /api/chat call using a local HTTP server. +func TestWebSearchToolPresent_ModelCallsIt_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + // Create a mock Ollama server that responds to the followup /api/chat call + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "Based on my search, the answer is 42.", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 50, EvalCount: 20}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + + // Set OLLAMA_HOST to our mock server so the followup call goes there + t.Setenv("OLLAMA_HOST", followupServer.URL) + + // Also mock the web search API + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Test Result", URL: "https://example.com/result", Content: "Some content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + + // Point DoWebSearch at our mock search server + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_001", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "meaning of life"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 15, EvalCount: 3}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"What is the meaning of life?"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v\nbody: %s", err, resp.Body.String()) + } + + if result.Type != "message" { + t.Errorf("expected type 'message', got %q", result.Type) + } + if result.Role != "assistant" { + t.Errorf("expected role 'assistant', got %q", result.Role) + } + + // Should have 3 blocks: server_tool_use + web_search_tool_result + text + if len(result.Content) != 3 { + t.Fatalf("expected 3 content blocks, got %d: %+v", len(result.Content), result.Content) + } + + if result.Content[0].Type != "server_tool_use" { + t.Errorf("expected first block type 'server_tool_use', got %q", result.Content[0].Type) + } + if result.Content[0].Name != "web_search" { + t.Errorf("expected name 'web_search', got %q", result.Content[0].Name) + } + + if result.Content[1].Type != "web_search_tool_result" { + t.Errorf("expected second block type 'web_search_tool_result', got %q", result.Content[1].Type) + } + if result.Content[1].ToolUseID != result.Content[0].ID { + t.Errorf("tool_use_id mismatch: %q != %q", result.Content[1].ToolUseID, result.Content[0].ID) + } + + if result.Content[2].Type != "text" { + t.Errorf("expected third block type 'text', got %q", result.Content[2].Type) + } + if result.Content[2].Text == nil || *result.Content[2].Text == "" { + t.Error("expected non-empty text in third block") + } + + if result.StopReason != "end_turn" { + t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason) + } +} + +// TestWebSearchToolPresent_ModelCallsIt_Streaming tests the streaming SSE output +// when the model calls web_search with mocked search and followup endpoints. +func TestWebSearchToolPresent_ModelCallsIt_Streaming(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + // Mock followup /api/chat server + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "Here are the latest news."}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 40, EvalCount: 15}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + // Mock web search API + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "News Result", URL: "https://example.com/news", Content: "Breaking news"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + // Simulate buffered streaming: non-final chunk then final with tool call + chunks := []api.ChatResponse{ + { + Model: "test-model", + Message: api.Message{Role: "assistant"}, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_002", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "latest news"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2}, + }, + } + c.Writer.WriteHeader(http.StatusOK) + for _, chunk := range chunks { + data, _ := json.Marshal(chunk) + _, _ = c.Writer.Write(data) + } + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "stream":true, + "messages":[{"role":"user","content":"What is the latest news?"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + events := parseSSEEvents(t, resp.Body.String()) + + // Success path: 10 events (3 blocks: server_tool_use, web_search_tool_result, text with delta) + expectedEventTypes := []string{ + "message_start", + "content_block_start", // server_tool_use + "content_block_stop", + "content_block_start", // web_search_tool_result + "content_block_stop", + "content_block_start", // text (empty) + "content_block_delta", // text_delta with actual content + "content_block_stop", + "message_delta", + "message_stop", + } + + if len(events) != len(expectedEventTypes) { + t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events)) + } + + for i, expected := range expectedEventTypes { + if events[i].event != expected { + t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event) + } + } + + // Verify text delta has the followup model's content + var textDelta anthropic.ContentBlockDeltaEvent + if err := json.Unmarshal([]byte(events[6].data), &textDelta); err != nil { + t.Fatalf("failed to parse text delta: %v", err) + } + if textDelta.Delta.Type != "text_delta" { + t.Errorf("expected delta type 'text_delta', got %q", textDelta.Delta.Type) + } + if textDelta.Delta.Text != "Here are the latest news." { + t.Errorf("expected followup text, got %q", textDelta.Delta.Text) + } +} + +// TestWebSearchStreamResponse tests the streamResponse method directly by constructing +// a WebSearchAnthropicWriter and calling streamResponse with a known response. +func TestWebSearchStreamResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + text := "Here is the answer." + + response := anthropic.MessagesResponse{ + ID: "msg_test123", + Type: "message", + Role: "assistant", + Model: "test-model", + Content: []anthropic.ContentBlock{ + { + Type: "server_tool_use", + ID: "srvtoolu_test123", + Name: "web_search", + Input: map[string]any{"query": "test query"}, + }, + { + Type: "web_search_tool_result", + ToolUseID: "srvtoolu_test123", + Content: []anthropic.WebSearchResult{ + {Type: "web_search_result", URL: "https://example.com", Title: "Example"}, + }, + }, + { + Type: "text", + Text: &text, + }, + }, + StopReason: "end_turn", + Usage: anthropic.Usage{InputTokens: 20, OutputTokens: 10}, + } + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + innerWriter := &AnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer}, + stream: true, + id: "msg_test123", + } + wsWriter := &WebSearchAnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer}, + inner: innerWriter, + stream: true, + req: anthropic.MessagesRequest{Model: "test-model"}, + } + + if err := wsWriter.streamResponse(response); err != nil { + t.Fatalf("streamResponse error: %v", err) + } + + events := parseSSEEvents(t, rec.Body.String()) + + // Verify full event sequence + expectedEventTypes := []string{ + "message_start", + "content_block_start", // server_tool_use (index 0) + "content_block_stop", // index 0 + "content_block_start", // web_search_tool_result (index 1) + "content_block_stop", // index 1 + "content_block_start", // text (index 2) + "content_block_delta", // text_delta + "content_block_stop", // index 2 + "message_delta", + "message_stop", + } + + if len(events) != len(expectedEventTypes) { + t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events)) + } + + for i, expected := range expectedEventTypes { + if events[i].event != expected { + t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event) + } + } + + // Verify message_start content + var msgStart anthropic.MessageStartEvent + if err := json.Unmarshal([]byte(events[0].data), &msgStart); err != nil { + t.Fatalf("failed to parse message_start: %v", err) + } + if msgStart.Message.ID != "msg_test123" { + t.Errorf("expected message ID 'msg_test123', got %q", msgStart.Message.ID) + } + if msgStart.Message.Role != "assistant" { + t.Errorf("expected role 'assistant', got %q", msgStart.Message.Role) + } + if len(msgStart.Message.Content) != 0 { + t.Errorf("expected empty content in message_start, got %d blocks", len(msgStart.Message.Content)) + } + + // Verify content_block_start for server_tool_use (event index 1) + var toolStart anthropic.ContentBlockStartEvent + if err := json.Unmarshal([]byte(events[1].data), &toolStart); err != nil { + t.Fatalf("failed to parse server_tool_use start: %v", err) + } + if toolStart.Index != 0 { + t.Errorf("expected index 0, got %d", toolStart.Index) + } + if toolStart.ContentBlock.Type != "server_tool_use" { + t.Errorf("expected type 'server_tool_use', got %q", toolStart.ContentBlock.Type) + } + if toolStart.ContentBlock.ID != "srvtoolu_test123" { + t.Errorf("expected ID 'srvtoolu_test123', got %q", toolStart.ContentBlock.ID) + } + + // Verify content_block_start for web_search_tool_result (event index 3) + var searchStart anthropic.ContentBlockStartEvent + if err := json.Unmarshal([]byte(events[3].data), &searchStart); err != nil { + t.Fatalf("failed to parse web_search_tool_result start: %v", err) + } + if searchStart.Index != 1 { + t.Errorf("expected index 1, got %d", searchStart.Index) + } + if searchStart.ContentBlock.Type != "web_search_tool_result" { + t.Errorf("expected type 'web_search_tool_result', got %q", searchStart.ContentBlock.Type) + } + + // Verify text block: content_block_start (event index 5) + var textStart anthropic.ContentBlockStartEvent + if err := json.Unmarshal([]byte(events[5].data), &textStart); err != nil { + t.Fatalf("failed to parse text start: %v", err) + } + if textStart.Index != 2 { + t.Errorf("expected index 2, got %d", textStart.Index) + } + if textStart.ContentBlock.Type != "text" { + t.Errorf("expected type 'text', got %q", textStart.ContentBlock.Type) + } + // Text in start should be empty + if textStart.ContentBlock.Text == nil || *textStart.ContentBlock.Text != "" { + t.Errorf("expected empty text in content_block_start, got %v", textStart.ContentBlock.Text) + } + + // Verify text delta (event index 6) + var textDelta anthropic.ContentBlockDeltaEvent + if err := json.Unmarshal([]byte(events[6].data), &textDelta); err != nil { + t.Fatalf("failed to parse text delta: %v", err) + } + if textDelta.Index != 2 { + t.Errorf("expected index 2, got %d", textDelta.Index) + } + if textDelta.Delta.Type != "text_delta" { + t.Errorf("expected delta type 'text_delta', got %q", textDelta.Delta.Type) + } + if textDelta.Delta.Text != "Here is the answer." { + t.Errorf("expected delta text 'Here is the answer.', got %q", textDelta.Delta.Text) + } + + // Verify message_delta (event index 8) + var msgDelta anthropic.MessageDeltaEvent + if err := json.Unmarshal([]byte(events[8].data), &msgDelta); err != nil { + t.Fatalf("failed to parse message_delta: %v", err) + } + if msgDelta.Delta.StopReason != "end_turn" { + t.Errorf("expected stop_reason 'end_turn', got %q", msgDelta.Delta.StopReason) + } + if msgDelta.Usage.InputTokens != 20 { + t.Errorf("expected input_tokens 20, got %d", msgDelta.Usage.InputTokens) + } + if msgDelta.Usage.OutputTokens != 10 { + t.Errorf("expected output_tokens 10, got %d", msgDelta.Usage.OutputTokens) + } +} + +// TestWebSearchSendError_NonStreaming tests sendError produces correct response shape. +func TestWebSearchSendError_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + innerWriter := &AnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer}, + stream: false, + id: "msg_err001", + } + wsWriter := &WebSearchAnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer}, + inner: innerWriter, + stream: false, + req: anthropic.MessagesRequest{Model: "test-model"}, + } + + errorUsage := anthropic.Usage{InputTokens: 7, OutputTokens: 2} + if err := wsWriter.sendError("unavailable", "test query", errorUsage); err != nil { + t.Fatalf("sendError error: %v", err) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v\nbody: %s", err, rec.Body.String()) + } + + if result.Type != "message" { + t.Errorf("expected type 'message', got %q", result.Type) + } + if result.ID != "msg_err001" { + t.Errorf("expected ID 'msg_err001', got %q", result.ID) + } + + // Should have exactly 2 blocks: server_tool_use + web_search_tool_result + if len(result.Content) != 2 { + t.Fatalf("expected 2 content blocks, got %d", len(result.Content)) + } + + // Block 0: server_tool_use + if result.Content[0].Type != "server_tool_use" { + t.Errorf("expected 'server_tool_use', got %q", result.Content[0].Type) + } + expectedToolID := "srvtoolu_err001" + if result.Content[0].ID != expectedToolID { + t.Errorf("expected ID %q, got %q", expectedToolID, result.Content[0].ID) + } + if result.Content[0].Name != "web_search" { + t.Errorf("expected name 'web_search', got %q", result.Content[0].Name) + } + // Verify input contains the query + inputMap, ok := result.Content[0].Input.(map[string]any) + if !ok { + t.Fatalf("expected Input to be map, got %T", result.Content[0].Input) + } + if inputMap["query"] != "test query" { + t.Errorf("expected query 'test query', got %v", inputMap["query"]) + } + + // Block 1: web_search_tool_result with error + if result.Content[1].Type != "web_search_tool_result" { + t.Errorf("expected 'web_search_tool_result', got %q", result.Content[1].Type) + } + if result.Content[1].ToolUseID != expectedToolID { + t.Errorf("expected tool_use_id %q, got %q", expectedToolID, result.Content[1].ToolUseID) + } + + // The Content field should be a WebSearchToolResultError + contentJSON, _ := json.Marshal(result.Content[1].Content) + var errContent anthropic.WebSearchToolResultError + if err := json.Unmarshal(contentJSON, &errContent); err != nil { + t.Fatalf("failed to parse error content: %v\nraw: %s", err, string(contentJSON)) + } + if errContent.Type != "web_search_tool_result_error" { + t.Errorf("expected error type 'web_search_tool_result_error', got %q", errContent.Type) + } + if errContent.ErrorCode != "unavailable" { + t.Errorf("expected error_code 'unavailable', got %q", errContent.ErrorCode) + } + + if result.StopReason != "end_turn" { + t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason) + } + if result.Usage != errorUsage { + t.Errorf("expected usage %+v, got %+v", errorUsage, result.Usage) + } +} + +// TestWebSearchSendError_Streaming tests sendError in streaming mode produces proper SSE. +func TestWebSearchSendError_Streaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + innerWriter := &AnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer}, + stream: true, + id: "msg_err002", + } + wsWriter := &WebSearchAnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer}, + inner: innerWriter, + stream: true, + req: anthropic.MessagesRequest{Model: "test-model"}, + } + + errorUsage := anthropic.Usage{InputTokens: 9, OutputTokens: 4} + if err := wsWriter.sendError("invalid_request", "bad query", errorUsage); err != nil { + t.Fatalf("sendError error: %v", err) + } + + events := parseSSEEvents(t, rec.Body.String()) + + // Error response has 2 blocks: server_tool_use + web_search_tool_result + // Expected events: message_start, + // content_block_start(server_tool_use), content_block_stop, + // content_block_start(web_search_tool_result), content_block_stop, + // message_delta, message_stop + expectedEventTypes := []string{ + "message_start", + "content_block_start", + "content_block_stop", + "content_block_start", + "content_block_stop", + "message_delta", + "message_stop", + } + + if len(events) != len(expectedEventTypes) { + t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events)) + } + + for i, expected := range expectedEventTypes { + if events[i].event != expected { + t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event) + } + } + + // Verify the server_tool_use block + var toolStart anthropic.ContentBlockStartEvent + if err := json.Unmarshal([]byte(events[1].data), &toolStart); err != nil { + t.Fatalf("failed to parse server_tool_use start: %v", err) + } + if toolStart.ContentBlock.Type != "server_tool_use" { + t.Errorf("expected 'server_tool_use', got %q", toolStart.ContentBlock.Type) + } + + // Verify the web_search_tool_result block + var resultStart anthropic.ContentBlockStartEvent + if err := json.Unmarshal([]byte(events[3].data), &resultStart); err != nil { + t.Fatalf("failed to parse web_search_tool_result start: %v", err) + } + if resultStart.ContentBlock.Type != "web_search_tool_result" { + t.Errorf("expected 'web_search_tool_result', got %q", resultStart.ContentBlock.Type) + } + + var msgDelta anthropic.MessageDeltaEvent + if err := json.Unmarshal([]byte(events[5].data), &msgDelta); err != nil { + t.Fatalf("failed to parse message_delta: %v", err) + } + if msgDelta.Usage.InputTokens != errorUsage.InputTokens || msgDelta.Usage.OutputTokens != errorUsage.OutputTokens { + t.Fatalf("expected usage %+v in message_delta, got %+v", errorUsage, msgDelta.Usage) + } +} + +// TestWebSearchSendError_EmptyQuery tests sendError with an empty query. +func TestWebSearchSendError_EmptyQuery(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + innerWriter := &AnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer}, + stream: false, + id: "msg_empty001", + } + wsWriter := &WebSearchAnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer}, + inner: innerWriter, + stream: false, + req: anthropic.MessagesRequest{Model: "test-model"}, + } + + if err := wsWriter.sendError("invalid_request", "", anthropic.Usage{}); err != nil { + t.Fatalf("sendError error: %v", err) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if len(result.Content) != 2 { + t.Fatalf("expected 2 content blocks, got %d", len(result.Content)) + } + + // Verify the input has empty query + inputMap, ok := result.Content[0].Input.(map[string]any) + if !ok { + t.Fatalf("expected Input to be map, got %T", result.Content[0].Input) + } + if inputMap["query"] != "" { + t.Errorf("expected empty query, got %v", inputMap["query"]) + } +} + +// --- SSE parsing helpers --- + +type sseEvent struct { + event string + data string +} + +// parseSSEEvents parses Server-Sent Events from a string. +func parseSSEEvents(t *testing.T, body string) []sseEvent { + t.Helper() + var events []sseEvent + var currentEvent string + var currentData strings.Builder + + for _, line := range strings.Split(body, "\n") { + if strings.HasPrefix(line, "event: ") { + currentEvent = strings.TrimPrefix(line, "event: ") + } else if strings.HasPrefix(line, "data: ") { + currentData.WriteString(strings.TrimPrefix(line, "data: ")) + } else if line == "" && currentEvent != "" { + events = append(events, sseEvent{event: currentEvent, data: currentData.String()}) + currentEvent = "" + currentData.Reset() + } + } + return events +} + +// eventNames returns a list of event type names for debugging. +func eventNames(events []sseEvent) []string { + names := make([]string, len(events)) + for i, e := range events { + names[i] = e.event + } + return names +} + +// TestWebSearchCloudModelGating tests web_search behavior across model types. +func TestWebSearchCloudModelGating(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + t.Run("local model allowed when web_search is not called", func(t *testing.T) { + handlerCalled := false + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + handlerCalled = true + resp := api.ChatResponse{ + Model: "llama3.2", + Message: api.Message{Role: "assistant", Content: "hello"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + if !handlerCalled { + t.Error("handler should be called for local model when web_search is not called") + } + }) + + t.Run("local model emits web_search and gets structured error", func(t *testing.T) { + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "llama3.2", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_local_ws", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "hello"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 8, EvalCount: 2}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if len(result.Content) != 2 { + t.Fatalf("expected 2 content blocks for local model web_search error, got %d", len(result.Content)) + } + contentJSON, _ := json.Marshal(result.Content[1].Content) + var errContent anthropic.WebSearchToolResultError + if err := json.Unmarshal(contentJSON, &errContent); err != nil { + t.Fatalf("failed to parse web_search error content: %v", err) + } + if errContent.ErrorCode != "web_search_not_supported_for_local_models" { + t.Fatalf("expected web_search_not_supported_for_local_models, got %q", errContent.ErrorCode) + } + }) + + t.Run("model ending in cloud without cloud suffix treated as local", func(t *testing.T) { + handlerCalled := false + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + handlerCalled = true + resp := api.ChatResponse{ + Model: "notreallycloud", + Message: api.Message{Role: "assistant", Content: "hello"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model":"notreallycloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if !handlerCalled { + t.Error("handler should be called for non-cloud model when web_search is not called") + } + if resp.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + }) + + t.Run("cloud model with size tag allowed", func(t *testing.T) { + handlerCalled := false + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + handlerCalled = true + resp := api.ChatResponse{ + Model: "gpt-oss:120b", + Message: api.Message{Role: "assistant", Content: "hello"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model":"gpt-oss:120b-cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if !handlerCalled { + t.Error("handler should be called for cloud model") + } + if resp.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + }) + + t.Run("cloud model allowed", func(t *testing.T) { + handlerCalled := false + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + handlerCalled = true + resp := api.ChatResponse{ + Model: "kimi-k2.5", + Message: api.Message{Role: "assistant", Content: "hello"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model":"kimi-k2.5:cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if !handlerCalled { + t.Error("handler should be called for cloud model") + } + if resp.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + }) + + t.Run("cloud disabled blocks web search for cloud model", func(t *testing.T) { + t.Setenv("OLLAMA_NO_CLOUD", "1") + + handlerCalled := false + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + handlerCalled = true + }) + + body := `{"model":"kimi-k2.5:cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", resp.Code, resp.Body.String()) + } + if handlerCalled { + t.Fatal("handler should not be called when cloud is disabled") + } + + var errResp anthropic.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatalf("failed to parse error response: %v", err) + } + if !strings.Contains(errResp.Error.Message, "ollama cloud is disabled") { + t.Fatalf("expected cloud disabled error, got: %q", errResp.Error.Message) + } + }) + + t.Run("cloud disabled does not block local model if web_search is not called", func(t *testing.T) { + t.Setenv("OLLAMA_NO_CLOUD", "1") + + handlerCalled := false + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + handlerCalled = true + resp := api.ChatResponse{ + Model: "llama3.2", + Message: api.Message{Role: "assistant", Content: "hello"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + if !handlerCalled { + t.Fatal("handler should be called for local model when web_search is not called") + } + }) +} + +func TestWebSearchDoesNotRequireAuthorizationHeaderForMockEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + var authHeader string + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader = r.Header.Get("Authorization") + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "done"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 5, EvalCount: 2}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_auth", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "auth test"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 4, EvalCount: 1}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"test auth"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + if authHeader != "" { + t.Fatalf("expected no Authorization header for mock web search endpoint, got %q", authHeader) + } +} + +// TestWebSearchSearchAPIError tests that a failing search API returns a proper error response. +func TestWebSearchSearchAPIError(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + // Mock search server that returns 500 + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal error", http.StatusInternalServerError) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_err", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "test"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"test"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + // Error response: server_tool_use + web_search_tool_result with error + if len(result.Content) != 2 { + t.Fatalf("expected 2 content blocks for error, got %d", len(result.Content)) + } + if result.Content[0].Type != "server_tool_use" { + t.Errorf("expected 'server_tool_use', got %q", result.Content[0].Type) + } + if result.Content[1].Type != "web_search_tool_result" { + t.Errorf("expected 'web_search_tool_result', got %q", result.Content[1].Type) + } + if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 2 { + t.Fatalf("expected usage input=10 output=2, got %+v", result.Usage) + } +} + +func TestWebSearchStreamingImmediateTakeover(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "After search."}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 10}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + chunks := []api.ChatResponse{ + { + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "Preface "}, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_stream_1", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "latest updates"), + }, + }, + }, + }, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "ignored chunk"}, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{Role: "assistant"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 9, EvalCount: 4}, + }, + } + c.Writer.WriteHeader(http.StatusOK) + for _, chunk := range chunks { + data, _ := json.Marshal(chunk) + _, _ = c.Writer.Write(data) + } + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "stream":true, + "messages":[{"role":"user","content":"Find updates"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + events := parseSSEEvents(t, resp.Body.String()) + if countEventsByName(events, "message_start") != 1 { + t.Fatalf("expected exactly one message_start, got %d", countEventsByName(events, "message_start")) + } + if countEventsByName(events, "message_stop") != 1 { + t.Fatalf("expected exactly one message_stop, got %d", countEventsByName(events, "message_stop")) + } + + textDeltas := collectTextDeltas(t, events) + if !containsString(textDeltas, "Preface ") { + t.Fatalf("expected passthrough text delta, got %v", textDeltas) + } + if !containsString(textDeltas, "After search.") { + t.Fatalf("expected post-search text delta, got %v", textDeltas) + } + if containsString(textDeltas, "ignored chunk") { + t.Fatalf("unexpected text from chunks after takeover: %v", textDeltas) + } +} + +func TestWebSearchStreamingUsageUsesObservedChunkMetrics(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "After search."}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 7}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + chunks := []api.ChatResponse{ + { + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "Preface "}, + Done: false, + Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 4}, + }, + { + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_stream_usage", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "latest updates"), + }, + }, + }, + }, + Done: false, + Metrics: api.Metrics{PromptEvalCount: 0, EvalCount: 0}, + }, + { + Model: "test-model", + Message: api.Message{Role: "assistant"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 4}, + }, + } + c.Writer.WriteHeader(http.StatusOK) + for _, chunk := range chunks { + data, _ := json.Marshal(chunk) + _, _ = c.Writer.Write(data) + } + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "stream":true, + "messages":[{"role":"user","content":"Find updates"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + events := parseSSEEvents(t, resp.Body.String()) + var messageDelta anthropic.MessageDeltaEvent + found := false + for _, event := range events { + if event.event != "message_delta" { + continue + } + if err := json.Unmarshal([]byte(event.data), &messageDelta); err != nil { + t.Fatalf("failed to unmarshal message_delta: %v", err) + } + found = true + break + } + if !found { + t.Fatal("expected message_delta event") + } + if messageDelta.Usage.InputTokens != 32 { + t.Fatalf("expected aggregated input tokens 32 (12 passthrough + 20 followup), got %d", messageDelta.Usage.InputTokens) + } + if messageDelta.Usage.OutputTokens != 11 { + t.Fatalf("expected aggregated output tokens 11 (4 passthrough + 7 followup), got %d", messageDelta.Usage.OutputTokens) + } +} + +func TestWebSearchMixedToolCallsPreferWebSearch(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "Search answer."}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 11, EvalCount: 6}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_other", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: makeArgs("location", "SF"), + }, + }, + { + ID: "call_ws_mixed", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "latest weather"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"Weather?"}], + "tools":[ + {"type":"web_search_20250305","name":"web_search"}, + {"type":"custom","name":"get_weather","input_schema":{"type":"object"}} + ] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if len(result.Content) < 3 { + t.Fatalf("expected at least 3 blocks, got %d", len(result.Content)) + } + if result.Content[0].Type != "server_tool_use" { + t.Fatalf("expected server_tool_use first, got %q", result.Content[0].Type) + } + if result.Content[1].Type != "web_search_tool_result" { + t.Fatalf("expected web_search_tool_result second, got %q", result.Content[1].Type) + } + + for _, block := range result.Content { + if block.Type == "tool_use" && block.Name == "get_weather" { + t.Fatalf("did not expect get_weather tool_use in mixed web_search-preferred path: %+v", result.Content) + } + } +} + +func TestWebSearchFollowupClientToolStopReasonToolUse(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_weather_final", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: makeArgs("location", "New York"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 25, EvalCount: 7}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_tool_use", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "forecast"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 15, EvalCount: 3}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"Do I need an umbrella?"}], + "tools":[ + {"type":"web_search_20250305","name":"web_search"}, + {"type":"custom","name":"get_weather","input_schema":{"type":"object"}} + ] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if result.StopReason != "tool_use" { + t.Fatalf("expected stop_reason tool_use, got %q", result.StopReason) + } + if len(result.Content) < 3 { + t.Fatalf("expected server blocks + tool_use, got %d blocks", len(result.Content)) + } + last := result.Content[len(result.Content)-1] + if last.Type != "tool_use" { + t.Fatalf("expected final block tool_use, got %q", last.Type) + } + if last.Name != "get_weather" { + t.Fatalf("expected final tool name get_weather, got %q", last.Name) + } + if result.Usage.InputTokens != 40 || result.Usage.OutputTokens != 10 { + t.Fatalf("unexpected aggregated usage: %+v", result.Usage) + } +} + +func TestWebSearchMultiIterationLoop(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + followupCall := 0 + followupDecodeErr := false + missingWebSearchTool := false + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var followupReq api.ChatRequest + if err := json.NewDecoder(r.Body).Decode(&followupReq); err != nil { + followupDecodeErr = true + http.Error(w, "bad request", http.StatusBadRequest) + return + } + hasWebSearchTool := false + for _, tool := range followupReq.Tools { + if tool.Function.Name == "web_search" { + hasWebSearchTool = true + break + } + } + if !hasWebSearchTool { + missingWebSearchTool = true + } + + followupCall++ + switch followupCall { + case 1: + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_2", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "loop query 2"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 2}, + } + _ = json.NewEncoder(w).Encode(resp) + case 2: + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "Final answer after 2 searches."}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 30, EvalCount: 3}, + } + _ = json.NewEncoder(w).Encode(resp) + default: + t.Fatalf("unexpected extra followup call: %d", followupCall) + } + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_1", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "loop query 1"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 1}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"do multiple searches"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + if followupCall != 2 { + t.Fatalf("expected 2 followup calls, got %d", followupCall) + } + if followupDecodeErr { + t.Fatal("failed to decode followup request body") + } + if missingWebSearchTool { + t.Fatal("expected followup requests to retain web_search tool definition") + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + serverToolUses := 0 + webResults := 0 + for _, block := range result.Content { + if block.Type == "server_tool_use" { + serverToolUses++ + } + if block.Type == "web_search_tool_result" { + webResults++ + } + } + if serverToolUses != 2 || webResults != 2 { + t.Fatalf("expected two search iterations, got server_tool_use=%d web_search_tool_result=%d", serverToolUses, webResults) + } + + if result.Usage.InputTokens != 60 || result.Usage.OutputTokens != 6 { + t.Fatalf("unexpected aggregated usage: %+v", result.Usage) + } +} + +func TestWebSearchLoopMaxLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + followupCall := 0 + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + followupCall++ + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_loop_limit", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "loop query next"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 7, EvalCount: 2}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_initial", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "loop query 1"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 5, EvalCount: 1}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"keep searching"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + if followupCall != 3 { + t.Fatalf("expected 3 followup calls before max loop error, got %d", followupCall) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + last := result.Content[len(result.Content)-1] + if last.Type != "web_search_tool_result" { + t.Fatalf("expected last block web_search_tool_result, got %q", last.Type) + } + contentJSON, _ := json.Marshal(last.Content) + var errContent anthropic.WebSearchToolResultError + if err := json.Unmarshal(contentJSON, &errContent); err != nil { + t.Fatalf("failed to parse web search error content: %v", err) + } + if errContent.ErrorCode != "max_uses_exceeded" { + t.Fatalf("expected max_uses_exceeded error, got %q", errContent.ErrorCode) + } + if result.StopReason != "end_turn" { + t.Fatalf("expected end_turn, got %q", result.StopReason) + } +} + +func TestWebSearchStreamingFinalStopReasonToolUse(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_weather_stream", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: makeArgs("location", "Seattle"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 14, EvalCount: 5}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + chunks := []api.ChatResponse{ + { + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "Let me check. "}, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_stream_tool_use", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "weather seattle"), + }, + }, + }, + }, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{Role: "assistant"}, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 3}, + }, + } + c.Writer.WriteHeader(http.StatusOK) + for _, chunk := range chunks { + data, _ := json.Marshal(chunk) + _, _ = c.Writer.Write(data) + } + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "stream":true, + "messages":[{"role":"user","content":"Should I take a jacket?"}], + "tools":[ + {"type":"web_search_20250305","name":"web_search"}, + {"type":"custom","name":"get_weather","input_schema":{"type":"object"}} + ] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + events := parseSSEEvents(t, resp.Body.String()) + if countEventsByName(events, "message_start") != 1 { + t.Fatalf("expected exactly one message_start, got %d", countEventsByName(events, "message_start")) + } + + var messageDelta anthropic.MessageDeltaEvent + foundMessageDelta := false + foundToolUse := false + for _, event := range events { + if event.event == "message_delta" { + foundMessageDelta = true + if err := json.Unmarshal([]byte(event.data), &messageDelta); err != nil { + t.Fatalf("failed to unmarshal message_delta: %v", err) + } + } + if event.event == "content_block_start" { + var start anthropic.ContentBlockStartEvent + if err := json.Unmarshal([]byte(event.data), &start); err != nil { + t.Fatalf("failed to unmarshal content_block_start: %v", err) + } + if start.ContentBlock.Type == "tool_use" && start.ContentBlock.Name == "get_weather" { + foundToolUse = true + } + } + } + + if !foundMessageDelta { + t.Fatal("expected message_delta event") + } + if messageDelta.Delta.StopReason != "tool_use" { + t.Fatalf("expected stop_reason tool_use, got %q", messageDelta.Delta.StopReason) + } + if !foundToolUse { + t.Fatal("expected tool_use content block for get_weather") + } +} + +func TestWebSearchFollowupNon200ReturnsApiError(t *testing.T) { + gin.SetMode(gin.TestMode) + enableCloudForTest(t) + + followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer followupServer.Close() + t.Setenv("OLLAMA_HOST", followupServer.URL) + + searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := anthropic.OllamaWebSearchResponse{ + Results: []anthropic.OllamaWebSearchResult{ + {Title: "Result", URL: "https://example.com", Content: "content"}, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer searchServer.Close() + originalEndpoint := anthropic.WebSearchEndpoint + anthropic.WebSearchEndpoint = searchServer.URL + defer func() { anthropic.WebSearchEndpoint = originalEndpoint }() + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_ws_non200", + Function: api.ToolCallFunction{ + Name: "web_search", + Arguments: makeArgs("query", "test"), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 9, EvalCount: 1}, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{ + "model":"test-model:cloud", + "max_tokens":100, + "messages":[{"role":"user","content":"test"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if len(result.Content) != 2 { + t.Fatalf("expected 2 blocks in error response, got %d", len(result.Content)) + } + + contentJSON, _ := json.Marshal(result.Content[1].Content) + var errContent anthropic.WebSearchToolResultError + if err := json.Unmarshal(contentJSON, &errContent); err != nil { + t.Fatalf("failed to parse error content: %v", err) + } + if errContent.ErrorCode != "api_error" { + t.Fatalf("expected api_error, got %q", errContent.ErrorCode) + } + if result.Usage.InputTokens != 9 || result.Usage.OutputTokens != 1 { + t.Fatalf("expected usage input=9 output=1, got %+v", result.Usage) + } +} + +func countEventsByName(events []sseEvent, eventName string) int { + count := 0 + for _, event := range events { + if event.event == eventName { + count++ + } + } + return count +} + +func collectTextDeltas(t *testing.T, events []sseEvent) []string { + t.Helper() + + var deltas []string + for _, event := range events { + if event.event != "content_block_delta" { + continue + } + + var delta anthropic.ContentBlockDeltaEvent + if err := json.Unmarshal([]byte(event.data), &delta); err != nil { + t.Fatalf("failed to unmarshal content_block_delta: %v", err) + } + if delta.Delta.Type == "text_delta" { + deltas = append(deltas, delta.Delta.Text) + } + } + + return deltas +} + +func containsString(values []string, target string) bool { + for _, value := range values { + if value == target { + return true + } + } + return false +} diff --git a/middleware/test_home_test.go b/middleware/test_home_test.go new file mode 100644 index 000000000..6c013c147 --- /dev/null +++ b/middleware/test_home_test.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "testing" + + "github.com/ollama/ollama/envconfig" +) + +func setTestHome(t *testing.T, home string) { + t.Helper() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + envconfig.ReloadServerConfig() +} + +// enableCloudForTest sets HOME to a clean temp dir and clears OLLAMA_NO_CLOUD +// so that cloud features are enabled for the duration of the test. +func enableCloudForTest(t *testing.T) { + t.Helper() + t.Setenv("OLLAMA_NO_CLOUD", "") + setTestHome(t, t.TempDir()) +}