mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlxrunner: improve eviction and LRU tracking
Update LRU last used time just on the nodes that actually used during processing rather than all snapshots along the path. This allows eviction to remove nodes more accurately so we can avoid other heuristics to auto-merge nodes.
This commit is contained in:
@@ -238,6 +238,13 @@ pageIn:
|
||||
}
|
||||
}
|
||||
|
||||
// Update last-used time on only the final used node. For recurrent
|
||||
// caches we don't need the intermediate snapshots and for KV caches
|
||||
// we can reslice the data out of merged edges.
|
||||
if len(c.activePath) > 0 {
|
||||
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
|
||||
}
|
||||
|
||||
if pageOutCount > 0 || pageInCount > 0 {
|
||||
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
|
||||
}
|
||||
@@ -355,6 +362,7 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
||||
}
|
||||
}
|
||||
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||
node.lastUsed = time.Now()
|
||||
slog.Debug("created snapshot", "offset", cacheOffset)
|
||||
c.enforceEvictionPolicy()
|
||||
}
|
||||
@@ -412,10 +420,7 @@ func (s *cacheSession) close() {
|
||||
newTokens := stored[frontier.endOffset:offset]
|
||||
c.advancePath(frontier, newTokens, offset)
|
||||
}
|
||||
now := time.Now()
|
||||
for _, node := range c.activePath {
|
||||
node.lastUsed = now
|
||||
}
|
||||
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -433,7 +438,7 @@ func (c *kvCache) enforceEvictionPolicy() {
|
||||
for c.pagedOutBytes > maxPagedOutBytes {
|
||||
var best *trieNode
|
||||
walkNodes(c.root, func(n *trieNode) bool {
|
||||
if n == c.root || activeSet[n] || !n.hasSnapshots() {
|
||||
if n == c.root || activeSet[n] || len(n.children) > 1 {
|
||||
return true
|
||||
}
|
||||
// Evict: oldest, then deepest, then largest.
|
||||
@@ -457,27 +462,16 @@ func (c *kvCache) enforceEvictionPolicy() {
|
||||
func (c *kvCache) evictNode(node *trieNode) {
|
||||
if len(node.children) == 0 {
|
||||
// Leaf: remove entirely.
|
||||
parent := node.parent
|
||||
kind := "evicting leaf"
|
||||
if node.user {
|
||||
kind = "evicting user snapshot"
|
||||
}
|
||||
slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||
slog.Debug("evicting leaf", "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||
removeNode(node, &c.pagedOutBytes)
|
||||
|
||||
// If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge.
|
||||
if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root {
|
||||
logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset))
|
||||
mergeWithChild(parent, c.caches, &c.pagedOutBytes)
|
||||
}
|
||||
} else if len(node.children) == 1 {
|
||||
// Interior snapshot node with one child: merge with child.
|
||||
slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||
// Interior node with one child: merge with child.
|
||||
before := c.pagedOutBytes
|
||||
tokens := len(node.tokens)
|
||||
mergeWithChild(node, c.caches, &c.pagedOutBytes)
|
||||
slog.Debug("evicting interior node", "offset", node.startOffset(), "tokens", tokens, "freed", mlx.PrettyBytes(int(before-c.pagedOutBytes)))
|
||||
} else {
|
||||
// Multi-child branch point: drop snapshots but keep the node.
|
||||
slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||
node.setSnapshots(nil, &c.pagedOutBytes)
|
||||
panic("evictNode called on multi-child branch point")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package mlxrunner
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
@@ -761,8 +762,8 @@ func TestEvictionPreservesActiveConversations(t *testing.T) {
|
||||
t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath))
|
||||
}
|
||||
|
||||
// System prompt prefix should still be findable (evicting a
|
||||
// multi-child branch point only drops snapshots, not the node).
|
||||
// System prompt prefix should still be findable (multi-child
|
||||
// branch points are protected from eviction entirely).
|
||||
_, matched := findBestMatch(kvc.root, systemPrompt)
|
||||
if matched < len(systemPrompt) {
|
||||
t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt))
|
||||
@@ -895,3 +896,55 @@ func TestBranchSwitchRestoresCorrectState(t *testing.T) {
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLRUOnlyUpdatesUsedNodes verifies that intermediate nodes on the active
|
||||
// path whose snapshots were not actually restored don't get their lastUsed
|
||||
// refreshed, allowing them to age out and collapse.
|
||||
func TestLRUOnlyUpdatesUsedNodes(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: creates path [1,2,3,4,5] + generate [10,11]
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||
|
||||
// Request B: diverges at token 4, creating a branch point at offset 3
|
||||
// with a split snapshot.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20, 21})
|
||||
|
||||
// Set all lastUsed to a known old time.
|
||||
oldTime := time.Now().Add(-1 * time.Hour)
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
n.lastUsed = oldTime
|
||||
return true
|
||||
})
|
||||
|
||||
// Request C: continue on B's branch. This will match B's path
|
||||
// and extend it. The branch point's snapshot may be paged in
|
||||
// for some cache types but not others.
|
||||
beforeRequest := time.Now()
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7, 20, 21, 30}, nil)
|
||||
|
||||
// The path must have enough depth to exercise intermediate nodes.
|
||||
if len(kvc.activePath) < 3 {
|
||||
t.Fatalf("activePath too short to test intermediate nodes: got %d nodes", len(kvc.activePath))
|
||||
}
|
||||
|
||||
// The frontier (deepest node on the active path) must be updated.
|
||||
frontier := kvc.activePath[len(kvc.activePath)-1]
|
||||
if frontier.lastUsed.Before(beforeRequest) {
|
||||
t.Errorf("frontier lastUsed was not updated: got %v, want >= %v",
|
||||
frontier.lastUsed, beforeRequest)
|
||||
}
|
||||
|
||||
// Every non-frontier node on the active path (including root)
|
||||
// should retain its old lastUsed — only the frontier gets refreshed.
|
||||
for i, node := range kvc.activePath[:len(kvc.activePath)-1] {
|
||||
if !node.lastUsed.Before(beforeRequest) {
|
||||
t.Errorf("activePath[%d] (endOffset=%d) lastUsed was refreshed: got %v, want < %v",
|
||||
i, node.endOffset, node.lastUsed, beforeRequest)
|
||||
}
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user