mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
anthropic: enable websearch (#14246)
This commit is contained in:
@@ -1,17 +1,25 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
@@ -82,22 +90,25 @@ type MessageParam struct {
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking, server_tool_use, web_search_tool_result
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For text blocks with citations
|
||||
Citations []Citation `json:"citations,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
// For tool_use and server_tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
// For tool_result and web_search_tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
Content any `json:"content,omitempty"` // string, []ContentBlock, []WebSearchResult, or WebSearchToolResultError
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
@@ -105,6 +116,30 @@ type ContentBlock struct {
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// Citation represents a citation in a text block
|
||||
type Citation struct {
|
||||
Type string `json:"type"` // "web_search_result_location"
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
EncryptedIndex string `json:"encrypted_index,omitempty"`
|
||||
CitedText string `json:"cited_text,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchResult represents a single web search result
|
||||
type WebSearchResult struct {
|
||||
Type string `json:"type"` // "web_search_result"
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
EncryptedContent string `json:"encrypted_content,omitempty"`
|
||||
PageAge string `json:"page_age,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchToolResultError represents an error from web search
|
||||
type WebSearchToolResultError struct {
|
||||
Type string `json:"type"` // "web_search_tool_result_error"
|
||||
ErrorCode string `json:"error_code"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
@@ -115,10 +150,13 @@ type ImageSource struct {
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools, or "web_search_20250305" for web search
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
|
||||
// Web search specific fields
|
||||
MaxUses int `json:"max_uses,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
@@ -233,6 +271,8 @@ type StreamErrorEvent struct {
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
logutil.Trace("anthropic: converting request", "req", TraceMessagesRequest(r))
|
||||
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
@@ -259,9 +299,10 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
for i, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: message conversion failed", "index", i, "role", msg.Role, "err", err)
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
@@ -288,8 +329,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
hasBuiltinWebSearch := false
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
hasBuiltinWebSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range r.Tools {
|
||||
// Anthropic built-in web_search maps to Ollama function name "web_search".
|
||||
// If a user-defined tool also uses that name in the same request, drop the
|
||||
// user-defined one to avoid ambiguous tool-call routing.
|
||||
if hasBuiltinWebSearch && !strings.HasPrefix(t.Type, "web_search") && t.Name == "web_search" {
|
||||
logutil.Trace("anthropic: dropping colliding custom web_search tool", "tool", TraceTool(t))
|
||||
continue
|
||||
}
|
||||
|
||||
tool, _, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -302,15 +359,17 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
convertedRequest := &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
logutil.Trace("anthropic: converted request", "req", TraceChatRequest(convertedRequest))
|
||||
|
||||
return convertedRequest, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
@@ -328,10 +387,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
textBlocks := 0
|
||||
imageBlocks := 0
|
||||
toolUseBlocks := 0
|
||||
toolResultBlocks := 0
|
||||
serverToolUseBlocks := 0
|
||||
webSearchToolResultBlocks := 0
|
||||
thinkingBlocks := 0
|
||||
unknownBlocks := 0
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid content block format", "role", role)
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
@@ -339,13 +407,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
textBlocks++
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
imageBlocks++
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
@@ -354,21 +425,26 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
@@ -383,6 +459,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
@@ -408,9 +485,36 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
thinkingBlocks++
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
|
||||
case "server_tool_use":
|
||||
serverToolUseBlocks++
|
||||
id, _ := blockMap["id"].(string)
|
||||
name, _ := blockMap["name"].(string)
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "web_search_tool_result":
|
||||
webSearchToolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchToolResultContent(blockMap["content"]),
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
default:
|
||||
unknownBlocks++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,6 +531,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(content),
|
||||
"text", textBlocks,
|
||||
"image", imageBlocks,
|
||||
"tool_use", toolUseBlocks,
|
||||
"tool_result", toolResultBlocks,
|
||||
"server_tool_use", serverToolUseBlocks,
|
||||
"web_search_result", webSearchToolResultBlocks,
|
||||
"thinking", thinkingBlocks,
|
||||
"unknown", unknownBlocks,
|
||||
"messages", TraceAPIMessages(messages),
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
@@ -435,12 +552,94 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
func formatWebSearchToolResultContent(content any) string {
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return c
|
||||
case []WebSearchResult:
|
||||
var resultContent strings.Builder
|
||||
for _, item := range c {
|
||||
if item.Type != "web_search_result" {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&resultContent, "- %s: %s\n", item.Title, item.URL)
|
||||
}
|
||||
return resultContent.String()
|
||||
case []any:
|
||||
var resultContent strings.Builder
|
||||
for _, item := range c {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch itemMap["type"] {
|
||||
case "web_search_result":
|
||||
title, _ := itemMap["title"].(string)
|
||||
url, _ := itemMap["url"].(string)
|
||||
fmt.Fprintf(&resultContent, "- %s: %s\n", title, url)
|
||||
case "web_search_tool_result_error":
|
||||
errorCode, _ := itemMap["error_code"].(string)
|
||||
if errorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + errorCode
|
||||
}
|
||||
}
|
||||
return resultContent.String()
|
||||
case map[string]any:
|
||||
if c["type"] == "web_search_tool_result_error" {
|
||||
errorCode, _ := c["error_code"].(string)
|
||||
if errorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + errorCode
|
||||
}
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
case WebSearchToolResultError:
|
||||
if c.ErrorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + c.ErrorCode
|
||||
default:
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool, returning true if it's a server tool
|
||||
func convertTool(t Tool) (api.Tool, bool, error) {
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The search query to look up on the web",
|
||||
})
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "web_search",
|
||||
Description: "Search the web for current information. Use this to find up-to-date information about any topic.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"query"},
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
logutil.Trace("anthropic: invalid tool schema", "tool", t.Name, "err", err)
|
||||
return api.Tool{}, false, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,7 +650,7 @@ func convertTool(t Tool) (api.Tool, error) {
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}, false, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
@@ -899,3 +1098,113 @@ func countContentBlock(block any) int {
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
// OllamaWebSearchRequest represents a request to the Ollama web search API
|
||||
type OllamaWebSearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaWebSearchResult represents a single search result from Ollama API
|
||||
type OllamaWebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OllamaWebSearchResponse represents the response from the Ollama web search API
|
||||
type OllamaWebSearchResponse struct {
|
||||
Results []OllamaWebSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
var WebSearchEndpoint = "https://ollama.com/api/web_search"
|
||||
|
||||
func WebSearch(ctx context.Context, query string, maxResults int) (*OllamaWebSearchResponse, error) {
|
||||
if internalcloud.Disabled() {
|
||||
logutil.TraceContext(ctx, "anthropic: web search blocked", "reason", "cloud_disabled")
|
||||
return nil, errors.New(internalcloud.DisabledError("web search is unavailable"))
|
||||
}
|
||||
|
||||
if maxResults <= 0 {
|
||||
maxResults = 5
|
||||
}
|
||||
if maxResults > 10 {
|
||||
maxResults = 10
|
||||
}
|
||||
|
||||
reqBody := OllamaWebSearchRequest{
|
||||
Query: query,
|
||||
MaxResults: maxResults,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal web search request: %w", err)
|
||||
}
|
||||
|
||||
searchURL, err := url.Parse(WebSearchEndpoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web search URL: %w", err)
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search request",
|
||||
"query", TraceTruncateString(query),
|
||||
"max_results", maxResults,
|
||||
"url", searchURL.String(),
|
||||
)
|
||||
|
||||
q := searchURL.Query()
|
||||
q.Set("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
searchURL.RawQuery = q.Encode()
|
||||
|
||||
signature := ""
|
||||
if strings.EqualFold(searchURL.Hostname(), "ollama.com") {
|
||||
challenge := fmt.Sprintf("%s,%s", http.MethodPost, searchURL.RequestURI())
|
||||
signature, err = auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign web search request: %w", err)
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search auth", "signed", signature != "")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create web search request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if signature != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
logutil.TraceContext(ctx, "anthropic: web search response", "status", resp.StatusCode)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("web search returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var searchResp OllamaWebSearchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode web search response: %w", err)
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search results", "count", len(searchResp.Results))
|
||||
|
||||
return &searchResp, nil
|
||||
}
|
||||
|
||||
func ConvertOllamaToAnthropicResults(ollamaResults *OllamaWebSearchResponse) []WebSearchResult {
|
||||
var results []WebSearchResult
|
||||
for _, r := range ollamaResults.Results {
|
||||
results = append(results, WebSearchResult{
|
||||
Type: "web_search_result",
|
||||
URL: r.URL,
|
||||
Title: r.Title,
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package anthropic
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -300,6 +301,78 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "web_search",
|
||||
Description: "User-defined web search that should be dropped",
|
||||
InputSchema: json.RawMessage(`{"type":"invalid"}`),
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 2 {
|
||||
t.Fatalf("expected 2 tools after dropping custom web_search, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Function.Name != "web_search" {
|
||||
t.Fatalf("expected first tool to be built-in web_search, got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
if result.Tools[1].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected second tool to be get_weather, got %q", result.Tools[1].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "web_search",
|
||||
Description: "User-defined web search",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 custom tool, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Function.Name != "web_search" {
|
||||
t.Fatalf("expected custom tool name web_search, got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
if result.Tools[0].Function.Description != "User-defined web search" {
|
||||
t.Fatalf("expected custom description preserved, got %q", result.Tools[0].Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -1063,3 +1136,320 @@ func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Web Search Tests
|
||||
|
||||
func TestConvertTool_WebSearch(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
MaxUses: 5,
|
||||
}
|
||||
|
||||
result, isServerTool, err := convertTool(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !isServerTool {
|
||||
t.Error("expected isServerTool to be true for web_search tool")
|
||||
}
|
||||
|
||||
if result.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", result.Type)
|
||||
}
|
||||
|
||||
if result.Function.Name != "web_search" {
|
||||
t.Errorf("expected name 'web_search', got %q", result.Function.Name)
|
||||
}
|
||||
|
||||
if result.Function.Description == "" {
|
||||
t.Error("expected non-empty description for web_search tool")
|
||||
}
|
||||
|
||||
// Check that query parameter is defined
|
||||
if result.Function.Parameters.Properties == nil {
|
||||
t.Fatal("expected properties to be defined")
|
||||
}
|
||||
|
||||
queryProp, ok := result.Function.Parameters.Properties.Get("query")
|
||||
if !ok {
|
||||
t.Error("expected 'query' property to be defined")
|
||||
}
|
||||
|
||||
if len(queryProp.Type) == 0 || queryProp.Type[0] != "string" {
|
||||
t.Errorf("expected query type to be 'string', got %v", queryProp.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertTool_RegularTool(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "custom",
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
}
|
||||
|
||||
result, isServerTool, err := convertTool(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if isServerTool {
|
||||
t.Error("expected isServerTool to be false for regular tool")
|
||||
}
|
||||
|
||||
if result.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_123",
|
||||
"name": "web_search",
|
||||
"input": map[string]any{"query": "test query"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
|
||||
if len(messages[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(messages[0].ToolCalls))
|
||||
}
|
||||
|
||||
tc := messages[0].ToolCalls[0]
|
||||
if tc.ID != "srvtoolu_123" {
|
||||
t.Errorf("expected tool call ID 'srvtoolu_123', got %q", tc.ID)
|
||||
}
|
||||
|
||||
if tc.Function.Name != "web_search" {
|
||||
t.Errorf("expected tool name 'web_search', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_123",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "web_search_result",
|
||||
"title": "Test Result",
|
||||
"url": "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should have a tool result message
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
|
||||
if messages[0].Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", messages[0].Role)
|
||||
}
|
||||
|
||||
if messages[0].ToolCallID != "srvtoolu_123" {
|
||||
t.Errorf("expected tool_call_id 'srvtoolu_123', got %q", messages[0].ToolCallID)
|
||||
}
|
||||
|
||||
if messages[0].Content == "" {
|
||||
t.Error("expected non-empty content from web search results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_empty",
|
||||
"content": []any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "tool" {
|
||||
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||
}
|
||||
if messages[0].ToolCallID != "srvtoolu_empty" {
|
||||
t.Fatalf("expected tool_call_id srvtoolu_empty, got %q", messages[0].ToolCallID)
|
||||
}
|
||||
if messages[0].Content != "" {
|
||||
t.Fatalf("expected empty content for empty web search results, got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_error",
|
||||
"content": map[string]any{
|
||||
"type": "web_search_tool_result_error",
|
||||
"error_code": "max_uses_exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "tool" {
|
||||
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||
}
|
||||
if messages[0].ToolCallID != "srvtoolu_error" {
|
||||
t.Fatalf("expected tool_call_id srvtoolu_error, got %q", messages[0].ToolCallID)
|
||||
}
|
||||
if !strings.Contains(messages[0].Content, "max_uses_exceeded") {
|
||||
t.Fatalf("expected error code in converted tool content, got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOllamaToAnthropicResults(t *testing.T) {
|
||||
ollamaResp := &OllamaWebSearchResponse{
|
||||
Results: []OllamaWebSearchResult{
|
||||
{
|
||||
Title: "Test Title",
|
||||
URL: "https://example.com",
|
||||
Content: "Test content",
|
||||
},
|
||||
{
|
||||
Title: "Another Result",
|
||||
URL: "https://example.org",
|
||||
Content: "More content",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := ConvertOllamaToAnthropicResults(ollamaResp)
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
if results[0].Type != "web_search_result" {
|
||||
t.Errorf("expected type 'web_search_result', got %q", results[0].Type)
|
||||
}
|
||||
|
||||
if results[0].Title != "Test Title" {
|
||||
t.Errorf("expected title 'Test Title', got %q", results[0].Title)
|
||||
}
|
||||
|
||||
if results[0].URL != "https://example.com" {
|
||||
t.Errorf("expected URL 'https://example.com', got %q", results[0].URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchTypes(t *testing.T) {
|
||||
// Test that WebSearchResult serializes correctly
|
||||
result := WebSearchResult{
|
||||
Type: "web_search_result",
|
||||
URL: "https://example.com",
|
||||
Title: "Test",
|
||||
EncryptedContent: "abc123",
|
||||
PageAge: "2025-01-01",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal WebSearchResult: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled WebSearchResult
|
||||
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||
t.Fatalf("failed to unmarshal WebSearchResult: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Type != result.Type {
|
||||
t.Errorf("type mismatch: expected %q, got %q", result.Type, unmarshaled.Type)
|
||||
}
|
||||
|
||||
// Test WebSearchToolResultError
|
||||
errResult := WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: "max_uses_exceeded",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(errResult)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal WebSearchToolResultError: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaledErr WebSearchToolResultError
|
||||
if err := json.Unmarshal(data, &unmarshaledErr); err != nil {
|
||||
t.Fatalf("failed to unmarshal WebSearchToolResultError: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaledErr.ErrorCode != "max_uses_exceeded" {
|
||||
t.Errorf("error_code mismatch: expected 'max_uses_exceeded', got %q", unmarshaledErr.ErrorCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitation(t *testing.T) {
|
||||
citation := Citation{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com",
|
||||
Title: "Example",
|
||||
EncryptedIndex: "enc123",
|
||||
CitedText: "Some cited text...",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(citation)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal Citation: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled Citation
|
||||
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||
t.Fatalf("failed to unmarshal Citation: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Type != "web_search_result_location" {
|
||||
t.Errorf("type mismatch: expected 'web_search_result_location', got %q", unmarshaled.Type)
|
||||
}
|
||||
|
||||
if unmarshaled.CitedText != "Some cited text..." {
|
||||
t.Errorf("cited_text mismatch: expected 'Some cited text...', got %q", unmarshaled.CitedText)
|
||||
}
|
||||
}
|
||||
|
||||
352
anthropic/trace.go
Normal file
352
anthropic/trace.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Trace truncation limits.
|
||||
const (
|
||||
TraceMaxStringRunes = 240
|
||||
TraceMaxSliceItems = 8
|
||||
TraceMaxMapEntries = 16
|
||||
TraceMaxDepth = 4
|
||||
)
|
||||
|
||||
// TraceTruncateString shortens s to TraceMaxStringRunes, appending a count of
|
||||
// omitted characters when truncated.
|
||||
func TraceTruncateString(s string) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
runes := []rune(s)
|
||||
if len(runes) <= TraceMaxStringRunes {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%s...(+%d chars)", string(runes[:TraceMaxStringRunes]), len(runes)-TraceMaxStringRunes)
|
||||
}
|
||||
|
||||
// TraceJSON round-trips v through JSON and returns a compacted representation.
|
||||
func TraceJSON(v any) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return map[string]any{"marshal_error": err.Error(), "type": fmt.Sprintf("%T", v)}
|
||||
}
|
||||
var out any
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
return TraceTruncateString(string(data))
|
||||
}
|
||||
return TraceCompactValue(out, 0)
|
||||
}
|
||||
|
||||
// TraceCompactValue recursively truncates strings, slices, and maps for trace
|
||||
// output. depth tracks recursion to enforce TraceMaxDepth.
|
||||
func TraceCompactValue(v any, depth int) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
if depth >= TraceMaxDepth {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return TraceTruncateString(t)
|
||||
case []any:
|
||||
return fmt.Sprintf("<array len=%d>", len(t))
|
||||
case map[string]any:
|
||||
return fmt.Sprintf("<object keys=%d>", len(t))
|
||||
default:
|
||||
return fmt.Sprintf("<%T>", v)
|
||||
}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return TraceTruncateString(t)
|
||||
case []any:
|
||||
limit := min(len(t), TraceMaxSliceItems)
|
||||
out := make([]any, 0, limit+1)
|
||||
for i := range limit {
|
||||
out = append(out, TraceCompactValue(t[i], depth+1))
|
||||
}
|
||||
if len(t) > limit {
|
||||
out = append(out, fmt.Sprintf("... +%d more items", len(t)-limit))
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
keys := make([]string, 0, len(t))
|
||||
for k := range t {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
limit := min(len(keys), TraceMaxMapEntries)
|
||||
out := make(map[string]any, limit+1)
|
||||
for i := range limit {
|
||||
out[keys[i]] = TraceCompactValue(t[keys[i]], depth+1)
|
||||
}
|
||||
if len(keys) > limit {
|
||||
out["__truncated_keys"] = len(keys) - limit
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic request/response tracing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TraceMessagesRequest returns a compact trace representation of a MessagesRequest.
|
||||
func TraceMessagesRequest(r MessagesRequest) map[string]any {
|
||||
return map[string]any{
|
||||
"model": r.Model,
|
||||
"max_tokens": r.MaxTokens,
|
||||
"messages": traceMessageParams(r.Messages),
|
||||
"system": traceAnthropicContent(r.System),
|
||||
"stream": r.Stream,
|
||||
"tools": traceTools(r.Tools),
|
||||
"tool_choice": TraceJSON(r.ToolChoice),
|
||||
"thinking": TraceJSON(r.Thinking),
|
||||
"stop_sequences": r.StopSequences,
|
||||
"temperature": ptrVal(r.Temperature),
|
||||
"top_p": ptrVal(r.TopP),
|
||||
"top_k": ptrVal(r.TopK),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceMessagesResponse returns a compact trace representation of a MessagesResponse.
|
||||
func TraceMessagesResponse(r MessagesResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"id": r.ID,
|
||||
"model": r.Model,
|
||||
"content": TraceJSON(r.Content),
|
||||
"stop_reason": r.StopReason,
|
||||
"usage": r.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
func traceMessageParams(msgs []MessageParam) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, map[string]any{
|
||||
"role": m.Role,
|
||||
"content": traceAnthropicContent(m.Content),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func traceAnthropicContent(content any) any {
|
||||
switch c := content.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
return TraceTruncateString(c)
|
||||
case []any:
|
||||
blocks := make([]any, 0, len(c))
|
||||
for _, block := range c {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
blocks = append(blocks, TraceCompactValue(block, 0))
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, traceAnthropicBlock(blockMap))
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return TraceJSON(c)
|
||||
}
|
||||
}
|
||||
|
||||
func traceAnthropicBlock(block map[string]any) map[string]any {
|
||||
blockType, _ := block["type"].(string)
|
||||
out := map[string]any{"type": blockType}
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
out["text"] = TraceTruncateString(text)
|
||||
} else {
|
||||
out["text"] = TraceCompactValue(block["text"], 0)
|
||||
}
|
||||
case "thinking":
|
||||
if thinking, ok := block["thinking"].(string); ok {
|
||||
out["thinking"] = TraceTruncateString(thinking)
|
||||
} else {
|
||||
out["thinking"] = TraceCompactValue(block["thinking"], 0)
|
||||
}
|
||||
case "tool_use", "server_tool_use":
|
||||
out["id"] = block["id"]
|
||||
out["name"] = block["name"]
|
||||
out["input"] = TraceCompactValue(block["input"], 0)
|
||||
case "tool_result", "web_search_tool_result":
|
||||
out["tool_use_id"] = block["tool_use_id"]
|
||||
out["content"] = TraceCompactValue(block["content"], 0)
|
||||
case "image":
|
||||
if source, ok := block["source"].(map[string]any); ok {
|
||||
out["source"] = map[string]any{
|
||||
"type": source["type"],
|
||||
"media_type": source["media_type"],
|
||||
"url": source["url"],
|
||||
"data_len": len(fmt.Sprint(source["data"])),
|
||||
}
|
||||
}
|
||||
default:
|
||||
out["block"] = TraceCompactValue(block, 0)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func traceTools(tools []Tool) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, TraceTool(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceTool returns a compact trace representation of an Anthropic Tool.
|
||||
func TraceTool(t Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": TraceTruncateString(t.Description),
|
||||
"input_schema": TraceJSON(t.InputSchema),
|
||||
"max_uses": t.MaxUses,
|
||||
}
|
||||
}
|
||||
|
||||
// ContentBlockTypes returns the type strings from content (when it's []any blocks).
|
||||
func ContentBlockTypes(content any) []string {
|
||||
blocks, ok := content.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
types := make([]string, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
types = append(types, fmt.Sprintf("%T", block))
|
||||
continue
|
||||
}
|
||||
t, _ := blockMap["type"].(string)
|
||||
types = append(types, t)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
func ptrVal[T any](v *T) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ollama api.* tracing (shared between anthropic and middleware packages)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TraceChatRequest returns a compact trace representation of an Ollama ChatRequest.
|
||||
func TraceChatRequest(req *api.ChatRequest) map[string]any {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
stream := false
|
||||
if req.Stream != nil {
|
||||
stream = *req.Stream
|
||||
}
|
||||
return map[string]any{
|
||||
"model": req.Model,
|
||||
"messages": TraceAPIMessages(req.Messages),
|
||||
"tools": TraceAPITools(req.Tools),
|
||||
"stream": stream,
|
||||
"options": req.Options,
|
||||
"think": TraceJSON(req.Think),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceChatResponse returns a compact trace representation of an Ollama ChatResponse.
|
||||
func TraceChatResponse(resp api.ChatResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"model": resp.Model,
|
||||
"done": resp.Done,
|
||||
"done_reason": resp.DoneReason,
|
||||
"message": TraceAPIMessage(resp.Message),
|
||||
"metrics": TraceJSON(resp.Metrics),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceAPIMessages returns compact trace representations for a slice of api.Message.
|
||||
func TraceAPIMessages(msgs []api.Message) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, TraceAPIMessage(m))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceAPIMessage returns a compact trace representation of a single api.Message.
|
||||
func TraceAPIMessage(m api.Message) map[string]any {
|
||||
return map[string]any{
|
||||
"role": m.Role,
|
||||
"content": TraceTruncateString(m.Content),
|
||||
"thinking": TraceTruncateString(m.Thinking),
|
||||
"images": traceImageSizes(m.Images),
|
||||
"tool_calls": traceToolCalls(m.ToolCalls),
|
||||
"tool_name": m.ToolName,
|
||||
"tool_call_id": m.ToolCallID,
|
||||
}
|
||||
}
|
||||
|
||||
func traceImageSizes(images []api.ImageData) []int {
|
||||
if len(images) == 0 {
|
||||
return nil
|
||||
}
|
||||
sizes := make([]int, 0, len(images))
|
||||
for _, img := range images {
|
||||
sizes = append(sizes, len(img))
|
||||
}
|
||||
return sizes
|
||||
}
|
||||
|
||||
// TraceAPITools returns compact trace representations for a slice of api.Tool.
|
||||
func TraceAPITools(tools api.Tools) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, TraceAPITool(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceAPITool returns a compact trace representation of a single api.Tool.
|
||||
func TraceAPITool(t api.Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Function.Name,
|
||||
"description": TraceTruncateString(t.Function.Description),
|
||||
"parameters": TraceJSON(t.Function.Parameters),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceToolCall returns a compact trace representation of an api.ToolCall.
|
||||
func TraceToolCall(tc api.ToolCall) map[string]any {
|
||||
return map[string]any{
|
||||
"id": tc.ID,
|
||||
"name": tc.Function.Name,
|
||||
"args": TraceJSON(tc.Function.Arguments),
|
||||
}
|
||||
}
|
||||
|
||||
func traceToolCalls(tcs []api.ToolCall) []map[string]any {
|
||||
if len(tcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(tcs))
|
||||
for _, tc := range tcs {
|
||||
out = append(out, TraceToolCall(tc))
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user