mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
Update LRU last used time just on the nodes that actually used during processing rather than all snapshots along the path. This allows eviction to remove nodes more accurately so we can avoid other heuristics to auto-merge nodes.
575 lines
17 KiB
Go
575 lines
17 KiB
Go
// cache.go manages a shared KV cache across conversations using a compressed
|
|
// prefix trie. Each trie node stores a token sequence (edge) and optional
|
|
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
|
|
//
|
|
// Key properties:
|
|
// - Only one path through the trie is "active" (backed by live MLX arrays)
|
|
// at a time. Switching paths pages out the frontier node and pages in the
|
|
// new path.
|
|
// - Snapshots are only captured at the frontier (end) of the active path.
|
|
// Intermediate node snapshots come from split prefill.
|
|
// - All cache layers must stay at the same token offset.
|
|
// - Sibling edges must not share a common token prefix (compressed trie
|
|
// invariant).
|
|
// - begin() always re-evaluates at least one token so the pipeline can seed
|
|
// generation, even on a full prefix match.
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"cmp"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
)
|
|
|
|
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
|
|
|
|
type kvCache struct {
|
|
root *trieNode // root of the prefix trie
|
|
activePath []*trieNode // current root→leaf path with live MLX arrays
|
|
caches []cache.Cache
|
|
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
|
|
}
|
|
|
|
// cacheSession manages caches for a single pipeline run.
|
|
// Callers should append generated tokens to outputs and
|
|
// defer close to save the cache state.
|
|
type cacheSession struct {
|
|
cache *kvCache
|
|
inputs []int32
|
|
outputs []int32
|
|
|
|
caches []cache.Cache
|
|
remaining []int32
|
|
|
|
// snapshotOffset, if > 0, is a trie node boundary where we need to
|
|
// capture a snapshot during prefill. This enables future requests
|
|
// branching at this point to restore non-rewindable caches (e.g.
|
|
// RecurrentCache) instead of re-evaluating from scratch.
|
|
snapshotOffset int
|
|
}
|
|
|
|
func (c *kvCache) ensureCaches(m base.Model) {
|
|
if len(c.caches) != 0 {
|
|
return
|
|
}
|
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
|
c.caches = cacheFactory.NewCaches()
|
|
return
|
|
}
|
|
c.caches = make([]cache.Cache, m.NumLayers())
|
|
for i := range c.caches {
|
|
c.caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
|
|
func (c *kvCache) ensureRoot() {
|
|
if c.root == nil {
|
|
c.root = &trieNode{
|
|
lastUsed: time.Now(),
|
|
}
|
|
c.activePath = []*trieNode{c.root}
|
|
}
|
|
}
|
|
|
|
// begin prepares caches for a new request. It finds the nearest
|
|
// matching cache or creates new caches if none match.
|
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|
c.ensureCaches(m)
|
|
c.ensureRoot()
|
|
|
|
matchPath, matched := findBestMatch(c.root, inputs)
|
|
originalMatched := matched
|
|
|
|
// Always keep at least one token to re-evaluate so the
|
|
// pipeline can seed token generation from it.
|
|
if matched == len(inputs) && matched > 0 {
|
|
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
|
}
|
|
|
|
// Switch to the matched path, paging in/out as needed.
|
|
c.switchToPath(matchPath, matched)
|
|
|
|
// switchToPath aligns caches to a common offset
|
|
prefix := c.minCacheOffset()
|
|
remaining := inputs[prefix:]
|
|
|
|
// Schedule a snapshot at the branch point during prefill so future
|
|
// requests diverging here can restore instead of re-evaluating.
|
|
var snapshotAt int
|
|
if prefix < matched {
|
|
snapshotAt = matched
|
|
}
|
|
|
|
args := []any{"total", len(inputs), "matched", originalMatched}
|
|
args = append(args, "cached", prefix, "left", len(remaining))
|
|
if snapshotAt > 0 {
|
|
args = append(args, "pending_snapshot", snapshotAt)
|
|
}
|
|
if prefix == 0 {
|
|
slog.Info("cache miss", args...)
|
|
} else {
|
|
slog.Info("cache hit", args...)
|
|
}
|
|
|
|
return &cacheSession{
|
|
cache: c,
|
|
inputs: inputs,
|
|
snapshotOffset: snapshotAt,
|
|
caches: c.caches,
|
|
remaining: remaining,
|
|
}
|
|
}
|
|
|
|
// switchToPath transitions from the current active path to a new path,
|
|
// paging out diverging segments and paging in the new path.
|
|
func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
|
|
defer c.enforceEvictionPolicy()
|
|
|
|
// Find common ancestor index.
|
|
commonLen := 0
|
|
for commonLen < len(c.activePath) && commonLen < len(newPath) {
|
|
if c.activePath[commonLen] != newPath[commonLen] {
|
|
break
|
|
}
|
|
commonLen++
|
|
}
|
|
|
|
ancestorOffset := 0
|
|
if commonLen > 0 {
|
|
ancestorOffset = c.activePath[commonLen-1].endOffset
|
|
}
|
|
|
|
var pageOutCount, pageInCount int
|
|
|
|
// Page out the leaf of the old path. Only the leaf's live cache
|
|
// state is correct — intermediate nodes already have snapshots
|
|
// captured during their creation (splitNode + prefill). Snapshotting
|
|
// non-leaf nodes here would produce wrong results for non-rewindable
|
|
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
|
// the intermediate boundary.
|
|
leaf := len(c.activePath) - 1
|
|
leafDiverges := leaf >= commonLen
|
|
leafNeedsRewind := matched < c.activePath[leaf].endOffset
|
|
if leafDiverges || leafNeedsRewind {
|
|
node := c.activePath[leaf]
|
|
if !node.hasAllSnapshots() {
|
|
fromOffset := node.startOffset()
|
|
snaps := make([]cache.Snapshot, len(c.caches))
|
|
for j, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
snaps[j] = kv.Snapshot(fromOffset)
|
|
}
|
|
node.setSnapshots(snaps, &c.pagedOutBytes)
|
|
pageOutCount++
|
|
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
|
|
}
|
|
}
|
|
|
|
// Rewind each cache to the target offset or free it. When matched
|
|
// falls within the ancestor's range (same-path case), we rewind
|
|
// directly to the match point. Otherwise we rewind to the ancestor
|
|
// and let page-in bring us forward to matched.
|
|
rewindTarget := min(ancestorOffset, matched)
|
|
for _, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
if !kv.Restore(nil, rewindTarget) {
|
|
kv.Free()
|
|
}
|
|
}
|
|
|
|
// Page in — walk the full new path, restoring from snapshots.
|
|
// Freed caches naturally pick up the first available snapshot.
|
|
// Caches already past a node skip it via offset check.
|
|
pageIn:
|
|
for _, node := range newPath {
|
|
if !node.hasSnapshots() {
|
|
continue
|
|
}
|
|
nodeTarget := min(node.endOffset, matched)
|
|
for j, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
|
continue
|
|
}
|
|
if kv.Offset() >= nodeTarget {
|
|
continue
|
|
}
|
|
if !kv.Restore(node.snapshots[j], nodeTarget) {
|
|
// Restore failed — stop page-in and let alignment
|
|
// bring all caches to a consistent offset.
|
|
break pageIn
|
|
}
|
|
}
|
|
if node.endOffset > ancestorOffset {
|
|
pageInCount++
|
|
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
|
|
}
|
|
}
|
|
|
|
// Align all caches to the minimum offset.
|
|
c.activePath = newPath
|
|
minOff := c.minCacheOffset()
|
|
for _, kv := range c.caches {
|
|
if kv != nil && kv.Offset() != minOff {
|
|
if !kv.Restore(nil, minOff) {
|
|
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
|
|
c.freeAll()
|
|
break
|
|
}
|
|
}
|
|
}
|
|
for i := len(c.activePath) - 1; i >= 0; i-- {
|
|
if c.activePath[i].endOffset <= minOff {
|
|
c.activePath = c.activePath[:i+1]
|
|
break
|
|
}
|
|
}
|
|
|
|
// Update last-used time on only the final used node. For recurrent
|
|
// caches we don't need the intermediate snapshots and for KV caches
|
|
// we can reslice the data out of merged edges.
|
|
if len(c.activePath) > 0 {
|
|
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
|
|
}
|
|
|
|
if pageOutCount > 0 || pageInCount > 0 {
|
|
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
|
|
}
|
|
}
|
|
|
|
// snapshot creates a snapshot at the current cache position. During prefill,
|
|
// it is called at branch points (user=false) to create restore points for
|
|
// future diverging requests and with user=true to mark an explicit reusable
|
|
// restore point.
|
|
func (s *cacheSession) snapshot(user bool) {
|
|
c := s.cache
|
|
cacheOffset := c.minCacheOffset()
|
|
if cacheOffset <= 0 {
|
|
return
|
|
}
|
|
|
|
// Clear pending intermediate snapshot if we've reached or passed it.
|
|
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
|
|
s.snapshotOffset = 0
|
|
}
|
|
|
|
// The last node in activePath is the frontier where caches are advancing.
|
|
// cacheOffset is always >= its endOffset: begin() restores caches to this
|
|
// boundary and prefill advances monotonically forward.
|
|
frontier := c.activePath[len(c.activePath)-1]
|
|
|
|
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
|
|
if frontier.endOffset == cacheOffset {
|
|
if user {
|
|
frontier.user = true
|
|
}
|
|
if !frontier.hasAllSnapshots() {
|
|
s.attachSnapshots(frontier, cacheOffset)
|
|
}
|
|
return
|
|
}
|
|
|
|
if frontier.endOffset > cacheOffset {
|
|
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
|
|
return
|
|
}
|
|
|
|
// Advance the trie to cacheOffset — find or create a node there.
|
|
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
|
|
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
|
|
|
|
// Attach fresh snapshots from the live caches. Always use fresh
|
|
// snapshots even if the node already has some (e.g. from splitNode's
|
|
// Cache.Split which may be incomplete for non-splittable caches
|
|
// like RecurrentCache).
|
|
if user {
|
|
frontier.user = true
|
|
}
|
|
s.attachSnapshots(frontier, cacheOffset)
|
|
}
|
|
|
|
// advancePath advances the active path from the current frontier by matching
|
|
// tokens against existing trie children, splitting partial matches, and
|
|
// appending any remaining tokens as new nodes. Returns the new frontier.
|
|
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
|
|
// Check if existing children already cover some or all of tokens.
|
|
// tokens may span multiple trie nodes when extending a previous run's
|
|
// leaf and this snapshot now overlaps that same range.
|
|
matchPath, matched := findBestMatch(frontier, tokens)
|
|
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
|
|
remaining := tokens[matched:]
|
|
|
|
// Check for a partial match within the last node's edge — if so, split it.
|
|
if len(matchPath) > 1 {
|
|
lastNode := matchPath[len(matchPath)-1]
|
|
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
|
|
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
|
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
|
|
}
|
|
}
|
|
|
|
// Append traversed nodes (excluding frontier) to the active path.
|
|
c.activePath = append(c.activePath, matchPath[1:]...)
|
|
dest := matchPath[len(matchPath)-1]
|
|
|
|
if len(remaining) > 0 {
|
|
// Drop non-user snapshots so appendTokens can extend in-place
|
|
// rather than creating a new child node.
|
|
if len(dest.children) == 0 && !dest.user {
|
|
dest.setSnapshots(nil, &c.pagedOutBytes)
|
|
}
|
|
newDest := dest.appendTokens(c.root, remaining, endOffset)
|
|
if newDest != dest {
|
|
c.activePath = append(c.activePath, newDest)
|
|
}
|
|
dest = newDest
|
|
}
|
|
return dest
|
|
}
|
|
|
|
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
|
|
// The node must be on the active path (and thus protected from eviction;
|
|
// lastUsed is updated in close()). All non-nil caches must be at the same
|
|
// offset (cacheOffset); a mismatch indicates a bug in the caller.
|
|
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
|
c := s.cache
|
|
|
|
if c.activePath[len(c.activePath)-1] != node {
|
|
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
|
|
return
|
|
}
|
|
|
|
snaps := make([]cache.Snapshot, len(c.caches))
|
|
for i, kv := range c.caches {
|
|
if kv != nil {
|
|
if kv.Offset() != cacheOffset {
|
|
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
|
|
}
|
|
snaps[i] = kv.Snapshot(node.startOffset())
|
|
}
|
|
}
|
|
node.setSnapshots(snaps, &c.pagedOutBytes)
|
|
node.lastUsed = time.Now()
|
|
slog.Debug("created snapshot", "offset", cacheOffset)
|
|
c.enforceEvictionPolicy()
|
|
}
|
|
|
|
// freeAll releases all cache layers.
|
|
func (c *kvCache) freeAll() {
|
|
for _, kv := range c.caches {
|
|
if kv != nil {
|
|
kv.Free()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *kvCache) minCacheOffset() int {
|
|
offset := 0
|
|
found := false
|
|
for _, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
if off := kv.Offset(); !found || off < offset {
|
|
offset = off
|
|
found = true
|
|
}
|
|
}
|
|
return offset
|
|
}
|
|
|
|
// close saves the token state if the forward pass ran.
|
|
func (s *cacheSession) close() {
|
|
offset := s.cache.minCacheOffset()
|
|
if offset <= 0 {
|
|
return
|
|
}
|
|
|
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
|
for _, kv := range s.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
arrays = append(arrays, kv.State()...)
|
|
}
|
|
|
|
// Ensure that if we have run the forward pass and set the metadata
|
|
// that we also actually have the data.
|
|
mlx.AsyncEval(arrays...)
|
|
|
|
// Advance the trie frontier with any newly generated tokens.
|
|
c := s.cache
|
|
if len(c.activePath) > 0 {
|
|
frontier := c.activePath[len(c.activePath)-1]
|
|
stored := append(s.inputs, s.outputs...)
|
|
|
|
if offset > frontier.endOffset {
|
|
newTokens := stored[frontier.endOffset:offset]
|
|
c.advancePath(frontier, newTokens, offset)
|
|
}
|
|
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
|
|
}
|
|
}
|
|
|
|
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
|
|
func (c *kvCache) enforceEvictionPolicy() {
|
|
if c.pagedOutBytes <= maxPagedOutBytes {
|
|
return
|
|
}
|
|
|
|
activeSet := make(map[*trieNode]bool, len(c.activePath))
|
|
for _, n := range c.activePath {
|
|
activeSet[n] = true
|
|
}
|
|
|
|
for c.pagedOutBytes > maxPagedOutBytes {
|
|
var best *trieNode
|
|
walkNodes(c.root, func(n *trieNode) bool {
|
|
if n == c.root || activeSet[n] || len(n.children) > 1 {
|
|
return true
|
|
}
|
|
// Evict: oldest, then deepest, then largest.
|
|
if best == nil || cmp.Or(
|
|
n.lastUsed.Compare(best.lastUsed),
|
|
cmp.Compare(best.endOffset, n.endOffset),
|
|
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
|
|
) < 0 {
|
|
best = n
|
|
}
|
|
return true
|
|
})
|
|
if best == nil {
|
|
break
|
|
}
|
|
c.evictNode(best)
|
|
}
|
|
}
|
|
|
|
// evictNode evicts a single node from the trie, freeing its snapshot memory.
|
|
func (c *kvCache) evictNode(node *trieNode) {
|
|
if len(node.children) == 0 {
|
|
// Leaf: remove entirely.
|
|
slog.Debug("evicting leaf", "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
|
removeNode(node, &c.pagedOutBytes)
|
|
} else if len(node.children) == 1 {
|
|
// Interior node with one child: merge with child.
|
|
before := c.pagedOutBytes
|
|
tokens := len(node.tokens)
|
|
mergeWithChild(node, c.caches, &c.pagedOutBytes)
|
|
slog.Debug("evicting interior node", "offset", node.startOffset(), "tokens", tokens, "freed", mlx.PrettyBytes(int(before-c.pagedOutBytes)))
|
|
} else {
|
|
panic("evictNode called on multi-child branch point")
|
|
}
|
|
}
|
|
|
|
func (c *kvCache) dumpTree() {
|
|
// Summary stats
|
|
var cacheBytes int
|
|
for _, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
for _, a := range kv.State() {
|
|
if a != nil {
|
|
cacheBytes += a.NumBytes()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build active path set for marking.
|
|
active := make(map[*trieNode]bool, len(c.activePath))
|
|
for _, n := range c.activePath {
|
|
active[n] = true
|
|
}
|
|
|
|
var nodeCount, snapshotCount int
|
|
var pagedBytes int64
|
|
var lines []string
|
|
var dump func(n *trieNode, prefix string, isLast bool)
|
|
dump = func(n *trieNode, prefix string, isLast bool) {
|
|
if n == nil {
|
|
return
|
|
}
|
|
nodeCount++
|
|
|
|
// Build connector
|
|
var connector string
|
|
if n.parent == nil {
|
|
connector = ""
|
|
} else if isLast {
|
|
connector = prefix + "`-- "
|
|
} else {
|
|
connector = prefix + "|-- "
|
|
}
|
|
|
|
// Node label
|
|
nodeBytes := n.snapshotBytes()
|
|
pagedBytes += nodeBytes
|
|
|
|
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
|
|
if nodeBytes > 0 {
|
|
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
|
}
|
|
if !n.lastUsed.IsZero() {
|
|
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
|
|
}
|
|
var flags []string
|
|
if n.user {
|
|
flags = append(flags, "user")
|
|
}
|
|
if n.hasAllSnapshots() {
|
|
snapshotCount++
|
|
flags = append(flags, "snap")
|
|
}
|
|
if active[n] {
|
|
flags = append(flags, "active")
|
|
}
|
|
if len(flags) > 0 {
|
|
label += " (" + flags[0]
|
|
for _, f := range flags[1:] {
|
|
label += ", " + f
|
|
}
|
|
label += ")"
|
|
}
|
|
lines = append(lines, connector+label)
|
|
|
|
// Recurse children
|
|
childPrefix := prefix
|
|
if n.parent != nil {
|
|
if isLast {
|
|
childPrefix += " "
|
|
} else {
|
|
childPrefix += "| "
|
|
}
|
|
}
|
|
for i, child := range n.children {
|
|
dump(child, childPrefix, i == len(n.children)-1)
|
|
}
|
|
}
|
|
dump(c.root, "", true)
|
|
|
|
offset := c.minCacheOffset()
|
|
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
|
|
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
|
|
for i, l := range lines {
|
|
if i == 0 {
|
|
logutil.Trace("cache trie: " + l)
|
|
} else {
|
|
logutil.Trace(" " + l)
|
|
}
|
|
}
|
|
}
|