mirror of
https://github.com/ollama/ollama.git
synced 2026-04-01 03:48:43 +07:00
This change fixes a problem in the token cache logic to avoid panics caused by empty token arrays by ensuring at least one token remains on full cache hits in the relevant function. The happens if there is an exact match in the cache on subsequent generations.
113 lines
2.8 KiB
Go
113 lines
2.8 KiB
Go
//go:build mlx
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
)
|
|
|
|
type kvCache struct {
|
|
// For now we only support a single entry, so this is just one sequence
|
|
tokens []int32
|
|
caches []cache.Cache
|
|
}
|
|
|
|
// cacheSession manages caches for a single pipeline run.
|
|
// Callers should append generated tokens to outputs and
|
|
// defer close to save the cache state.
|
|
type cacheSession struct {
|
|
cache *kvCache
|
|
inputs []int32
|
|
outputs []int32
|
|
|
|
caches []cache.Cache
|
|
remaining []int32
|
|
}
|
|
|
|
// begin prepares caches for a new request. It finds the nearest
|
|
// matching cache or creates new caches if none match.
|
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|
if len(c.caches) == 0 {
|
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
|
c.caches = cacheFactory.NewCaches()
|
|
} else {
|
|
c.caches = make([]cache.Cache, m.NumLayers())
|
|
for i := range c.caches {
|
|
c.caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
}
|
|
|
|
remaining := c.findRemaining(inputs)
|
|
|
|
return &cacheSession{
|
|
cache: c,
|
|
inputs: inputs,
|
|
caches: c.caches,
|
|
remaining: remaining,
|
|
}
|
|
}
|
|
|
|
// close saves the token state if the forward pass ran.
|
|
func (s *cacheSession) close() {
|
|
if offset := s.caches[0].Offset(); offset > 0 {
|
|
// Ensure that if we have run the forward pass and set the metadata
|
|
// that we also actually have the data
|
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
|
for _, c := range s.caches {
|
|
k, v := c.State()
|
|
arrays = append(arrays, k, v)
|
|
}
|
|
mlx.AsyncEval(arrays...)
|
|
|
|
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
|
}
|
|
}
|
|
|
|
// findRemaining finds the longest common prefix between tokens and the cached
|
|
// sequence, trims stale cache entries, and returns the remaining tokens.
|
|
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|
prefix := 0
|
|
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
|
prefix++
|
|
}
|
|
|
|
if prefix == len(tokens) && prefix > 0 {
|
|
// Leave one token to run through the model so we can sample a response.
|
|
prefix--
|
|
}
|
|
|
|
if prefix < len(c.tokens) {
|
|
trim := len(c.tokens) - prefix
|
|
for _, kv := range c.caches {
|
|
kv.Trim(trim)
|
|
}
|
|
c.tokens = c.tokens[:prefix]
|
|
}
|
|
|
|
if prefix == 0 {
|
|
slog.Info("Cache miss", "left", len(tokens))
|
|
} else {
|
|
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
|
}
|
|
return tokens[prefix:]
|
|
}
|
|
|
|
func (c *kvCache) log() {
|
|
if len(c.caches) == 0 {
|
|
return
|
|
}
|
|
var totalBytes int
|
|
for _, kv := range c.caches {
|
|
k, v := kv.State()
|
|
totalBytes += k.NumBytes() + v.NumBytes()
|
|
}
|
|
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
|
}
|