From 40f56cf543b69b3320193afc3ef1b288696ec16a Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 24 Mar 2026 11:43:48 -0700 Subject: [PATCH] anthropic: handle images in tool_result content blocks --- anthropic/anthropic.go | 54 +++++++++++------ anthropic/anthropic_test.go | 118 ++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 20 deletions(-) diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index cde6360e5..3077a72e5 100755 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -372,6 +372,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) { return convertedRequest, nil } +func extractBase64Image(blockMap map[string]any) (api.ImageData, error) { + source, ok := blockMap["source"].(map[string]any) + if !ok { + return nil, errors.New("invalid image source") + } + + sourceType, _ := source["type"].(string) + if sourceType == "base64" { + data, _ := source["data"].(string) + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, fmt.Errorf("invalid base64 image data: %w", err) + } + return decoded, nil + } + return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported", sourceType) +} + // convertMessage converts an Anthropic MessageParam to Ollama api.Message(s) func convertMessage(msg MessageParam) ([]api.Message, error) { var messages []api.Message @@ -414,26 +432,12 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { case "image": imageBlocks++ - source, ok := blockMap["source"].(map[string]any) - if !ok { - logutil.Trace("anthropic: invalid image source", "role", role) - return nil, errors.New("invalid image source") + decoded, err := extractBase64Image(blockMap) + if err != nil { + logutil.Trace("anthropic: failed to extract image", "role", role, "error", err) + return nil, err } - - sourceType, _ := source["type"].(string) - if sourceType == "base64" { - data, _ := source["data"].(string) - decoded, err := base64.StdEncoding.DecodeString(data) - if err != nil { - logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err) - return nil, fmt.Errorf("invalid base64 image data: %w", err) - } - images = append(images, decoded) - } else { - logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType) - return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType) - } - // URL images would need to be fetched - skip for now + images = append(images, decoded) case "tool_use": toolUseBlocks++ @@ -462,6 +466,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { toolResultBlocks++ toolUseID, _ := blockMap["tool_use_id"].(string) var resultContent string + var resultImages []api.ImageData switch c := blockMap["content"].(type) { case string: @@ -469,10 +474,18 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { case []any: for _, cb := range c { if cbMap, ok := cb.(map[string]any); ok { - if cbMap["type"] == "text" { + switch cbMap["type"] { + case "text": if text, ok := cbMap["text"].(string); ok { resultContent += text } + case "image": + decoded, err := extractBase64Image(cbMap) + if err != nil { + logutil.Trace("anthropic: failed to extract image from tool_result", "role", role, "error", err) + return nil, err + } + resultImages = append(resultImages, decoded) } } } @@ -481,6 +494,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { toolResults = append(toolResults, api.Message{ Role: "tool", Content: resultContent, + Images: resultImages, ToolCallID: toolUseID, }) diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go index faa98a2ef..045a87436 100755 --- a/anthropic/anthropic_test.go +++ b/anthropic/anthropic_test.go @@ -266,6 +266,124 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) { } } +func TestFromMessagesRequest_WithToolResultContainingImage(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(testImage) + + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "user", + Content: []any{ + map[string]any{ + "type": "tool_result", + "tool_use_id": "call_456", + "content": []any{ + map[string]any{"type": "text", "text": "Here is the screenshot:"}, + map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": "image/png", + "data": testImage, + }, + }, + }, + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + msg := result.Messages[0] + if msg.Role != "tool" { + t.Errorf("expected role 'tool', got %q", msg.Role) + } + if msg.ToolCallID != "call_456" { + t.Errorf("expected tool_call_id 'call_456', got %q", msg.ToolCallID) + } + if msg.Content != "Here is the screenshot:" { + t.Errorf("expected content 'Here is the screenshot:', got %q", msg.Content) + } + if len(msg.Images) != 1 { + t.Fatalf("expected 1 image in tool result, got %d", len(msg.Images)) + } + if string(msg.Images[0]) != string(imgData) { + t.Error("image data mismatch in tool result") + } +} + +func TestFromMessagesRequest_WithToolResultContainingMultipleImages(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(testImage) + + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "user", + Content: []any{ + map[string]any{ + "type": "tool_result", + "tool_use_id": "call_789", + "content": []any{ + map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": "image/png", + "data": testImage, + }, + }, + map[string]any{"type": "text", "text": "First image above, second below:"}, + map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": "image/png", + "data": testImage, + }, + }, + }, + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + msg := result.Messages[0] + if msg.Role != "tool" { + t.Errorf("expected role 'tool', got %q", msg.Role) + } + if len(msg.Images) != 2 { + t.Fatalf("expected 2 images in tool result, got %d", len(msg.Images)) + } + for i, img := range msg.Images { + if string(img) != string(imgData) { + t.Errorf("image %d data mismatch in tool result", i) + } + } +} + func TestFromMessagesRequest_WithTools(t *testing.T) { req := MessagesRequest{ Model: "test-model",