mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
This reverts commit ee25219edd.
This commit is contained in:
@@ -518,26 +518,24 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -553,11 +551,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -785,121 +779,3 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4)
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -605,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -678,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -731,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -778,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -903,7 +903,7 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -937,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
|
||||
Reference in New Issue
Block a user