mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
anthropic: handle images in tool_result content blocks
This commit is contained in:
@@ -372,6 +372,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
return convertedRequest, nil
|
||||
}
|
||||
|
||||
func extractBase64Image(blockMap map[string]any) (api.ImageData, error) {
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported", sourceType)
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
@@ -414,26 +432,12 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
case "image":
|
||||
imageBlocks++
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
decoded, err := extractBase64Image(blockMap)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
logutil.Trace("anthropic: failed to extract image", "role", role, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
@@ -462,6 +466,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
toolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
var resultImages []api.ImageData
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
@@ -469,10 +474,18 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
switch cbMap["type"] {
|
||||
case "text":
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
case "image":
|
||||
decoded, err := extractBase64Image(cbMap)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: failed to extract image from tool_result", "role", role, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
resultImages = append(resultImages, decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -481,6 +494,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
Images: resultImages,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
|
||||
@@ -266,6 +266,124 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResultContainingImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_456",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "Here is the screenshot:"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_456" {
|
||||
t.Errorf("expected tool_call_id 'call_456', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "Here is the screenshot:" {
|
||||
t.Errorf("expected content 'Here is the screenshot:', got %q", msg.Content)
|
||||
}
|
||||
if len(msg.Images) != 1 {
|
||||
t.Fatalf("expected 1 image in tool result, got %d", len(msg.Images))
|
||||
}
|
||||
if string(msg.Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch in tool result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResultContainingMultipleImages(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_789",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
map[string]any{"type": "text", "text": "First image above, second below:"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if len(msg.Images) != 2 {
|
||||
t.Fatalf("expected 2 images in tool result, got %d", len(msg.Images))
|
||||
}
|
||||
for i, img := range msg.Images {
|
||||
if string(img) != string(imgData) {
|
||||
t.Errorf("image %d data mismatch in tool result", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
|
||||
Reference in New Issue
Block a user