Files
ollama/x/mlxrunner/cache.go
Jesse Gross 015546fded mlxrunner: schedule periodic snapshots during prefill
Add periodic snapshots every 8k tokens and near the end of the prompt
so that long prompts can be partially restored and thinking/generation
can be retried without full reprocessing.
2026-03-25 15:31:22 -07:00

610 lines
18 KiB
Go

// cache.go manages a shared KV cache across conversations using a compressed
// prefix trie. Each trie node stores a token sequence (edge) and optional
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
//
// Key properties:
// - Only one path through the trie is "active" (backed by live MLX arrays)
// at a time. Switching paths pages out the frontier node and pages in the
// new path.
// - Snapshots are only captured at the frontier (end) of the active path.
// Intermediate node snapshots come from split prefill.
// - All cache layers must stay at the same token offset.
// - Sibling edges must not share a common token prefix (compressed trie
// invariant).
// - begin() always re-evaluates at least one token so the pipeline can seed
// generation, even on a full prefix match.
package mlxrunner
import (
"cmp"
"fmt"
"log/slog"
"slices"
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
type kvCache struct {
root *trieNode // root of the prefix trie
activePath []*trieNode // current root→leaf path with live MLX arrays
caches []cache.Cache
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
}
// pendingSnapshot is a snapshot scheduled to be taken during prefill.
type pendingSnapshot struct {
offset int
user bool
}
// cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and
// defer close to save the cache state.
type cacheSession struct {
cache *kvCache
inputs []int32
outputs []int32
caches []cache.Cache
remaining []int32
// pendingSnapshots lists offsets where snapshots should be captured
// during prefill, sorted by offset. Entries are consumed as the
// cache advances past them.
pendingSnapshots []pendingSnapshot
}
func (c *kvCache) ensureCaches(m base.Model) {
if len(c.caches) != 0 {
return
}
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
return
}
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
}
func (c *kvCache) ensureRoot() {
if c.root == nil {
c.root = &trieNode{
lastUsed: time.Now(),
}
c.activePath = []*trieNode{c.root}
}
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
c.ensureCaches(m)
c.ensureRoot()
matchPath, matched := findBestMatch(c.root, inputs)
originalMatched := matched
// Always keep at least one token to re-evaluate so the
// pipeline can seed token generation from it.
if matched == len(inputs) && matched > 0 {
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
}
// Switch to the matched path, paging in/out as needed.
c.switchToPath(matchPath, matched)
// switchToPath aligns caches to a common offset
prefix := c.minCacheOffset()
remaining := inputs[prefix:]
session := &cacheSession{
cache: c,
inputs: inputs,
caches: c.caches,
remaining: remaining,
}
// Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating.
if prefix < matched {
session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false})
}
msg := "cache hit"
if prefix == 0 {
msg = "cache miss"
}
slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining))
return session
}
// switchToPath transitions from the current active path to a new path,
// paging out diverging segments and paging in the new path.
func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
defer c.enforceEvictionPolicy()
// Find common ancestor index.
commonLen := 0
for commonLen < len(c.activePath) && commonLen < len(newPath) {
if c.activePath[commonLen] != newPath[commonLen] {
break
}
commonLen++
}
ancestorOffset := 0
if commonLen > 0 {
ancestorOffset = c.activePath[commonLen-1].endOffset
}
var pageOutCount, pageInCount int
// Page out the leaf of the old path. Only the leaf's live cache
// state is correct — intermediate nodes already have snapshots
// captured during their creation (splitNode + prefill). Snapshotting
// non-leaf nodes here would produce wrong results for non-rewindable
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
// the intermediate boundary.
leaf := len(c.activePath) - 1
leafDiverges := leaf >= commonLen
leafNeedsRewind := matched < c.activePath[leaf].endOffset
if leafDiverges || leafNeedsRewind {
node := c.activePath[leaf]
if !node.hasAllSnapshots() {
fromOffset := node.startOffset()
snaps := make([]cache.Snapshot, len(c.caches))
for j, kv := range c.caches {
if kv == nil {
continue
}
snaps[j] = kv.Snapshot(fromOffset)
}
node.setSnapshots(snaps, &c.pagedOutBytes)
pageOutCount++
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
}
}
// Rewind each cache to the target offset or free it. When matched
// falls within the ancestor's range (same-path case), we rewind
// directly to the match point. Otherwise we rewind to the ancestor
// and let page-in bring us forward to matched.
rewindTarget := min(ancestorOffset, matched)
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.Restore(nil, rewindTarget) {
kv.Free()
}
}
// Page in — walk the full new path, restoring from snapshots.
// Freed caches naturally pick up the first available snapshot.
// Caches already past a node skip it via offset check.
pageIn:
for _, node := range newPath {
if !node.hasSnapshots() {
continue
}
nodeTarget := min(node.endOffset, matched)
for j, kv := range c.caches {
if kv == nil {
continue
}
if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue
}
if kv.Offset() >= nodeTarget {
continue
}
if !kv.Restore(node.snapshots[j], nodeTarget) {
// Restore failed — stop page-in and let alignment
// bring all caches to a consistent offset.
break pageIn
}
}
if node.endOffset > ancestorOffset {
pageInCount++
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
}
}
// Align all caches to the minimum offset.
c.activePath = newPath
minOff := c.minCacheOffset()
for _, kv := range c.caches {
if kv != nil && kv.Offset() != minOff {
if !kv.Restore(nil, minOff) {
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
c.freeAll()
break
}
}
}
for i := len(c.activePath) - 1; i >= 0; i-- {
if c.activePath[i].endOffset <= minOff {
c.activePath = c.activePath[:i+1]
break
}
}
// Update last-used time on only the final used node. For recurrent
// caches we don't need the intermediate snapshots and for KV caches
// we can reslice the data out of merged edges.
if len(c.activePath) > 0 {
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
}
if pageOutCount > 0 || pageInCount > 0 {
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
}
}
// requestSnapshot schedules a user snapshot at the given absolute token
// offset. The snapshot will be captured during prefill when the cache
// reaches this offset.
func (s *cacheSession) requestSnapshot(offset int) {
baseOffset := len(s.inputs) - len(s.remaining)
if offset <= baseOffset || offset > len(s.inputs) {
return
}
// Deduplicate: if this offset already exists, upgrade to user.
for i := range s.pendingSnapshots {
if s.pendingSnapshots[i].offset == offset {
s.pendingSnapshots[i].user = true
return
}
}
s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true})
slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int {
return a.offset - b.offset
})
}
// nextPendingSnapshot returns the offset of the next pending snapshot,
// or 0 if there are none.
func (s *cacheSession) nextPendingSnapshot() int {
if len(s.pendingSnapshots) == 0 {
return 0
}
return s.pendingSnapshots[0].offset
}
// snapshot creates a snapshot at the current cache position. It determines
// whether this is a user snapshot by consuming pending entries whose offset
// has been reached.
func (s *cacheSession) snapshot() {
c := s.cache
cacheOffset := c.minCacheOffset()
if cacheOffset <= 0 {
return
}
// Consume pending snapshots up to the current offset and derive
// the user flag from them.
user := false
for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset {
if s.pendingSnapshots[0].user {
user = true
}
s.pendingSnapshots = s.pendingSnapshots[1:]
}
// The last node in activePath is the frontier where caches are advancing.
// cacheOffset is always >= its endOffset: begin() restores caches to this
// boundary and prefill advances monotonically forward.
frontier := c.activePath[len(c.activePath)-1]
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
if frontier.endOffset == cacheOffset {
if user {
frontier.user = true
}
if !frontier.hasAllSnapshots() {
s.attachSnapshots(frontier, cacheOffset)
}
return
}
if frontier.endOffset > cacheOffset {
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
return
}
// Advance the trie to cacheOffset — find or create a node there.
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
// Attach fresh snapshots from the live caches. Always use fresh
// snapshots even if the node already has some (e.g. from splitNode's
// Cache.Split which may be incomplete for non-splittable caches
// like RecurrentCache).
if user {
frontier.user = true
}
s.attachSnapshots(frontier, cacheOffset)
}
// advancePath advances the active path from the current frontier by matching
// tokens against existing trie children, splitting partial matches, and
// appending any remaining tokens as new nodes. Returns the new frontier.
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
// Check if existing children already cover some or all of tokens.
// tokens may span multiple trie nodes when extending a previous run's
// leaf and this snapshot now overlaps that same range.
matchPath, matched := findBestMatch(frontier, tokens)
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
remaining := tokens[matched:]
// Check for a partial match within the last node's edge — if so, split it.
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
}
}
// Append traversed nodes (excluding frontier) to the active path.
c.activePath = append(c.activePath, matchPath[1:]...)
dest := matchPath[len(matchPath)-1]
if len(remaining) > 0 {
// Drop non-user snapshots so appendTokens can extend in-place
// rather than creating a new child node.
if len(dest.children) == 0 && !dest.user {
dest.setSnapshots(nil, &c.pagedOutBytes)
}
newDest := dest.appendTokens(c.root, remaining, endOffset)
if newDest != dest {
c.activePath = append(c.activePath, newDest)
}
dest = newDest
}
return dest
}
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
// The node must be on the active path (and thus protected from eviction;
// lastUsed is updated in close()). All non-nil caches must be at the same
// offset (cacheOffset); a mismatch indicates a bug in the caller.
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
c := s.cache
if c.activePath[len(c.activePath)-1] != node {
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
return
}
snaps := make([]cache.Snapshot, len(c.caches))
for i, kv := range c.caches {
if kv != nil {
if kv.Offset() != cacheOffset {
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
}
snaps[i] = kv.Snapshot(node.startOffset())
}
}
node.setSnapshots(snaps, &c.pagedOutBytes)
node.lastUsed = time.Now()
slog.Debug("created snapshot", "offset", cacheOffset)
c.enforceEvictionPolicy()
}
// freeAll releases all cache layers.
func (c *kvCache) freeAll() {
for _, kv := range c.caches {
if kv != nil {
kv.Free()
}
}
}
func (c *kvCache) minCacheOffset() int {
offset := 0
found := false
for _, kv := range c.caches {
if kv == nil {
continue
}
if off := kv.Offset(); !found || off < offset {
offset = off
found = true
}
}
return offset
}
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
offset := s.cache.minCacheOffset()
if offset <= 0 {
return
}
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, kv := range s.caches {
if kv == nil {
continue
}
arrays = append(arrays, kv.State()...)
}
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data.
mlx.AsyncEval(arrays...)
// Advance the trie frontier with any newly generated tokens.
c := s.cache
if len(c.activePath) > 0 {
frontier := c.activePath[len(c.activePath)-1]
stored := append(s.inputs, s.outputs...)
if offset > frontier.endOffset {
newTokens := stored[frontier.endOffset:offset]
c.advancePath(frontier, newTokens, offset)
}
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
}
}
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
func (c *kvCache) enforceEvictionPolicy() {
if c.pagedOutBytes <= maxPagedOutBytes {
return
}
activeSet := make(map[*trieNode]bool, len(c.activePath))
for _, n := range c.activePath {
activeSet[n] = true
}
for c.pagedOutBytes > maxPagedOutBytes {
var best *trieNode
walkNodes(c.root, func(n *trieNode) bool {
if n == c.root || activeSet[n] || len(n.children) > 1 {
return true
}
// Evict: oldest, then deepest, then largest.
if best == nil || cmp.Or(
n.lastUsed.Compare(best.lastUsed),
cmp.Compare(best.endOffset, n.endOffset),
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
) < 0 {
best = n
}
return true
})
if best == nil {
break
}
c.evictNode(best)
}
}
// evictNode evicts a single node from the trie, freeing its snapshot memory.
func (c *kvCache) evictNode(node *trieNode) {
if len(node.children) == 0 {
// Leaf: remove entirely.
slog.Debug("evicting leaf", "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
removeNode(node, &c.pagedOutBytes)
} else if len(node.children) == 1 {
// Interior node with one child: merge with child.
before := c.pagedOutBytes
tokens := len(node.tokens)
mergeWithChild(node, c.caches, &c.pagedOutBytes)
slog.Debug("evicting interior node", "offset", node.startOffset(), "tokens", tokens, "freed", mlx.PrettyBytes(int(before-c.pagedOutBytes)))
} else {
panic("evictNode called on multi-child branch point")
}
}
func (c *kvCache) dumpTree() {
// Summary stats
var cacheBytes int
for _, kv := range c.caches {
if kv == nil {
continue
}
for _, a := range kv.State() {
if a != nil {
cacheBytes += a.NumBytes()
}
}
}
// Build active path set for marking.
active := make(map[*trieNode]bool, len(c.activePath))
for _, n := range c.activePath {
active[n] = true
}
var nodeCount, snapshotCount int
var pagedBytes int64
var lines []string
var dump func(n *trieNode, prefix string, isLast bool)
dump = func(n *trieNode, prefix string, isLast bool) {
if n == nil {
return
}
nodeCount++
// Build connector
var connector string
if n.parent == nil {
connector = ""
} else if isLast {
connector = prefix + "`-- "
} else {
connector = prefix + "|-- "
}
// Node label
nodeBytes := n.snapshotBytes()
pagedBytes += nodeBytes
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
if nodeBytes > 0 {
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
}
if !n.lastUsed.IsZero() {
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
}
var flags []string
if n.user {
flags = append(flags, "user")
}
if n.hasAllSnapshots() {
snapshotCount++
flags = append(flags, "snap")
}
if active[n] {
flags = append(flags, "active")
}
if len(flags) > 0 {
label += " (" + flags[0]
for _, f := range flags[1:] {
label += ", " + f
}
label += ")"
}
lines = append(lines, connector+label)
// Recurse children
childPrefix := prefix
if n.parent != nil {
if isLast {
childPrefix += " "
} else {
childPrefix += "| "
}
}
for i, child := range n.children {
dump(child, childPrefix, i == len(n.children)-1)
}
}
dump(c.root, "", true)
offset := c.minCacheOffset()
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
for i, l := range lines {
if i == 0 {
logutil.Trace("cache trie: " + l)
} else {
logutil.Trace(" " + l)
}
}
}