anthropic: handle images in tool_result content blocks

This commit is contained in:
ParthSareen
2026-03-24 11:43:48 -07:00
parent 22c2bdbd8a
commit 40f56cf543
2 changed files with 152 additions and 20 deletions

View File

@@ -372,6 +372,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
return convertedRequest, nil
}
func extractBase64Image(blockMap map[string]any) (api.ImageData, error) {
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
return decoded, nil
}
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported", sourceType)
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
@@ -414,26 +432,12 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
case "image":
imageBlocks++
source, ok := blockMap["source"].(map[string]any)
if !ok {
logutil.Trace("anthropic: invalid image source", "role", role)
return nil, errors.New("invalid image source")
decoded, err := extractBase64Image(blockMap)
if err != nil {
logutil.Trace("anthropic: failed to extract image", "role", role, "error", err)
return nil, err
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
images = append(images, decoded)
case "tool_use":
toolUseBlocks++
@@ -462,6 +466,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
toolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
var resultImages []api.ImageData
switch c := blockMap["content"].(type) {
case string:
@@ -469,10 +474,18 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
switch cbMap["type"] {
case "text":
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
case "image":
decoded, err := extractBase64Image(cbMap)
if err != nil {
logutil.Trace("anthropic: failed to extract image from tool_result", "role", role, "error", err)
return nil, err
}
resultImages = append(resultImages, decoded)
}
}
}
@@ -481,6 +494,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
Images: resultImages,
ToolCallID: toolUseID,
})

View File

@@ -266,6 +266,124 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
}
}
func TestFromMessagesRequest_WithToolResultContainingImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_456",
"content": []any{
map[string]any{"type": "text", "text": "Here is the screenshot:"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_456" {
t.Errorf("expected tool_call_id 'call_456', got %q", msg.ToolCallID)
}
if msg.Content != "Here is the screenshot:" {
t.Errorf("expected content 'Here is the screenshot:', got %q", msg.Content)
}
if len(msg.Images) != 1 {
t.Fatalf("expected 1 image in tool result, got %d", len(msg.Images))
}
if string(msg.Images[0]) != string(imgData) {
t.Error("image data mismatch in tool result")
}
}
func TestFromMessagesRequest_WithToolResultContainingMultipleImages(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_789",
"content": []any{
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
map[string]any{"type": "text", "text": "First image above, second below:"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if len(msg.Images) != 2 {
t.Fatalf("expected 2 images in tool result, got %d", len(msg.Images))
}
for i, img := range msg.Images {
if string(img) != string(imgData) {
t.Errorf("image %d data mismatch in tool result", i)
}
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",