Files
ollama/x/mlxrunner/cache.go
Patrick Devine 857cffd22a bugfix: fix crash bug in token cache logic
This change fixes a problem in the token cache logic to avoid panics caused by empty token arrays
by ensuring at least one token remains on full cache hits in the relevant function. The happens
if there is an exact match in the cache on subsequent generations.
2026-02-26 18:35:44 -08:00

113 lines
2.8 KiB
Go

//go:build mlx
package mlxrunner
import (
"fmt"
"log/slog"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type kvCache struct {
// For now we only support a single entry, so this is just one sequence
tokens []int32
caches []cache.Cache
}
// cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and
// defer close to save the cache state.
type cacheSession struct {
cache *kvCache
inputs []int32
outputs []int32
caches []cache.Cache
remaining []int32
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
if len(c.caches) == 0 {
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
} else {
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
}
}
remaining := c.findRemaining(inputs)
return &cacheSession{
cache: c,
inputs: inputs,
caches: c.caches,
remaining: remaining,
}
}
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
if offset := s.caches[0].Offset(); offset > 0 {
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, c := range s.caches {
k, v := c.State()
arrays = append(arrays, k, v)
}
mlx.AsyncEval(arrays...)
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
}
}
// findRemaining finds the longest common prefix between tokens and the cached
// sequence, trims stale cache entries, and returns the remaining tokens.
func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix := 0
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
prefix++
}
if prefix == len(tokens) && prefix > 0 {
// Leave one token to run through the model so we can sample a response.
prefix--
}
if prefix < len(c.tokens) {
trim := len(c.tokens) - prefix
for _, kv := range c.caches {
kv.Trim(trim)
}
c.tokens = c.tokens[:prefix]
}
if prefix == 0 {
slog.Info("Cache miss", "left", len(tokens))
} else {
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
}
return tokens[prefix:]
}
func (c *kvCache) log() {
if len(c.caches) == 0 {
return
}
var totalBytes int
for _, kv := range c.caches {
k, v := kv.State()
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
}