From 96e36c0d90b1da23304658a2ba90784b4a1c822d Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 5 Mar 2026 15:45:36 -0800 Subject: [PATCH] mlxrunner: share KV cache across conversations with common prefixes Enable multiple conversations to reuse cached computations when they share token prefixes (e.g. the same system prompt). A prefix trie tracks shared regions so switching between conversations only recomputes tokens that diverge. Inactive conversation state is paged from active GPU memory to other memory and restored on demand, with LRU eviction to keep memory usage bounded. --- x/mlxrunner/cache.go | 629 ++++++++++++++++---- x/mlxrunner/cache/cache.go | 300 ++++++++-- x/mlxrunner/cache/cache_test.go | 271 +++++++++ x/mlxrunner/cache/recurrent.go | 88 ++- x/mlxrunner/cache/recurrent_test.go | 44 ++ x/mlxrunner/cache_test.go | 859 ++++++++++++++++++++++++++++ x/mlxrunner/cache_trie.go | 296 ++++++++++ x/mlxrunner/cache_trie_test.go | 455 +++++++++++++++ x/mlxrunner/mlx/mlx.go | 4 + x/mlxrunner/pipeline.go | 25 +- 10 files changed, 2768 insertions(+), 203 deletions(-) create mode 100644 x/mlxrunner/cache/cache_test.go create mode 100644 x/mlxrunner/cache/recurrent_test.go create mode 100644 x/mlxrunner/cache_test.go create mode 100644 x/mlxrunner/cache_trie.go create mode 100644 x/mlxrunner/cache_trie_test.go diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index 0216ffeaa..a5709101d 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -1,8 +1,26 @@ +// cache.go manages a shared KV cache across conversations using a compressed +// prefix trie. Each trie node stores a token sequence (edge) and optional +// per-layer snapshots that can be paged in/out of the live MLX cache arrays. +// +// Key properties: +// - Only one path through the trie is "active" (backed by live MLX arrays) +// at a time. Switching paths pages out the frontier node and pages in the +// new path. +// - Snapshots are only captured at the frontier (end) of the active path. +// Intermediate node snapshots come from split prefill. +// - All cache layers must stay at the same token offset. +// - Sibling edges must not share a common token prefix (compressed trie +// invariant). +// - begin() always re-evaluates at least one token so the pipeline can seed +// generation, even on a full prefix match. + package mlxrunner import ( + "cmp" "fmt" "log/slog" + "time" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/cache" @@ -10,10 +28,13 @@ import ( "github.com/ollama/ollama/x/mlxrunner/model/base" ) +const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory + type kvCache struct { - // For now we only support a single entry, so this is just one sequence - tokens []int32 - caches []cache.Cache + root *trieNode // root of the prefix trie + activePath []*trieNode // current root→leaf path with live MLX arrays + caches []cache.Cache + pagedOutBytes int64 // total bytes in paged-out snapshots across the trie } // cacheSession manages caches for a single pipeline run. @@ -26,176 +47,538 @@ type cacheSession struct { caches []cache.Cache remaining []int32 + + // snapshotOffset, if > 0, is a trie node boundary where we need to + // capture a snapshot during prefill. This enables future requests + // branching at this point to restore non-rewindable caches (e.g. + // RecurrentCache) instead of re-evaluating from scratch. + snapshotOffset int } -func appendCacheState(dst []*mlx.Array, c cache.Cache) []*mlx.Array { - if c == nil { - return dst +func (c *kvCache) ensureCaches(m base.Model) { + if len(c.caches) != 0 { + return } - - keys, values := c.State() - if keys != nil && keys.Valid() { - dst = append(dst, keys) + if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { + c.caches = cacheFactory.NewCaches() + return } - if values != nil && values.Valid() { - dst = append(dst, values) + c.caches = make([]cache.Cache, m.NumLayers()) + for i := range c.caches { + c.caches[i] = cache.NewKVCache() } - - return dst } -func (c *kvCache) free() { - for i, kv := range c.caches { - if kv == nil { - continue +func (c *kvCache) ensureRoot() { + if c.root == nil { + c.root = &trieNode{ + lastUsed: time.Now(), } - kv.Free() - c.caches[i] = nil - } - c.caches = nil - c.tokens = nil -} - -func (c *kvCache) cachesCanTrim() bool { - for _, kv := range c.caches { - if kv == nil { - continue - } - if !kv.CanTrim() { - return false - } - } - return true -} - -func (c *kvCache) trimToPrefix(prefix int) { - for _, kv := range c.caches { - if kv == nil || !kv.CanTrim() { - continue - } - if trim := kv.Offset() - prefix; trim > 0 { - kv.Trim(trim) - } - } - if prefix < len(c.tokens) { - c.tokens = c.tokens[:prefix] + c.activePath = []*trieNode{c.root} } } // begin prepares caches for a new request. It finds the nearest // matching cache or creates new caches if none match. func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { - ensureCaches := func() { - if len(c.caches) != 0 { - return - } - if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { - c.caches = cacheFactory.NewCaches() - return - } - c.caches = make([]cache.Cache, m.NumLayers()) - for i := range c.caches { - c.caches[i] = cache.NewKVCache() + c.ensureCaches(m) + c.ensureRoot() + + matchPath, matched := findBestMatch(c.root, inputs) + originalMatched := matched + + // Always keep at least one token to re-evaluate so the + // pipeline can seed token generation from it. + if matched == len(inputs) && matched > 0 { + matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1]) + } + + // Check for partial match within a node's edge — truncate path + // to the parent boundary. snapshot() will split the node and + // create the branch point during prefill when caches are ready. + partialMatch := false + if len(matchPath) > 1 { + lastNode := matchPath[len(matchPath)-1] + matchedInEdge := matched - lastNode.startOffset() + if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) { + matchPath = matchPath[:len(matchPath)-1] + partialMatch = true } } - ensureCaches() - remaining := c.findRemaining(inputs) - ensureCaches() + // Switch to the matched path, paging in/out as needed. + c.switchToPath(matchPath) + + // switchToPath aligns caches to a common offset + prefix := c.minCacheOffset() + remaining := inputs[prefix:] + + // Schedule a snapshot at the branch point during prefill so future + // requests diverging here can restore instead of re-evaluating. + var snapshotAt int + if partialMatch || (prefix == 0 && matched > 0) { + snapshotAt = matched + } + + args := []any{"total", len(inputs), "matched", originalMatched} + args = append(args, "cached", prefix, "left", len(remaining)) + if snapshotAt > 0 { + args = append(args, "pending_snapshot", snapshotAt) + } + if prefix == 0 { + slog.Info("cache miss", args...) + } else { + slog.Info("cache hit", args...) + } return &cacheSession{ - cache: c, - inputs: inputs, - caches: c.caches, - remaining: remaining, + cache: c, + inputs: inputs, + snapshotOffset: snapshotAt, + caches: c.caches, + remaining: remaining, } } +// switchToPath transitions from the current active path to a new path, +// paging out diverging segments and paging in the new path. +func (c *kvCache) switchToPath(newPath []*trieNode) { + defer c.enforceEvictionPolicy() + + // Find common ancestor index. + commonLen := 0 + for commonLen < len(c.activePath) && commonLen < len(newPath) { + if c.activePath[commonLen] != newPath[commonLen] { + break + } + commonLen++ + } + + ancestorOffset := 0 + if commonLen > 0 { + ancestorOffset = c.activePath[commonLen-1].endOffset + } + + var pageOutCount, pageInCount int + + // Page out the leaf of the old path. Only the leaf's live cache + // state is correct — intermediate nodes already have snapshots + // captured during their creation (splitNode + prefill). Snapshotting + // non-leaf nodes here would produce wrong results for non-rewindable + // caches (e.g. RecurrentCache) whose state reflects the leaf, not + // the intermediate boundary. + if leaf := len(c.activePath) - 1; leaf >= commonLen { + node := c.activePath[leaf] + if !node.hasAllSnapshots() { + fromOffset := node.startOffset() + snaps := make([]cache.Snapshot, len(c.caches)) + for j, kv := range c.caches { + if kv == nil { + continue + } + snaps[j] = kv.Snapshot(fromOffset) + } + node.setSnapshots(snaps, &c.pagedOutBytes) + pageOutCount++ + logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset)) + } + } + + // Rewind each cache to the ancestor offset or free it. Freed + // caches (e.g. RecurrentCache that can't rewind) will be restored + // from snapshots during page-in. + for _, kv := range c.caches { + if kv == nil { + continue + } + if !kv.Restore(nil, ancestorOffset) { + kv.Free() + } + } + + // Page in — walk the full new path, restoring from snapshots. + // Freed caches naturally pick up the first available snapshot. + // Caches already past a node skip it via offset check. + for _, node := range newPath { + if len(node.snapshots) == 0 { + continue + } + for j, kv := range c.caches { + if kv == nil { + continue + } + if j >= len(node.snapshots) || node.snapshots[j] == nil { + continue + } + if kv.Offset() >= node.endOffset { + continue + } + if !kv.Restore(node.snapshots[j], node.endOffset) { + slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset()) + c.freeAll() + c.activePath = []*trieNode{c.root} + return + } + } + if node.endOffset > ancestorOffset { + pageInCount++ + logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset)) + } + } + + // Align all caches to the minimum offset. + c.activePath = newPath + minOff := c.minCacheOffset() + for _, kv := range c.caches { + if kv != nil && kv.Offset() != minOff { + if !kv.Restore(nil, minOff) { + slog.Warn("failed to restore cache, freeing all caches", "offset", minOff) + c.freeAll() + break + } + } + } + for i := len(c.activePath) - 1; i >= 0; i-- { + if c.activePath[i].endOffset <= minOff { + c.activePath = c.activePath[:i+1] + break + } + } + + if pageOutCount > 0 || pageInCount > 0 { + slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount) + } +} + +// snapshot creates a snapshot at the current cache position. During prefill, +// it is called at branch points (user=false) to create restore points for +// future diverging requests and with user=true to mark an explicit reusable +// restore point. +func (s *cacheSession) snapshot(user bool) { + c := s.cache + cacheOffset := c.minCacheOffset() + if cacheOffset <= 0 { + return + } + + // Clear pending intermediate snapshot if we've reached or passed it. + if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset { + s.snapshotOffset = 0 + } + + // The last node in activePath is the frontier where caches are advancing. + // cacheOffset is always >= its endOffset: begin() restores caches to this + // boundary and prefill advances monotonically forward. + frontier := c.activePath[len(c.activePath)-1] + + // If the frontier already ends at cacheOffset, just ensure it has snapshots. + if frontier.endOffset == cacheOffset { + if user { + frontier.user = true + } + if !frontier.hasAllSnapshots() { + s.attachSnapshots(frontier, cacheOffset) + } + return + } + + if frontier.endOffset > cacheOffset { + slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset) + return + } + + // Advance the trie to cacheOffset — find or create a node there. + edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset] + frontier = c.advancePath(frontier, edgeTokens, cacheOffset) + + // Attach fresh snapshots from the live caches. Always use fresh + // snapshots even if the node already has some (e.g. from splitNode's + // Cache.Split which may be incomplete for non-splittable caches + // like RecurrentCache). + if user { + frontier.user = true + } + s.attachSnapshots(frontier, cacheOffset) +} + +// advancePath advances the active path from the current frontier by matching +// tokens against existing trie children, splitting partial matches, and +// appending any remaining tokens as new nodes. Returns the new frontier. +func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode { + // Check if existing children already cover some or all of tokens. + // tokens may span multiple trie nodes when extending a previous run's + // leaf and this snapshot now overlaps that same range. + matchPath, matched := findBestMatch(frontier, tokens) + // matchPath[0] is frontier itself; the rest are newly traversed nodes. + remaining := tokens[matched:] + + // Check for a partial match within the last node's edge — if so, split it. + if len(matchPath) > 1 { + lastNode := matchPath[len(matchPath)-1] + matchedInEdge := frontier.endOffset + matched - lastNode.startOffset() + if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) { + matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes) + } + } + + // Append traversed nodes (excluding frontier) to the active path. + c.activePath = append(c.activePath, matchPath[1:]...) + dest := matchPath[len(matchPath)-1] + + if len(remaining) > 0 { + // Drop non-user snapshots so appendTokens can extend in-place + // rather than creating a new child node. + if len(dest.children) == 0 && !dest.user { + dest.setSnapshots(nil, &c.pagedOutBytes) + } + newDest := dest.appendTokens(c.root, remaining, endOffset) + if newDest != dest { + c.activePath = append(c.activePath, newDest) + } + dest = newDest + } + return dest +} + +// attachSnapshots attaches cache snapshots to a trie node at the given offset. +// The node must be on the active path (and thus protected from eviction; +// lastUsed is updated in close()). All non-nil caches must be at the same +// offset (cacheOffset); a mismatch indicates a bug in the caller. +func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) { + c := s.cache + + if c.activePath[len(c.activePath)-1] != node { + slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset) + return + } + + snaps := make([]cache.Snapshot, len(c.caches)) + for i, kv := range c.caches { + if kv != nil { + if kv.Offset() != cacheOffset { + panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset())) + } + snaps[i] = kv.Snapshot(node.startOffset()) + } + } + node.setSnapshots(snaps, &c.pagedOutBytes) + slog.Debug("created snapshot", "offset", cacheOffset) + c.enforceEvictionPolicy() +} + +// freeAll releases all cache layers. +func (c *kvCache) freeAll() { + for _, kv := range c.caches { + if kv != nil { + kv.Free() + } + } +} + +func (c *kvCache) minCacheOffset() int { + offset := 0 + found := false + for _, kv := range c.caches { + if kv == nil { + continue + } + if off := kv.Offset(); !found || off < offset { + offset = off + found = true + } + } + return offset +} + // close saves the token state if the forward pass ran. func (s *cacheSession) close() { - if len(s.caches) == 0 { + offset := s.cache.minCacheOffset() + if offset <= 0 { return } - offset := -1 arrays := make([]*mlx.Array, 0, 2*len(s.caches)) for _, kv := range s.caches { if kv == nil { continue } - // Mixed cache types (e.g. recurrent + KV) can transiently report different - // offsets, so use the minimum as the safe reusable token prefix. - if off := kv.Offset(); offset < 0 || off < offset { - offset = off - } - arrays = appendCacheState(arrays, kv) - } - if offset <= 0 { - return + arrays = append(arrays, kv.State()...) } // Ensure that if we have run the forward pass and set the metadata // that we also actually have the data. mlx.AsyncEval(arrays...) - stored := append(s.inputs, s.outputs...) - if offset > len(stored) { - offset = len(stored) - } - s.cache.tokens = stored[:offset] -} + // Advance the trie frontier with any newly generated tokens. + c := s.cache + if len(c.activePath) > 0 { + frontier := c.activePath[len(c.activePath)-1] + stored := append(s.inputs, s.outputs...) -// findRemaining finds the longest common prefix between tokens and the cached -// sequence, trims stale cache entries, and returns the remaining tokens. -func (c *kvCache) findRemaining(tokens []int32) []int32 { - prefix := 0 - for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] { - prefix++ - } - - // Always keep at least one token to re-evaluate so the - // pipeline can seed token generation from it. - if prefix == len(tokens) && prefix > 0 { - prefix-- - } - - if prefix < len(c.tokens) { - if c.cachesCanTrim() { - c.trimToPrefix(prefix) - } else { - c.free() - slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence") - return tokens + if offset > frontier.endOffset { + newTokens := stored[frontier.endOffset:offset] + c.advancePath(frontier, newTokens, offset) + } + now := time.Now() + for _, node := range c.activePath { + node.lastUsed = now } } - - if prefix == 0 { - slog.Info("Cache miss", "left", len(tokens)) - } else { - slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:])) - } - return tokens[prefix:] } -func (c *kvCache) log() { - if len(c.caches) == 0 { +// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits. +func (c *kvCache) enforceEvictionPolicy() { + if c.pagedOutBytes <= maxPagedOutBytes { return } - offset := -1 - var totalBytes int + + activeSet := make(map[*trieNode]bool, len(c.activePath)) + for _, n := range c.activePath { + activeSet[n] = true + } + + for c.pagedOutBytes > maxPagedOutBytes { + var best *trieNode + walkNodes(c.root, func(n *trieNode) bool { + if n == c.root || activeSet[n] || !n.hasSnapshots() { + return true + } + // Evict: oldest, then deepest, then largest. + if best == nil || cmp.Or( + n.lastUsed.Compare(best.lastUsed), + cmp.Compare(best.endOffset, n.endOffset), + cmp.Compare(best.snapshotBytes(), n.snapshotBytes()), + ) < 0 { + best = n + } + return true + }) + if best == nil { + break + } + c.evictNode(best) + } +} + +// evictNode evicts a single node from the trie, freeing its snapshot memory. +func (c *kvCache) evictNode(node *trieNode) { + if len(node.children) == 0 { + // Leaf: remove entirely. + parent := node.parent + kind := "evicting leaf" + if node.user { + kind = "evicting user snapshot" + } + slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) + removeNode(node, &c.pagedOutBytes) + + // If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge. + if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root { + logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset)) + mergeWithChild(parent, c.caches, &c.pagedOutBytes) + } + } else if len(node.children) == 1 { + // Interior snapshot node with one child: merge with child. + slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) + mergeWithChild(node, c.caches, &c.pagedOutBytes) + } else { + // Multi-child branch point: drop snapshots but keep the node. + slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) + node.setSnapshots(nil, &c.pagedOutBytes) + } +} + +func (c *kvCache) dumpTree() { + // Summary stats + var cacheBytes int for _, kv := range c.caches { if kv == nil { continue } - if off := kv.Offset(); offset < 0 || off < offset { - offset = off - } - for _, a := range appendCacheState(nil, kv) { - totalBytes += a.NumBytes() + for _, a := range kv.State() { + if a != nil { + cacheBytes += a.NumBytes() + } } } - if offset < 0 { - return + + // Build active path set for marking. + active := make(map[*trieNode]bool, len(c.activePath)) + for _, n := range c.activePath { + active[n] = true + } + + var nodeCount, snapshotCount int + var pagedBytes int64 + var lines []string + var dump func(n *trieNode, prefix string, isLast bool) + dump = func(n *trieNode, prefix string, isLast bool) { + if n == nil { + return + } + nodeCount++ + + // Build connector + var connector string + if n.parent == nil { + connector = "" + } else if isLast { + connector = prefix + "`-- " + } else { + connector = prefix + "|-- " + } + + // Node label + nodeBytes := n.snapshotBytes() + pagedBytes += nodeBytes + + label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens)) + if nodeBytes > 0 { + label += " " + mlx.PrettyBytes(int(nodeBytes)).String() + } + var flags []string + if n.user { + flags = append(flags, "user") + } + if n.hasAllSnapshots() { + snapshotCount++ + flags = append(flags, "snap") + } + if active[n] { + flags = append(flags, "active") + } + if len(flags) > 0 { + label += " (" + flags[0] + for _, f := range flags[1:] { + label += ", " + f + } + label += ")" + } + lines = append(lines, connector+label) + + // Recurse children + childPrefix := prefix + if n.parent != nil { + if isLast { + childPrefix += " " + } else { + childPrefix += "| " + } + } + for i, child := range n.children { + dump(child, childPrefix, i == len(n.children)-1) + } + } + dump(c.root, "", true) + + offset := c.minCacheOffset() + logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d", + offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount)) + for i, l := range lines { + if i == 0 { + logutil.Trace("cache trie: " + l) + } else { + logutil.Trace(" " + l) + } } - logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes))) } diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index a452fbcb2..8e024115d 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -8,13 +8,34 @@ import ( type Cache interface { Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array) // State returns the cache-owned state roots that should be kept/evaluated. - State() (keys, values *mlx.Array) - CanTrim() bool - Trim(int) int - Clone() Cache + State() []*mlx.Array Free() Offset() int - Len() int + + // Snapshot copies cache state from fromOffset to current offset into + // pinned VRAM arrays. The active cache is unchanged. + Snapshot(fromOffset int) Snapshot + + // Restore brings the cache to target. If snapshot is nil, rewinds + // using the cache's own live state. + Restore(snapshot Snapshot, target int) bool + + // Merge combines two sequential snapshots [a,b) and [b,c) into [a,c). + // Takes ownership of both inputs. + Merge(parent, child Snapshot) Snapshot + + // Split divides a snapshot [a,c) at offset b into [a,b) and [b,c). + // Takes ownership of the input. Cache types that cannot split + // (e.g. recurrent) return (nil, snapshot). + Split(snapshot Snapshot, at int) (parent, child Snapshot) +} + +// Snapshot is paged-out cache state that can be restored later. +type Snapshot interface { + // Size returns the byte size of the paged-out data (in VRAM). + Size() int + // Close unpins the snapshot's arrays so they can be freed by Sweep. + Close() } type KVCache struct { @@ -59,40 +80,148 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) } -func (c *KVCache) State() (*mlx.Array, *mlx.Array) { +func (c *KVCache) State() []*mlx.Array { if c.keys == nil || c.values == nil { + return nil + } + return []*mlx.Array{ + c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), + c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), + } +} + +// kvSnapshot holds paged-out KV data for a range [fromOffset, toOffset). +type kvSnapshot struct { + keys, values *mlx.Array + fromOffset, toOffset int +} + +func (s *kvSnapshot) Size() int { return s.keys.NumBytes() + s.values.NumBytes() } +func (s *kvSnapshot) Close() { mlx.Unpin(s.keys, s.values) } + +func (c *KVCache) Snapshot(fromOffset int) Snapshot { + if c.keys == nil || c.offset <= fromOffset { + return nil + } + from := max(0, fromOffset) + to := c.offset + + kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice()) + vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice()) + kCopy := mlx.Copy(kSlice) + vCopy := mlx.Copy(vSlice) + mlx.Pin(kCopy, vCopy) + mlx.AsyncEval(kCopy, vCopy) + + return &kvSnapshot{ + keys: kCopy, + values: vCopy, + fromOffset: from, + toOffset: to, + } +} + +func (c *KVCache) Restore(snapshot Snapshot, target int) bool { + if snapshot == nil { + // Rewind using live state — just clamp offset. + target = max(0, min(target, c.offset)) + c.offset = target + return true + } + + snap := snapshot.(*kvSnapshot) + + // Check that the cache has data up to the snapshot's starting point. + if c.offset < snap.fromOffset { + return false + } + + // Rewind to snapshot start, then feed snapshot data through Update. + c.offset = snap.fromOffset + c.Update(snap.keys, snap.values) + + // Clamp to target if needed (target may be less than full snapshot). + if target < c.offset { + c.offset = target + } + + return true +} + +func (c *KVCache) Merge(parent, child Snapshot) Snapshot { + if parent == nil || child == nil { + if parent != nil { + parent.Close() + } + if child != nil { + child.Close() + } + return nil + } + p := parent.(*kvSnapshot) + ch := child.(*kvSnapshot) + + mk := p.keys.Concatenate(2, ch.keys) + mv := p.values.Concatenate(2, ch.values) + mlx.Pin(mk, mv) + mlx.AsyncEval(mk, mv) + + p.Close() + ch.Close() + + return &kvSnapshot{ + keys: mk, + values: mv, + fromOffset: p.fromOffset, + toOffset: ch.toOffset, + } +} + +func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) { + if snapshot == nil { return nil, nil } - return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), - c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) -} - -func (c *KVCache) CanTrim() bool { return true } - -func (c *KVCache) Trim(n int) int { - n = min(c.offset, n) - c.offset -= n - return n -} - -func (c *KVCache) Clone() Cache { - clone := &KVCache{ - keys: c.keys.Clone(), - values: c.values.Clone(), - offset: c.offset, - step: c.step, + snap := snapshot.(*kvSnapshot) + splitIdx := at - snap.fromOffset + seqLen := snap.toOffset - snap.fromOffset + if splitIdx <= 0 { + return nil, snapshot } - mlx.Pin(clone.keys, clone.values) - return clone + if splitIdx >= seqLen { + return snapshot, nil + } + + pk := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice())) + pv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice())) + ck := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice())) + cv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice())) + mlx.Pin(pk, pv, ck, cv) + mlx.AsyncEval(pk, pv, ck, cv) + + snap.Close() + + p := &kvSnapshot{ + keys: pk, + values: pv, + fromOffset: snap.fromOffset, + toOffset: at, + } + ch := &kvSnapshot{ + keys: ck, + values: cv, + fromOffset: at, + toOffset: snap.toOffset, + } + return p, ch } func (c *KVCache) Free() { mlx.Unpin(c.keys, c.values) c.keys, c.values = nil, nil + c.offset = 0 } func (c *KVCache) Offset() int { return c.offset } -func (c *KVCache) Len() int { return c.offset } // RotatingKVCache implements sliding window attention with bounded memory type RotatingKVCache struct { @@ -184,29 +313,104 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()) } -func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) { +func (c *RotatingKVCache) State() []*mlx.Array { if c.keys == nil || c.values == nil { - return nil, nil + return nil } - return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), - c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) -} - -func (c *RotatingKVCache) CanTrim() bool { return true } - -func (c *RotatingKVCache) Trim(n int) int { - n = min(c.offset, n) - c.offset -= n - c.idx -= n - return n -} - -func (c *RotatingKVCache) Clone() Cache { - return &RotatingKVCache{ - maxSize: c.maxSize, - idx: c.idx, - KVCache: c.KVCache.Clone().(*KVCache), + return []*mlx.Array{ + c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), + c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), } } -func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) } +// rotatingSnapshot holds paged-out data for a RotatingKVCache. +type rotatingSnapshot struct { + kvSnapshot // embedded KV data + idx int // buffer write position at snapshot time +} + +func (s *rotatingSnapshot) Size() int { return s.kvSnapshot.Size() } +func (s *rotatingSnapshot) Close() { s.kvSnapshot.Close() } + +func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot { + if c.keys == nil || c.offset <= fromOffset { + return nil + } + + state := c.State() + k := state[0].Clone() + v := state[1].Clone() + mlx.Pin(k, v) + + return &rotatingSnapshot{ + kvSnapshot: kvSnapshot{ + keys: k, + values: v, + fromOffset: fromOffset, + toOffset: c.offset, + }, + idx: c.idx, + } +} + +func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool { + if snapshot == nil { + // Live rewind is only safe when the buffer hasn't filled yet + // (offset <= maxSize). Once the window has shifted, rewinding + // leaves fewer than maxSize trailing tokens to attend to — + // a snapshot is required to restore the full window. + if c.offset > c.maxSize { + return false + } + target = max(0, min(target, c.offset)) + c.offset = target + c.idx = target + return true + } + + snap := snapshot.(*rotatingSnapshot) + + // Reject if clamping would leave an incomplete window. + if target < snap.toOffset && snap.toOffset > c.maxSize { + return false + } + + // Restore from snapshot: rebuild buffer state. + // Free existing state first. + if c.keys != nil { + mlx.Unpin(c.keys, c.values) + } + c.keys = snap.keys.Clone() + c.values = snap.values.Clone() + mlx.Pin(c.keys, c.values) + c.offset = snap.toOffset + c.idx = snap.idx + + // Clamp to target if needed. + if target < c.offset { + target = max(0, target) + c.offset = target + c.idx = target + } + return true +} + +func (c *RotatingKVCache) Merge(parent, child Snapshot) Snapshot { + // For rotating caches, the child snapshot supersedes the parent + // since it contains the full window state. + if parent != nil { + parent.Close() + } + return child +} + +func (c *RotatingKVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) { + // Rotating cache snapshots contain the full window state. + // Cannot cleanly split a ring buffer at an arbitrary point. + return nil, snapshot +} + +func (c *RotatingKVCache) Free() { + c.KVCache.Free() + c.idx = 0 +} diff --git a/x/mlxrunner/cache/cache_test.go b/x/mlxrunner/cache/cache_test.go new file mode 100644 index 000000000..86c26004a --- /dev/null +++ b/x/mlxrunner/cache/cache_test.go @@ -0,0 +1,271 @@ +package cache + +import ( + "testing" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := mlx.CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } +} + +func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) { + skipIfNoMLX(t) + c := NewKVCache() + + for range 10 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + + // Snapshot [5, 10). + snap := c.Snapshot(5) + + // Free the cache completely — offset is now 0. + c.Free() + + // Restore should fail because cache doesn't have data up to fromOffset=5. + if c.Restore(snap, 10) { + t.Fatal("expected Restore to fail with no base data") + } +} + +// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data +// is preserved through a snapshot→free→restore cycle. +func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) { + skipIfNoMLX(t) + c := NewKVCache() + + for range 10 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + + snap := c.Snapshot(0) + if snap == nil { + t.Fatal("Snapshot returned nil") + } + + // Free and restore to a fresh cache. + c2 := NewKVCache() + if !c2.Restore(snap, 10) { + t.Fatal("Restore failed") + } + if c2.Offset() != 10 { + t.Fatalf("offset = %d, want 10", c2.Offset()) + } + + // Verify State() returns arrays with correct sequence dimension. + state := c2.State() + if len(state) != 2 { + t.Fatalf("State() returned %d arrays, want 2", len(state)) + } + // keys shape: [B, H, seqLen, Dk] + if state[0].Dim(2) != 10 { + t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2)) + } + if state[1].Dim(2) != 10 { + t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2)) + } +} + +// TestKVCacheSplitPreservesData verifies that split produces two snapshots +// that can be sequentially restored to rebuild the original cache state. +func TestKVCacheSplitPreservesData(t *testing.T) { + skipIfNoMLX(t) + c := NewKVCache() + + for range 10 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + + snap := c.Snapshot(0) + parent, child := c.Split(snap, 5) + if parent == nil || child == nil { + t.Fatal("Split returned nil") + } + + // Restore parent → offset=5, seq dim=5. + c2 := NewKVCache() + if !c2.Restore(parent, 5) { + t.Fatal("Restore(parent) failed") + } + if c2.Offset() != 5 { + t.Fatalf("offset after parent = %d, want 5", c2.Offset()) + } + state := c2.State() + if state[0].Dim(2) != 5 { + t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2)) + } + + // Restore child on top → offset=10, seq dim=10. + if !c2.Restore(child, 10) { + t.Fatal("Restore(child) failed") + } + if c2.Offset() != 10 { + t.Fatalf("offset after child = %d, want 10", c2.Offset()) + } + state = c2.State() + if state[0].Dim(2) != 10 { + t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2)) + } +} + +// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back +// produces a snapshot equivalent to the original. +func TestKVCacheSplitMergeRoundTripData(t *testing.T) { + skipIfNoMLX(t) + c := NewKVCache() + + for range 10 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + + snap := c.Snapshot(0) + parent, child := c.Split(snap, 6) + merged := c.Merge(parent, child) + if merged == nil { + t.Fatal("Merge returned nil") + } + + c2 := NewKVCache() + if !c2.Restore(merged, 10) { + t.Fatal("Restore(merged) failed") + } + if c2.Offset() != 10 { + t.Fatalf("offset = %d, want 10", c2.Offset()) + } + + state := c2.State() + if state[0].Dim(2) != 10 { + t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2)) + } + if state[1].Dim(2) != 10 { + t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2)) + } +} + +func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) { + skipIfNoMLX(t) + c := NewRotatingKVCache(4) + + // Feed 10 tokens (window size 4, so positions 0-5 are evicted). + for range 10 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + + // Offset 3 is outside the window. + if c.Restore(nil, 3) { + t.Fatal("Restore(nil, 3) should fail when outside window") + } +} + +// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring +// from a snapshot, the rotating cache has the correct window of data. +func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) { + skipIfNoMLX(t) + c := NewRotatingKVCache(4) + + // Feed 10 tokens one at a time. Window size 4, so only last 4 are kept. + for range 10 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + + snap := c.Snapshot(0) + if snap == nil { + t.Fatal("Snapshot returned nil") + } + + // Feed 5 more tokens. + for range 5 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + + // Restore to offset 10. + if !c.Restore(snap, 10) { + t.Fatal("Restore failed") + } + if c.Offset() != 10 { + t.Fatalf("offset = %d, want 10", c.Offset()) + } + + state := c.State() + if len(state) != 2 { + t.Fatalf("State() returned %d arrays, want 2", len(state)) + } + // Seq dim should be min(offset, maxSize) = min(10, 4) = 4. + seqDim := state[0].Dim(2) + if seqDim != 4 { + t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim) + } +} + +// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a +// snapshot correctly preserves the write position (idx), so subsequent +// single-token updates land in the right buffer slot. +func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) { + skipIfNoMLX(t) + c := NewRotatingKVCache(4) + + // Fill the window: 6 tokens into a size-4 window. + // After this, idx has wrapped and the buffer has rotated. + for range 6 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + if c.Offset() != 6 { + t.Fatalf("offset = %d, want 6", c.Offset()) + } + + snap := c.Snapshot(0) + + // Mutate the cache further so live state diverges from snapshot. + for range 3 { + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + } + + // Restore to snapshot state. + if !c.Restore(snap, 6) { + t.Fatal("Restore failed") + } + if c.Offset() != 6 { + t.Fatalf("offset after restore = %d, want 6", c.Offset()) + } + + // Feed one more token. If idx was restored correctly, this should + // produce a valid window of size 4 at offset 7. + k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) + c.Update(k, v) + + if c.Offset() != 7 { + t.Fatalf("offset after post-restore update = %d, want 7", c.Offset()) + } + state := c.State() + if len(state) != 2 { + t.Fatalf("State() returned %d arrays, want 2", len(state)) + } + seqDim := state[0].Dim(2) + if seqDim != 4 { + t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim) + } +} diff --git a/x/mlxrunner/cache/recurrent.go b/x/mlxrunner/cache/recurrent.go index 86c592be5..4025c69a3 100644 --- a/x/mlxrunner/cache/recurrent.go +++ b/x/mlxrunner/cache/recurrent.go @@ -56,16 +56,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo return detached } -func snapshotPinned(a *mlx.Array) *mlx.Array { - if a == nil || !a.Valid() { - return nil - } - snap := mlx.Copy(a) - mlx.Eval(snap) - mlx.Pin(snap) - return snap -} - func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache { return &RecurrentCache{ convTail: int(convTail), @@ -123,30 +113,69 @@ func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array return keys, values } -func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) { - return c.convState, c.deltaState +func (c *RecurrentCache) State() []*mlx.Array { + return []*mlx.Array{c.convState, c.deltaState} } -func (c *RecurrentCache) CanTrim() bool { return false } - -func (c *RecurrentCache) Trim(n int) int { - // Recurrent state is not directly trimmable. Divergent prefixes must drop the cache. - _ = n - return 0 +// recurrentSnapshot holds paged-out recurrent state. Self-contained — +// does not depend on any parent state. +type recurrentSnapshot struct { + convState, deltaState *mlx.Array + offset int } -func (c *RecurrentCache) Clone() Cache { - clone := &RecurrentCache{ - offset: c.offset, - convTail: c.convTail, - convDim: c.convDim, - numVHeads: c.numVHeads, - headVDim: c.headVDim, - headKDim: c.headKDim, - convState: snapshotPinned(c.convState), - deltaState: snapshotPinned(c.deltaState), +func (s *recurrentSnapshot) Size() int { return s.convState.NumBytes() + s.deltaState.NumBytes() } +func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) } + +func (c *RecurrentCache) Snapshot(fromOffset int) Snapshot { + // Recurrent state is not position-sliceable — always snapshot the full state. + if c.convState == nil && c.deltaState == nil { + return nil } - return clone + + snap := &recurrentSnapshot{offset: c.offset} + snap.convState = c.convState.Clone() + snap.deltaState = c.deltaState.Clone() + mlx.Pin(snap.convState, snap.deltaState) + + return snap +} + +func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool { + if snapshot == nil { + // Recurrent state is cumulative and can't rewind. Only succeed + // if we're already at the target (no-op). + return target == c.offset + } + + snap := snapshot.(*recurrentSnapshot) + + // Recurrent state encodes all tokens up to snap.offset. Restoring + // to a target before that would leave stale state from tokens + // [target, snap.offset) baked in. Only allow restoring forward. + if target < snap.offset { + return false + } + + c.convState = c.setStateRaw(c.convState, snap.convState) + c.deltaState = c.setStateRaw(c.deltaState, snap.deltaState) + c.offset = snap.offset + + return true +} + +func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot { + // Recurrent snapshots are self-contained — child supersedes parent. + if parent != nil { + parent.Close() + } + return child +} + +func (c *RecurrentCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) { + // Recurrent state is cumulative and not position-sliceable. + // Cannot recover intermediate state at the split point. + return nil, snapshot } func (c *RecurrentCache) Free() { @@ -156,4 +185,3 @@ func (c *RecurrentCache) Free() { } func (c *RecurrentCache) Offset() int { return c.offset } -func (c *RecurrentCache) Len() int { return c.offset } diff --git a/x/mlxrunner/cache/recurrent_test.go b/x/mlxrunner/cache/recurrent_test.go new file mode 100644 index 000000000..64d482593 --- /dev/null +++ b/x/mlxrunner/cache/recurrent_test.go @@ -0,0 +1,44 @@ +package cache + +import ( + "testing" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only +// allows restoring forward (target >= snapshot offset), never backward. +func TestRecurrentCacheRestoreDirectionality(t *testing.T) { + skipIfNoMLX(t) + c := NewRecurrentCache(3, 12, 4, 8, 8) + _ = c.ConvState(1, mlx.DTypeFloat16) + _ = c.DeltaState(1, mlx.DTypeFloat16) + c.Advance(10) + + snap := c.Snapshot(0) + + c.Advance(5) // now at 15 + + // Restore backward should fail. + if c.Restore(snap, 5) { + t.Fatal("Restore(snap, 5) should fail — target < snap.offset") + } + + // Restore to exact snap offset should succeed. + if !c.Restore(snap, 10) { + t.Fatal("Restore(snap, 10) should succeed") + } + if c.Offset() != 10 { + t.Fatalf("offset = %d, want 10", c.Offset()) + } + + // Restore forward (target > snap offset) should succeed, offset = snap.offset. + snap2 := c.Snapshot(0) + if !c.Restore(snap2, 15) { + t.Fatal("Restore(snap, 15) should succeed") + } + // Recurrent state is at snap.offset (10), not target (15). + if c.Offset() != 10 { + t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset()) + } +} diff --git a/x/mlxrunner/cache_test.go b/x/mlxrunner/cache_test.go new file mode 100644 index 000000000..13524b124 --- /dev/null +++ b/x/mlxrunner/cache_test.go @@ -0,0 +1,859 @@ +package mlxrunner + +import ( + "slices" + "testing" + + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +// snapshotTracker records every fakeSnapshot created and every Close() call +// so tests can detect leaked (created but never closed) or double-closed snapshots. +type snapshotTracker struct { + all []*fakeSnapshot +} + +func (tr *snapshotTracker) track(s *fakeSnapshot) { + if s == nil { + return + } + s.tracker = tr + tr.all = append(tr.all, s) +} + +// Fake caches that store actual token sequences so tests can verify the right +// data was restored, not just the right offset. + +// fakeSnapshot stores a copy of the token sub-sequence it covers. +type fakeSnapshot struct { + tokens []int32 + from, to int + byteSize int // configurable for eviction tests + + tracker *snapshotTracker + closeCount int +} + +func (s *fakeSnapshot) Size() int { return s.byteSize } +func (s *fakeSnapshot) Close() { + s.closeCount++ +} + +// fakeRewindableCache tracks the full token sequence and supports +// arbitrary rewind via Restore(nil, target). +type fakeRewindableCache struct { + tokens []int32 + tracker *snapshotTracker +} + +func (c *fakeRewindableCache) feed(tokens []int32) { + c.tokens = append(c.tokens, tokens...) +} + +func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { + return nil, nil +} +func (c *fakeRewindableCache) State() []*mlx.Array { return nil } +func (c *fakeRewindableCache) Offset() int { return len(c.tokens) } + +func (c *fakeRewindableCache) Free() { + c.tokens = nil +} + +func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot { + if fromOffset >= len(c.tokens) { + return nil + } + from := fromOffset + if from < 0 { + from = 0 + } + s := &fakeSnapshot{ + tokens: slices.Clone(c.tokens[from:]), + from: from, + to: len(c.tokens), + } + c.tracker.track(s) + return s +} + +func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool { + if snapshot == nil { + // Rewind live state. + if target < 0 { + target = 0 + } + if target > len(c.tokens) { + target = len(c.tokens) + } + c.tokens = c.tokens[:target] + return true + } + s := snapshot.(*fakeSnapshot) + if len(c.tokens) < s.from { + return false // don't have base data up to snapshot start + } + c.tokens = append(c.tokens[:s.from], s.tokens...) + if target < len(c.tokens) { + c.tokens = c.tokens[:target] + } + return true +} + +func (c *fakeRewindableCache) Merge(parent, child cache.Snapshot) cache.Snapshot { + if parent == nil || child == nil { + if parent != nil { + parent.Close() + } + if child != nil { + child.Close() + } + return nil + } + p := parent.(*fakeSnapshot) + ch := child.(*fakeSnapshot) + merged := make([]int32, len(p.tokens)+len(ch.tokens)) + copy(merged, p.tokens) + copy(merged[len(p.tokens):], ch.tokens) + s := &fakeSnapshot{ + tokens: merged, + from: p.from, + to: ch.to, + byteSize: p.byteSize + ch.byteSize, + } + c.tracker.track(s) + p.Close() + ch.Close() + return s +} + +func (c *fakeRewindableCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) { + if snapshot == nil { + return nil, nil + } + s := snapshot.(*fakeSnapshot) + relAt := at - s.from + if relAt <= 0 { + return nil, snapshot + } + if relAt >= len(s.tokens) { + return snapshot, nil + } + p := &fakeSnapshot{ + tokens: slices.Clone(s.tokens[:relAt]), + from: s.from, + to: at, + byteSize: s.byteSize, + } + ch := &fakeSnapshot{ + tokens: slices.Clone(s.tokens[relAt:]), + from: at, + to: s.to, + byteSize: s.byteSize, + } + c.tracker.track(p) + c.tracker.track(ch) + s.Close() + return p, ch +} + +// fakeSlidingWindowCache models RotatingKVCache semantics: stores the full +// token sequence but only the trailing maxSize tokens are "live" in the window. +// Once the window fills, live rewind is impossible without a snapshot. +type fakeSlidingWindowCache struct { + tokens []int32 + maxSize int + tracker *snapshotTracker +} + +func (c *fakeSlidingWindowCache) feed(tokens []int32) { + c.tokens = append(c.tokens, tokens...) +} + +func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { + return nil, nil +} +func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil } +func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) } + +func (c *fakeSlidingWindowCache) Free() { + c.tokens = nil +} + +func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot { + if len(c.tokens) == 0 || len(c.tokens) <= fromOffset { + return nil + } + // Snapshot captures the full window state (like RotatingKVCache.Snapshot). + s := &fakeSnapshot{ + tokens: slices.Clone(c.tokens), + from: 0, + to: len(c.tokens), + } + c.tracker.track(s) + return s +} + +func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool { + if snapshot == nil { + if target == len(c.tokens) { + return true + } + // Live rewind only works when buffer hasn't filled (offset <= maxSize). + if len(c.tokens) > c.maxSize { + return false + } + c.tokens = c.tokens[:target] + return true + } + s := snapshot.(*fakeSnapshot) + c.tokens = slices.Clone(s.tokens) + if target < len(c.tokens) { + c.tokens = c.tokens[:target] + } + return true +} + +func (c *fakeSlidingWindowCache) Merge(parent, child cache.Snapshot) cache.Snapshot { + // Child supersedes parent for sliding window (full window state). + if parent != nil { + parent.Close() + } + return child +} + +func (c *fakeSlidingWindowCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) { + // Can't split a ring buffer at an arbitrary point. + return nil, snapshot +} + +// fakeRecurrentCache models RecurrentCache semantics: stores tokens +// but cannot rewind without a snapshot. +type fakeRecurrentCache struct { + tokens []int32 + tracker *snapshotTracker +} + +func (c *fakeRecurrentCache) feed(tokens []int32) { + c.tokens = append(c.tokens, tokens...) +} + +func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { + return nil, nil +} +func (c *fakeRecurrentCache) State() []*mlx.Array { return nil } +func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) } + +func (c *fakeRecurrentCache) Free() { + c.tokens = nil +} + +func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot { + // Recurrent state is cumulative; snapshot captures the full state. + if len(c.tokens) == 0 { + return nil + } + s := &fakeSnapshot{ + tokens: slices.Clone(c.tokens), + from: 0, + to: len(c.tokens), + } + c.tracker.track(s) + return s +} + +func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool { + if snapshot == nil { + return target == len(c.tokens) // can only no-op + } + s := snapshot.(*fakeSnapshot) + if target < s.to { + return false // can't go backward + } + c.tokens = slices.Clone(s.tokens) + return true +} + +func (c *fakeRecurrentCache) Merge(parent, child cache.Snapshot) cache.Snapshot { + // Child supersedes parent for cumulative state. + if parent != nil { + parent.Close() + } + return child +} + +func (c *fakeRecurrentCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) { + return nil, snapshot // can't split cumulative state +} + +type feedableCache interface { + cache.Cache + feed(tokens []int32) +} + +// testEnv encapsulates a kvCache and its fake caches for a test scenario. +type testEnv struct { + kvc *kvCache + caches []cache.Cache // typed references for assertions + tracker *snapshotTracker +} + +// newTransformerEnv creates a test environment with a single rewindable cache +// (pure transformer model). +func newTransformerEnv() *testEnv { + tracker := &snapshotTracker{} + caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}} + return &testEnv{ + kvc: &kvCache{caches: caches}, + caches: caches, + tracker: tracker, + } +} + +// newSlidingWindowEnv creates a test environment with one rewindable cache and +// one sliding window cache (Mistral-style architecture). +func newSlidingWindowEnv() *testEnv { + tr := &snapshotTracker{} + rc := &fakeRewindableCache{tracker: tr} + sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr} + caches := []cache.Cache{rc, sw} + return &testEnv{ + kvc: &kvCache{caches: caches}, + caches: caches, + tracker: tr, + } +} + +// newRecurrentEnv creates a test environment with one rewindable cache and one +// non-rewindable cache (Jamba-style architecture). +func newRecurrentEnv() *testEnv { + tr := &snapshotTracker{} + rc := &fakeRewindableCache{tracker: tr} + nrc := &fakeRecurrentCache{tracker: tr} + caches := []cache.Cache{rc, nrc} + return &testEnv{ + kvc: &kvCache{caches: caches}, + caches: caches, + tracker: tr, + } +} + +// assertAllTokens checks that every cache in the environment contains exactly +// the expected token sequence. +func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) { + t.Helper() + for i, c := range e.caches { + assertTokens(t, label, c, expected) + // Verify all caches report the same offset. + if i > 0 && c.Offset() != e.caches[0].Offset() { + t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", + label, i, c.Offset(), e.caches[0].Offset()) + } + } +} + +// simulateRequest mirrors the production pipeline lifecycle: +// begin -> prefill with snapshot(false) at branch points -> generate -> close + +type requestResult struct { + remaining []int32 + snapshotOffset int +} + +// simulateRequest runs a request through the harness. If userSnapshotAt > 0, +// a user snapshot (snapshot(true)) is created at that offset during prefill. +func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult { + t.Helper() + + userSnapAt := 0 + if len(userSnapshotAt) > 0 { + userSnapAt = userSnapshotAt[0] + } + + session := kvc.begin(nil, inputs) + result := requestResult{ + remaining: slices.Clone(session.remaining), + snapshotOffset: session.snapshotOffset, + } + + assertCacheOffsetAlignment(t, kvc, "after begin") + + baseOffset := kvc.minCacheOffset() + remaining := inputs[baseOffset:] + + // Collect snapshot points (offset -> user flag) in ascending order. + type snapPoint struct { + offset int + user bool + } + var points []snapPoint + if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset { + points = append(points, snapPoint{session.snapshotOffset, false}) + } + if userSnapAt > 0 && userSnapAt > baseOffset { + points = append(points, snapPoint{userSnapAt, true}) + } + slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset }) + + // Prefill: feed tokens, pausing at each snapshot point. + for _, sp := range points { + count := sp.offset - baseOffset + if count > len(remaining) { + break + } + if count > 0 { + feedAll(kvc.caches, remaining[:count]) + remaining = remaining[count:] + baseOffset = sp.offset + } + assertCacheOffsetAlignment(t, kvc, "at snapshot point") + session.snapshot(sp.user) + } + + // Feed rest of input tokens. + if len(remaining) > 0 { + feedAll(kvc.caches, remaining) + } + + assertCacheOffsetAlignment(t, kvc, "after prefill") + + // Generate tokens. + if len(generated) > 0 { + session.outputs = generated + feedAll(kvc.caches, generated) + } + + assertCacheOffsetAlignment(t, kvc, "before close") + session.close() + return result +} + +func feedAll(caches []cache.Cache, tokens []int32) { + for _, c := range caches { + if fc, ok := c.(feedableCache); ok { + fc.feed(tokens) + } + } +} + +// assertCacheOffsetAlignment verifies all caches report the same offset. +func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) { + t.Helper() + if len(kvc.caches) < 2 { + return + } + expected := kvc.caches[0].Offset() + for i := 1; i < len(kvc.caches); i++ { + if got := kvc.caches[i].Offset(); got != expected { + t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected) + } + } +} + +// assertTokens checks that a feedable cache contains the expected token sequence. +// For sliding window caches, only the trailing maxSize tokens are checked. +func assertTokens(t *testing.T, label string, c cache.Cache, expected []int32) { + t.Helper() + switch fc := c.(type) { + case *fakeRewindableCache: + if !slices.Equal(fc.tokens, expected) { + t.Errorf("%s: rewindable tokens = %v, want %v", label, fc.tokens, expected) + } + case *fakeSlidingWindowCache: + // Sliding window stores full history but only trailing maxSize are live. + // Verify the full token sequence matches (the window semantics are + // enforced by Snapshot/Restore, not by the token log). + if !slices.Equal(fc.tokens, expected) { + t.Errorf("%s: sliding window tokens = %v, want %v", label, fc.tokens, expected) + } + case *fakeRecurrentCache: + if !slices.Equal(fc.tokens, expected) { + t.Errorf("%s: non-rewindable tokens = %v, want %v", label, fc.tokens, expected) + } + default: + t.Fatalf("%s: unknown cache type %T", label, c) + } +} + +// checkTrieInvariants walks the trie and checks structural invariants. +func checkTrieInvariants(t *testing.T, root *trieNode) { + t.Helper() + walkNodes(root, func(n *trieNode) bool { + if n.parent != nil { + if n.startOffset() != n.parent.endOffset { + t.Errorf("node [%d,%d): startOffset %d != parent endOffset %d", + n.startOffset(), n.endOffset, n.startOffset(), n.parent.endOffset) + } + } + if len(n.tokens) != n.endOffset-n.startOffset() { + t.Errorf("node [%d,%d): token count %d != offset span %d", + n.startOffset(), n.endOffset, len(n.tokens), n.endOffset-n.startOffset()) + } + for _, c := range n.children { + if c.parent != n { + t.Errorf("child [%d,%d) parent mismatch", c.startOffset(), c.endOffset) + } + } + // No two siblings should start with the same token. + seen := make(map[int32]bool) + for _, c := range n.children { + if len(c.tokens) > 0 { + first := c.tokens[0] + if seen[first] { + t.Errorf("node [%d,%d): duplicate sibling first token %d", + n.startOffset(), n.endOffset, first) + } + seen[first] = true + } + } + return true + }) +} + +// checkSnapshotLeaks verifies that every tracked snapshot is either still live +// in the trie (closeCount == 0) or has been closed exactly once. It reports +// leaked snapshots (not in trie, never closed) and double-closes. +func checkSnapshotLeaks(t *testing.T, tracker *snapshotTracker, root *trieNode) { + t.Helper() + if tracker == nil { + return + } + + // Collect all live snapshots still referenced by trie nodes. + live := make(map[*fakeSnapshot]bool) + walkNodes(root, func(n *trieNode) bool { + for _, s := range n.snapshots { + if s != nil { + if fs, ok := s.(*fakeSnapshot); ok { + live[fs] = true + } + } + } + return true + }) + + for i, s := range tracker.all { + if live[s] { + if s.closeCount != 0 { + t.Errorf("snapshot #%d [%d,%d) is still in trie but was closed %d time(s)", + i, s.from, s.to, s.closeCount) + } + } else { + if s.closeCount == 0 { + t.Errorf("snapshot #%d [%d,%d) leaked: created but never closed and not in trie", + i, s.from, s.to) + } else if s.closeCount > 1 { + t.Errorf("snapshot #%d [%d,%d) double-closed: closed %d times", + i, s.from, s.to, s.closeCount) + } + } + } +} + +// forEachEnv runs fn as subtests for three realistic model configurations: +// pure transformer, transformer + sliding window (Mistral-style), and +// transformer + recurrent (Jamba-style). Leak checking runs automatically +// at the end of each subtest. +func forEachEnv(t *testing.T, fn func(t *testing.T, env *testEnv)) { + t.Helper() + run := func(t *testing.T, env *testEnv) { + t.Cleanup(func() { + checkSnapshotLeaks(t, env.tracker, env.kvc.root) + }) + fn(t, env) + } + t.Run("Transformer", func(t *testing.T) { run(t, newTransformerEnv()) }) + t.Run("SlidingWindow", func(t *testing.T) { run(t, newSlidingWindowEnv()) }) + t.Run("Recurrent", func(t *testing.T) { run(t, newRecurrentEnv()) }) +} + +// TestBranchCreationAndReuse exercises the core multi-conversation lifecycle: +// two conversations share a prefix and diverge, creating a branch point. +// A third conversation extends the first. Verifies trie structure, cache +// hit lengths, and that semantic caches contain the correct token sequences. +func TestBranchCreationAndReuse(t *testing.T) { + forEachEnv(t, func(t *testing.T, env *testEnv) { + kvc := env.kvc + + // Request A: [1,2,3,4,5,6,7,8] + generate [20,21] — full miss. + resA := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8}, []int32{20, 21}) + if len(resA.remaining) != 8 { + t.Fatalf("A: remaining = %d, want 8 (full miss)", len(resA.remaining)) + } + env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21}) + + // Verify trie was populated by close(). + _, mA := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21}) + if mA != 10 { + t.Fatalf("A findable: expected 10 matched, got %d", mA) + } + + // Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A. + // Partial match in A's edge triggers snapshotOffset. + resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31}) + if resB.snapshotOffset != 5 { + t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset) + } + // Cache was rewound to 0 (partial match truncates path to root), + // so all tokens were re-evaluated. + if len(resB.remaining) != 8 { + t.Fatalf("B: remaining = %d, want 8", len(resB.remaining)) + } + env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31}) + + // Both A and B should be findable in the trie. + _, mA2 := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21}) + if mA2 < 5 { + t.Fatalf("A still findable: expected >= 5 matched, got %d", mA2) + } + _, mB := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31}) + if mB < 5 { + t.Fatalf("B findable: expected >= 5 matched, got %d", mB) + } + + // Request C: [1,2,3,4,5,6,7,8,40,41] — extends A's prefix. + // Should get a cache hit for the shared prefix. + resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}, nil) + if len(resC.remaining) >= 10 { + t.Fatalf("C: remaining = %d, want < 10 (should get cache hit)", len(resC.remaining)) + } + env.assertAllTokens(t, "after C", []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}) + + checkTrieInvariants(t, kvc.root) + }) +} + +// TestExactMatchSeedBehavior verifies the holdback mechanism: when the exact +// same prompt is requested twice, the cache does not overclaim cached work. +// The last token must be re-evaluated to seed generation. +func TestExactMatchSeedBehavior(t *testing.T) { + forEachEnv(t, func(t *testing.T, env *testEnv) { + kvc := env.kvc + + // Request A: first time. + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}) + + // Request B: identical prompt. Holdback means matched=4, partial in + // the 5-token edge, so path truncates to root and all tokens are + // re-evaluated. snapshotOffset should be set at the holdback point. + resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21}) + if len(resB.remaining) != 5 { + t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining)) + } + if resB.snapshotOffset != 4 { + t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset) + } + env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21}) + + checkTrieInvariants(t, kvc.root) + }) +} + +// TestConversationResumption tests the most common pattern: user sends a message, +// gets a response, then sends a follow-up. The follow-up should reuse the cached +// prefix (system prompt + first turn + assistant response). +func TestConversationResumption(t *testing.T) { + forEachEnv(t, func(t *testing.T, env *testEnv) { + kvc := env.kvc + + // Turn 1: system prompt + user message, assistant generates response. + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11, 12}) + env.assertAllTokens(t, "turn 1", []int32{1, 2, 3, 4, 5, 10, 11, 12}) + + // Turn 2: full history + new user message. Should get a cache hit on + // the prefix [1,2,3,4,5,10,11,12] and only need to evaluate [20,21]. + resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21}, []int32{30}) + if len(resB.remaining) > 5 { + t.Fatalf("turn 2: remaining = %d, want <= 5 (should reuse most of history)", len(resB.remaining)) + } + env.assertAllTokens(t, "turn 2", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30}) + + // Turn 3: even longer history. + resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}, nil) + if len(resC.remaining) > 5 { + t.Fatalf("turn 3: remaining = %d, want <= 5", len(resC.remaining)) + } + env.assertAllTokens(t, "turn 3", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}) + + checkTrieInvariants(t, kvc.root) + }) +} + +// TestEvictionPreservesActiveConversations creates multiple conversations sharing +// a system prompt, triggers eviction via large snapshot sizes, and verifies the +// active path and shared prefix survive while memory stays bounded. +func TestEvictionPreservesActiveConversations(t *testing.T) { + forEachEnv(t, func(t *testing.T, env *testEnv) { + kvc := env.kvc + systemPrompt := []int32{1, 2, 3, 4, 5} + + // Create 5 conversations with unique suffixes. + for i := range 5 { + suffix := []int32{int32(100 + i*10), int32(101 + i*10), int32(102 + i*10)} + inputs := append(slices.Clone(systemPrompt), suffix...) + simulateRequest(t, kvc, inputs, []int32{int32(200 + i)}) + } + + // Inflate snapshot sizes to trigger eviction. + walkNodes(kvc.root, func(n *trieNode) bool { + if !n.hasSnapshots() { + return true + } + snaps := make([]cache.Snapshot, len(n.snapshots)) + for i, s := range n.snapshots { + if s != nil { + snaps[i] = &fakeSnapshot{byteSize: 2 * 1024 * 1024 * 1024} // 2 GiB per snapshot + } + } + n.setSnapshots(snaps, &kvc.pagedOutBytes) + return true + }) + + // Run eviction. + kvc.enforceEvictionPolicy() + + // Memory should be within limits. + if kvc.pagedOutBytes > maxPagedOutBytes { + t.Fatalf("pagedOutBytes = %d, want <= %d", kvc.pagedOutBytes, maxPagedOutBytes) + } + + // Active path should be untouched. + if len(kvc.activePath) < 2 { + t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath)) + } + + // System prompt prefix should still be findable (evicting a + // multi-child branch point only drops snapshots, not the node). + _, matched := findBestMatch(kvc.root, systemPrompt) + if matched < len(systemPrompt) { + t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt)) + } + + checkTrieInvariants(t, kvc.root) + }) +} + +// TestUserSnapshotPreservesRestorePoint verifies that user-created snapshots +// (snapshot(true)) resist structural changes that would destroy them: +// - A user node forces new tokens into a child instead of extending in-place +// - The snapshot remains restorable after other branches are added +func TestUserSnapshotPreservesRestorePoint(t *testing.T) { + forEachEnv(t, func(t *testing.T, env *testEnv) { + kvc := env.kvc + + // Request A: user snapshot at offset 5, then generate. + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}, 5) + + assertUserNodeExists(t, kvc, "after A") + + // Request B: extends A's prefix. The user node at offset 5 should + // force tokens into a child rather than extending in-place. + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}, nil) + env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}) + assertUserNodeExists(t, kvc, "after B") + + // Request C: diverge from the user node. + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 30, 31}, []int32{40}) + + // Request D: switch back to A's branch — user snapshot still restorable. + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}, nil) + env.assertAllTokens(t, "back to A", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}) + + checkTrieInvariants(t, kvc.root) + }) +} + +// TestUserSnapshotResistsAutoMerge verifies that when a sibling leaf is evicted, +// a user-marked parent node is not auto-merged with its remaining single child. +func TestUserSnapshotResistsAutoMerge(t *testing.T) { + forEachEnv(t, func(t *testing.T, env *testEnv) { + kvc := env.kvc + + // Request A: user snapshot at offset 3, then continue to offset 5. + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10}, 3) + + // Request B: diverges at the user node, creating a second child. + simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20}) + + userNode := findUserNode(t, kvc) + if len(userNode.children) != 2 { + t.Fatalf("user node children = %d, want 2", len(userNode.children)) + } + + // Inflate snapshot sizes and evict. The non-active branch should be + // evicted, leaving the user node with one child. + walkNodes(kvc.root, func(n *trieNode) bool { + if !n.hasSnapshots() { + return true + } + snaps := make([]cache.Snapshot, len(n.snapshots)) + for i, s := range n.snapshots { + if s != nil { + snaps[i] = &fakeSnapshot{byteSize: 5 * 1024 * 1024 * 1024} + } + } + n.setSnapshots(snaps, &kvc.pagedOutBytes) + return true + }) + kvc.enforceEvictionPolicy() + + // The user node should still exist (not auto-merged) even with one child. + assertUserNodeExists(t, kvc, "after eviction") + + checkTrieInvariants(t, kvc.root) + }) +} + +func findUserNode(t *testing.T, kvc *kvCache) *trieNode { + t.Helper() + var found *trieNode + walkNodes(kvc.root, func(n *trieNode) bool { + if n.user { + found = n + } + return true + }) + if found == nil { + t.Fatal("no user-marked node found") + } + return found +} + +func assertUserNodeExists(t *testing.T, kvc *kvCache, label string) { + t.Helper() + var exists bool + walkNodes(kvc.root, func(n *trieNode) bool { + if n.user { + exists = true + } + return true + }) + if !exists { + t.Fatalf("%s: no user-marked node found", label) + } +} + +// TestBranchSwitchRestoresCorrectState exercises switching back to an older +// branch after working on a different one, verifying that the restored cache +// state contains the correct token sequence for both rewindable and +// non-rewindable caches. +func TestBranchSwitchRestoresCorrectState(t *testing.T) { + forEachEnv(t, func(t *testing.T, env *testEnv) { + kvc := env.kvc + + // Request A: [1,2,3,4,5] + generate [10,11] + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}) + env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 10, 11}) + + // Request B: [1,2,3,6,7] — diverges at token 4 + simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{12, 13}) + env.assertAllTokens(t, "after B", []int32{1, 2, 3, 6, 7, 12, 13}) + + // Request C: switch back to A's branch [1,2,3,4,5,10,11,20] + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20}, nil) + env.assertAllTokens(t, "after C (back to A)", []int32{1, 2, 3, 4, 5, 10, 11, 20}) + + checkTrieInvariants(t, kvc.root) + }) +} diff --git a/x/mlxrunner/cache_trie.go b/x/mlxrunner/cache_trie.go new file mode 100644 index 000000000..3516ba676 --- /dev/null +++ b/x/mlxrunner/cache_trie.go @@ -0,0 +1,296 @@ +package mlxrunner + +import ( + "fmt" + "slices" + "time" + + "github.com/ollama/ollama/x/mlxrunner/cache" +) + +// trieNode represents a node in the compressed prefix trie for KV cache branching. +// Each node stores a compressed edge (multiple tokens) and optional paged-out +// snapshot data per cache layer. +type trieNode struct { + tokens []int32 // compressed edge — multiple tokens per node + endOffset int // cumulative tokens from root to end of this node + parent *trieNode + children []*trieNode + lastUsed time.Time // for LRU eviction + snapshots []cache.Snapshot // per-layer paged-out snapshot data (nil if not paged out) + user bool // true = explicit restore point (resist auto-merge) +} + +// startOffset returns the cumulative token offset at the start of this node's edge. +func (n *trieNode) startOffset() int { + return n.endOffset - len(n.tokens) +} + +// snapshotBytes returns the total bytes of paged-out snapshots on this node. +func (n *trieNode) snapshotBytes() int64 { + var total int64 + for _, s := range n.snapshots { + if s != nil { + total += int64(s.Size()) + } + } + return total +} + +// setSnapshots replaces this node's snapshots with snaps and closes the old ones. +// If counter is non-nil, the net byte delta is applied to it. +func (n *trieNode) setSnapshots(snaps []cache.Snapshot, counter *int64) { + old := n.swapSnapshots(snaps, counter) + for _, s := range old { + if s != nil { + s.Close() + } + } +} + +// swapSnapshots is like setSnapshots but returns the previous snapshots +// without closing them. Use this when the old snapshots will be consumed +// (e.g. by Split/Merge). +func (n *trieNode) swapSnapshots(snaps []cache.Snapshot, counter *int64) []cache.Snapshot { + old := n.snapshots + if counter != nil { + *counter -= n.snapshotBytes() + } + n.snapshots = snaps + if counter != nil { + *counter += n.snapshotBytes() + } + return old +} + +// hasSnapshots returns true if any layer has snapshot data. +func (n *trieNode) hasSnapshots() bool { + return slices.ContainsFunc(n.snapshots, func(s cache.Snapshot) bool { return s != nil }) +} + +// hasAllSnapshots returns true if every layer has snapshot data. +func (n *trieNode) hasAllSnapshots() bool { + return len(n.snapshots) > 0 && !slices.Contains(n.snapshots, nil) +} + +// findBestMatch walks the trie matching input tokens, returning the path of +// nodes traversed and the total number of tokens matched. +func findBestMatch(root *trieNode, tokens []int32) (path []*trieNode, matched int) { + if root == nil { + return nil, 0 + } + + path = []*trieNode{root} + pos := 0 + + node := root + for pos < len(tokens) { + // When multiple children share the same first token (e.g. after + // a split), prefer the child whose full edge matches over one + // that only partially matches. This is just being defensive - it + // shouldn't actually happen. + var best *trieNode + bestMatched := 0 + bestFull := false + for _, child := range node.children { + edge := child.tokens + if len(edge) == 0 { + continue + } + if edge[0] != tokens[pos] { + continue + } + // Count matching tokens in this child's edge. + j := 0 + for j < len(edge) && pos+j < len(tokens) && edge[j] == tokens[pos+j] { + j++ + } + full := j == len(edge) + // Prefer full edge matches; among same type, prefer longer. + if best == nil || (full && !bestFull) || (full == bestFull && j > bestMatched) { + best = child + bestMatched = j + bestFull = full + } + } + if best == nil { + break + } + + pos += bestMatched + path = append(path, best) + + if !bestFull { + // Partial match within this edge + break + } + node = best + } + + return path, pos +} + +// appendTokens either creates a new child node or extends the leaf in place, +// returning the node that now holds the tokens. +func (n *trieNode) appendTokens(root *trieNode, tokens []int32, endOffset int) *trieNode { + if n == root || len(n.children) > 0 || n.hasSnapshots() { + child := &trieNode{ + tokens: make([]int32, len(tokens)), + endOffset: endOffset, + parent: n, + lastUsed: n.lastUsed, + } + copy(child.tokens, tokens) + n.children = append(n.children, child) + return child + } + n.tokens = append(n.tokens, tokens...) + n.endOffset = endOffset + return n +} + +// removeNode removes a leaf node from the trie. +func removeNode(node *trieNode, counter *int64) { + if node.parent == nil { + panic("removeNode called on root") + } + if len(node.children) != 0 { + panic("removeNode called on non-leaf node") + } + p := node.parent + for i, child := range p.children { + if child == node { + p.children = append(p.children[:i], p.children[i+1:]...) + break + } + } + node.parent = nil + node.setSnapshots(nil, counter) +} + +// splitNode splits a node at the given token offset within its edge, +// creating a new parent node. Returns the new parent. +// `at` is relative to the node's edge (0-based index into node.tokens). +// If caches are provided, snapshots are split between parent and child +// using Cache.Split; otherwise snapshots are invalidated. +func splitNode(node *trieNode, at int, caches []cache.Cache, counter *int64) *trieNode { + if at <= 0 || at >= len(node.tokens) { + panic(fmt.Sprintf("splitNode: invalid split offset %d for node with %d tokens", at, len(node.tokens))) + } + + // Create new parent with the prefix of the edge. + newParent := &trieNode{ + tokens: make([]int32, at), + endOffset: node.startOffset() + at, + parent: node.parent, + children: []*trieNode{node}, + lastUsed: node.lastUsed, + } + copy(newParent.tokens, node.tokens[:at]) + + // Update the original node to have only the suffix. + node.tokens = node.tokens[at:] + // endOffset stays the same for the original node. + + // Split snapshots between parent and child using Cache.Split. + // Split consumes the old snapshots, so we remove them first (adjusting + // the counter), then assign the split halves (adjusting it back). + if node.hasSnapshots() { + oldSnaps := node.swapSnapshots(nil, counter) + parentSnaps := make([]cache.Snapshot, len(oldSnaps)) + childSnaps := make([]cache.Snapshot, len(oldSnaps)) + for i, snap := range oldSnaps { + if snap != nil { + parentSnaps[i], childSnaps[i] = caches[i].Split(snap, newParent.endOffset) + } + } + newParent.setSnapshots(parentSnaps, counter) + node.setSnapshots(childSnaps, counter) + } + + // Reparent: replace node with newParent in the old parent's children. + if node.parent != nil { + for i, child := range node.parent.children { + if child == node { + node.parent.children[i] = newParent + break + } + } + } + node.parent = newParent + + return newParent +} + +// mergeWithChild merges a node with its single child: concatenates tokens, +// merges snapshot data via Cache.Merge, and removes the child. +func mergeWithChild(node *trieNode, caches []cache.Cache, counter *int64) { + if len(node.children) != 1 { + panic(fmt.Sprintf("mergeWithChild called on node with %d children", len(node.children))) + } + + child := node.children[0] + + // Concatenate tokens. + node.tokens = append(node.tokens, child.tokens...) + node.endOffset = child.endOffset + + // Merge snapshots per layer. Merge consumes the old snapshots, so we + // remove them first (adjusting the counter), then assign the merged + // result (adjusting it back). + if len(node.snapshots) > 0 || len(child.snapshots) > 0 { + nodeSnaps := node.swapSnapshots(nil, counter) + childSnaps := child.swapSnapshots(nil, counter) + merged := make([]cache.Snapshot, len(caches)) + for i := range caches { + var ps, cs cache.Snapshot + if nodeSnaps != nil { + ps = nodeSnaps[i] + } + if childSnaps != nil { + cs = childSnaps[i] + } + + merged[i] = caches[i].Merge(ps, cs) + } + node.setSnapshots(merged, counter) + } + + // Adopt grandchildren. + node.children = child.children + for _, gc := range node.children { + gc.parent = node + } + + // Inherit user flag from child if child was a user-created snapshot node. + node.user = child.user + + // Update lastUsed to the more recent of the two. + if child.lastUsed.After(node.lastUsed) { + node.lastUsed = child.lastUsed + } + + child.parent = nil + child.children = nil +} + +// walkNodes calls fn for every node in the trie (depth-first). +// If fn returns false, the walk stops. +func walkNodes(root *trieNode, fn func(*trieNode) bool) { + if root == nil { + return + } + var walk func(*trieNode) bool + walk = func(n *trieNode) bool { + if !fn(n) { + return false + } + for _, child := range n.children { + if !walk(child) { + return false + } + } + return true + } + walk(root) +} diff --git a/x/mlxrunner/cache_trie_test.go b/x/mlxrunner/cache_trie_test.go new file mode 100644 index 000000000..43b93ba42 --- /dev/null +++ b/x/mlxrunner/cache_trie_test.go @@ -0,0 +1,455 @@ +package mlxrunner + +import ( + "slices" + "testing" + "time" + + "github.com/ollama/ollama/x/mlxrunner/cache" +) + +func newTestTrie(tokens []int32) *trieNode { + root := &trieNode{lastUsed: time.Now()} + if len(tokens) > 0 { + child := &trieNode{ + tokens: slices.Clone(tokens), + endOffset: len(tokens), + parent: root, + lastUsed: time.Now(), + } + root.children = []*trieNode{child} + } + return root +} + +func TestFindBestMatchMultipleBranches(t *testing.T) { + root := &trieNode{lastUsed: time.Now()} + + branch1 := &trieNode{ + tokens: []int32{1, 2, 3}, + endOffset: 3, + parent: root, + lastUsed: time.Now(), + } + branch2 := &trieNode{ + tokens: []int32{4, 5, 6}, + endOffset: 3, + parent: root, + lastUsed: time.Now(), + } + root.children = []*trieNode{branch1, branch2} + + // Match branch 1. + path, matched := findBestMatch(root, []int32{1, 2, 3, 7}) + if matched != 3 { + t.Fatalf("expected 3 matched, got %d", matched) + } + if len(path) != 2 || path[1] != branch1 { + t.Fatal("expected to match branch1") + } + + // Match branch 2. + path, matched = findBestMatch(root, []int32{4, 5, 6, 8}) + if matched != 3 { + t.Fatalf("expected 3 matched, got %d", matched) + } + if len(path) != 2 || path[1] != branch2 { + t.Fatal("expected to match branch2") + } + + // Match neither. + _, matched = findBestMatch(root, []int32{7, 8, 9}) + if matched != 0 { + t.Fatalf("expected 0 matched, got %d", matched) + } +} + +func TestFindBestMatchPrefersFullEdge(t *testing.T) { + root := &trieNode{lastUsed: time.Now()} + + shared := &trieNode{ + tokens: []int32{1, 2, 3}, + endOffset: 3, + parent: root, + lastUsed: time.Now(), + } + root.children = []*trieNode{shared} + + longer := &trieNode{ + tokens: []int32{10, 11, 12, 13, 14}, + endOffset: 8, + parent: shared, + lastUsed: time.Now(), + } + shorter := &trieNode{ + tokens: []int32{10, 11, 12}, + endOffset: 6, + parent: shared, + lastUsed: time.Now(), + } + // Put longer first so naive first-match would pick it. + shared.children = []*trieNode{longer, shorter} + + input := []int32{1, 2, 3, 10, 11, 12, 99, 100} + path, matched := findBestMatch(root, input) + + if matched != 6 { + t.Fatalf("expected 6 matched, got %d", matched) + } + if len(path) != 3 { + t.Fatalf("expected 3 nodes in path, got %d", len(path)) + } + if path[2] != shorter { + t.Fatal("expected findBestMatch to pick shorter (full edge match), not longer (partial)") + } +} + +func TestFindBestMatchPrefersLongerPartial(t *testing.T) { + root := &trieNode{lastUsed: time.Now()} + + child1 := &trieNode{ + tokens: []int32{1, 2, 3, 4, 5}, + endOffset: 5, + parent: root, + lastUsed: time.Now(), + } + child2 := &trieNode{ + tokens: []int32{1, 2, 9}, + endOffset: 3, + parent: root, + lastUsed: time.Now(), + } + root.children = []*trieNode{child2, child1} + + input := []int32{1, 2, 3, 7, 8} + path, matched := findBestMatch(root, input) + + if matched != 3 { + t.Fatalf("expected 3 matched, got %d", matched) + } + if path[1] != child1 { + t.Fatal("expected findBestMatch to pick child1 (longer partial match)") + } +} + +func TestSplitNodeWithSnapshots(t *testing.T) { + root := newTestTrie([]int32{1, 2, 3, 4, 5}) + child := root.children[0] + + rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}} + child.snapshots = []cache.Snapshot{rc.Snapshot(0)} + child.user = true + + caches := []cache.Cache{rc} + + newParent := splitNode(child, 3, caches, nil) + + if !newParent.hasSnapshots() { + t.Fatal("newParent should have snapshots after split") + } + if newParent.user { + t.Fatal("newParent should not be a user snapshot after splitNode") + } + if !child.hasSnapshots() { + t.Fatal("child should have snapshots after split") + } + if !child.user { + t.Fatal("child should remain a user snapshot") + } +} + +func TestFindSplitAppendSequence(t *testing.T) { + root := newTestTrie([]int32{1, 2, 3, 4, 5}) + + path, matched := findBestMatch(root, []int32{1, 2, 3, 6, 7}) + if matched != 3 { + t.Fatalf("expected 3 matched, got %d", matched) + } + + lastNode := path[len(path)-1] + matchedInEdge := matched - lastNode.startOffset() + split := splitNode(lastNode, matchedInEdge, nil, nil) + + split.appendTokens(root, []int32{6, 7}, 5) + + if len(root.children) != 1 { + t.Fatalf("root should have 1 child, got %d", len(root.children)) + } + shared := root.children[0] + if !slices.Equal(shared.tokens, []int32{1, 2, 3}) { + t.Fatalf("shared tokens = %v, want [1,2,3]", shared.tokens) + } + if len(shared.children) != 2 { + t.Fatalf("shared should have 2 children, got %d", len(shared.children)) + } + + _, m1 := findBestMatch(root, []int32{1, 2, 3, 4, 5}) + if m1 != 5 { + t.Fatalf("original branch: expected 5 matched, got %d", m1) + } + _, m2 := findBestMatch(root, []int32{1, 2, 3, 6, 7}) + if m2 != 5 { + t.Fatalf("new branch: expected 5 matched, got %d", m2) + } + _, m3 := findBestMatch(root, []int32{1, 2, 3, 9, 9}) + if m3 != 3 { + t.Fatalf("unrelated input: expected 3 matched, got %d", m3) + } +} + +func TestRepeatedBranching(t *testing.T) { + root := &trieNode{lastUsed: time.Now()} + + root.appendTokens(root, []int32{1, 2, 3, 4, 5}, 5) + + _, matchedB := findBestMatch(root, []int32{1, 2, 3, 6, 7}) + if matchedB != 3 { + t.Fatalf("B: expected 3 matched, got %d", matchedB) + } + nodeA := root.children[0] + split1 := splitNode(nodeA, 3, nil, nil) + split1.appendTokens(root, []int32{6, 7}, 5) + + _, matchedC := findBestMatch(root, []int32{1, 2, 8, 9}) + if matchedC != 2 { + t.Fatalf("C: expected 2 matched, got %d", matchedC) + } + split2 := splitNode(split1, 2, nil, nil) + split2.appendTokens(root, []int32{8, 9}, 4) + + _, mA := findBestMatch(root, []int32{1, 2, 3, 4, 5}) + if mA != 5 { + t.Fatalf("A: expected 5 matched, got %d", mA) + } + _, mB := findBestMatch(root, []int32{1, 2, 3, 6, 7}) + if mB != 5 { + t.Fatalf("B: expected 5 matched, got %d", mB) + } + _, mC := findBestMatch(root, []int32{1, 2, 8, 9}) + if mC != 4 { + t.Fatalf("C: expected 4 matched, got %d", mC) + } + + checkTrieInvariants(t, root) +} + +func TestMergeWithChild(t *testing.T) { + t.Run("Basic", func(t *testing.T) { + // root -> A[1,2,3] -> B[4,5] -> {C[6], D[7]} + now := time.Now() + root := &trieNode{lastUsed: now} + a := &trieNode{ + tokens: []int32{1, 2, 3}, + endOffset: 3, + parent: root, + lastUsed: now, + snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3}, from: 0, to: 3}}, + } + b := &trieNode{ + tokens: []int32{4, 5}, + endOffset: 5, + parent: a, + lastUsed: now, + snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{4, 5}, from: 3, to: 5}}, + } + c := &trieNode{tokens: []int32{6}, endOffset: 6, parent: b, lastUsed: now} + d := &trieNode{tokens: []int32{7}, endOffset: 6, parent: b, lastUsed: now} + root.children = []*trieNode{a} + a.children = []*trieNode{b} + b.children = []*trieNode{c, d} + + mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}} + mergeWithChild(a, []cache.Cache{mc}, nil) + + // Tokens concatenated. + if !slices.Equal(a.tokens, []int32{1, 2, 3, 4, 5}) { + t.Fatalf("merged tokens = %v, want [1,2,3,4,5]", a.tokens) + } + if a.endOffset != 5 { + t.Fatalf("merged endOffset = %d, want 5", a.endOffset) + } + // Grandchildren reparented. + if len(a.children) != 2 { + t.Fatalf("merged children count = %d, want 2", len(a.children)) + } + if c.parent != a || d.parent != a { + t.Fatal("grandchildren should be reparented to merged node") + } + // B detached. + if b.parent != nil || b.children != nil || b.snapshots != nil { + t.Fatal("child B should be fully detached after merge") + } + // Merged snapshot should cover [0,5). + if !a.hasSnapshots() { + t.Fatal("merged node should have snapshots") + } + ms := a.snapshots[0].(*fakeSnapshot) + if ms.from != 0 || ms.to != 5 { + t.Fatalf("merged snapshot = [%d,%d), want [0,5)", ms.from, ms.to) + } + + checkTrieInvariants(t, root) + }) + + t.Run("UserFlag", func(t *testing.T) { + root := &trieNode{lastUsed: time.Now()} + parent := &trieNode{ + tokens: []int32{1, 2}, endOffset: 2, parent: root, + lastUsed: time.Now(), user: false, + } + child := &trieNode{ + tokens: []int32{3, 4}, endOffset: 4, parent: parent, + lastUsed: time.Now(), user: true, + } + root.children = []*trieNode{parent} + parent.children = []*trieNode{child} + + mergeWithChild(parent, nil, nil) + + if !parent.user { + t.Fatal("merged node should inherit user=true from child") + } + }) + + t.Run("LastUsed", func(t *testing.T) { + now := time.Now() + root := &trieNode{lastUsed: now} + parent := &trieNode{ + tokens: []int32{1}, endOffset: 1, parent: root, + lastUsed: now.Add(-1 * time.Hour), + } + child := &trieNode{ + tokens: []int32{2}, endOffset: 2, parent: parent, + lastUsed: now.Add(1 * time.Hour), + } + root.children = []*trieNode{parent} + parent.children = []*trieNode{child} + + mergeWithChild(parent, nil, nil) + + if !parent.lastUsed.Equal(now.Add(1 * time.Hour)) { + t.Fatal("merged node should pick the more recent lastUsed") + } + }) + + t.Run("PanicOnMultipleChildren", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on node with 2 children") + } + }() + root := &trieNode{lastUsed: time.Now()} + node := &trieNode{ + tokens: []int32{1}, endOffset: 1, parent: root, lastUsed: time.Now(), + children: []*trieNode{ + {tokens: []int32{2}, endOffset: 2, lastUsed: time.Now()}, + {tokens: []int32{3}, endOffset: 2, lastUsed: time.Now()}, + }, + } + root.children = []*trieNode{node} + mergeWithChild(node, nil, nil) + }) +} + +func TestSplitMergeRoundTrip(t *testing.T) { + root := &trieNode{lastUsed: time.Now()} + leaf := &trieNode{ + tokens: []int32{1, 2, 3, 4, 5}, + endOffset: 5, + parent: root, + lastUsed: time.Now(), + snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3, 4, 5}, from: 0, to: 5}}, + } + root.children = []*trieNode{leaf} + + mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}} + caches := []cache.Cache{mc} + + // Split at 3: [1,2,3] -> [4,5] + newParent := splitNode(leaf, 3, caches, nil) + if !slices.Equal(newParent.tokens, []int32{1, 2, 3}) { + t.Fatalf("after split: parent tokens = %v, want [1,2,3]", newParent.tokens) + } + if !slices.Equal(leaf.tokens, []int32{4, 5}) { + t.Fatalf("after split: child tokens = %v, want [4,5]", leaf.tokens) + } + checkTrieInvariants(t, root) + + // Merge back: should restore [1,2,3,4,5] + mergeWithChild(newParent, caches, nil) + if !slices.Equal(newParent.tokens, []int32{1, 2, 3, 4, 5}) { + t.Fatalf("after merge: tokens = %v, want [1,2,3,4,5]", newParent.tokens) + } + if newParent.endOffset != 5 { + t.Fatalf("after merge: endOffset = %d, want 5", newParent.endOffset) + } + if len(newParent.children) != 0 { + t.Fatalf("after merge: children count = %d, want 0", len(newParent.children)) + } + // Merged snapshot should cover [0,5). + if !newParent.hasSnapshots() { + t.Fatal("after merge: should have snapshots") + } + ms := newParent.snapshots[0].(*fakeSnapshot) + if ms.from != 0 || ms.to != 5 { + t.Fatalf("after merge: snapshot = [%d,%d), want [0,5)", ms.from, ms.to) + } + + checkTrieInvariants(t, root) +} + +func TestRemoveNode(t *testing.T) { + t.Run("Leaf", func(t *testing.T) { + root := &trieNode{lastUsed: time.Now()} + shared := &trieNode{ + tokens: []int32{1, 2, 3}, endOffset: 3, parent: root, lastUsed: time.Now(), + } + leafA := &trieNode{ + tokens: []int32{4, 5}, endOffset: 5, parent: shared, lastUsed: time.Now(), + snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}}, + } + leafB := &trieNode{ + tokens: []int32{6, 7}, endOffset: 5, parent: shared, lastUsed: time.Now(), + snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}}, + } + root.children = []*trieNode{shared} + shared.children = []*trieNode{leafA, leafB} + + removeNode(leafA, nil) + + if len(shared.children) != 1 { + t.Fatalf("parent should have 1 child, got %d", len(shared.children)) + } + if shared.children[0] != leafB { + t.Fatal("remaining child should be leafB") + } + if leafA.parent != nil { + t.Fatal("removed node parent should be nil") + } + if leafA.snapshots != nil { + t.Fatal("removed node snapshots should be nil") + } + + checkTrieInvariants(t, root) + }) + + t.Run("PanicOnRoot", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic when removing root") + } + }() + removeNode(&trieNode{}, nil) + }) + + t.Run("PanicOnNonLeaf", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic when removing non-leaf") + } + }() + parent := &trieNode{parent: &trieNode{}} + parent.children = []*trieNode{{}} + removeNode(parent, nil) + }) +} diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go index 43aec769a..f2daa2e28 100644 --- a/x/mlxrunner/mlx/mlx.go +++ b/x/mlxrunner/mlx/mlx.go @@ -18,6 +18,10 @@ func Version() string { } func doEval(outputs []*Array, async bool) { + if len(outputs) == 0 { + return + } + vector := C.mlx_vector_array_new() defer C.mlx_vector_array_free(vector) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 3ce148c02..ea7e12a30 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -50,7 +50,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { mlx.LogArrays() - r.cache.log() + r.cache.dumpTree() } slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory())) }() @@ -86,7 +86,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { materializeCaches := func() { state := make([]*mlx.Array, 0, 2*len(caches)) for _, c := range caches { - state = appendCacheState(state, c) + state = append(state, c.State()...) } if len(state) == 0 { return @@ -102,11 +102,32 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } n := min(prefillChunk, total-processed-1) + + // If there's a pending intermediate snapshot, split the batch + // so we can capture it at the exact offset. The cache offset + // after this batch will be: baseOffset + processed + n. + if session.snapshotOffset > 0 { + baseOffset := len(session.inputs) - len(tokens) + tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed) + if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n { + n = tokensUntilSnapshot + } + } + r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) mlx.Sweep() materializeCaches() processed += n slog.Info("Prompt processing progress", "processed", processed, "total", total) + + // Create snapshot at branch point for future diverging requests. + if session.snapshotOffset > 0 { + baseOffset := len(session.inputs) - len(tokens) + if baseOffset+processed >= session.snapshotOffset { + session.snapshot(false) + } + } + mlx.ClearCache() }