mlxrunner: schedule periodic snapshots during prefill

Add periodic snapshots every 8k tokens and near the end of the prompt
so that long prompts can be partially restored and thinking/generation
can be retried without full reprocessing.
This commit is contained in:
Jesse Gross
2026-03-24 16:55:49 -07:00
parent 74c4a72fe8
commit 015546fded
3 changed files with 109 additions and 73 deletions

View File

@@ -20,6 +20,7 @@ import (
"cmp" "cmp"
"fmt" "fmt"
"log/slog" "log/slog"
"slices"
"time" "time"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
@@ -37,6 +38,12 @@ type kvCache struct {
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
} }
// pendingSnapshot is a snapshot scheduled to be taken during prefill.
type pendingSnapshot struct {
offset int
user bool
}
// cacheSession manages caches for a single pipeline run. // cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and // Callers should append generated tokens to outputs and
// defer close to save the cache state. // defer close to save the cache state.
@@ -48,11 +55,10 @@ type cacheSession struct {
caches []cache.Cache caches []cache.Cache
remaining []int32 remaining []int32
// snapshotOffset, if > 0, is a trie node boundary where we need to // pendingSnapshots lists offsets where snapshots should be captured
// capture a snapshot during prefill. This enables future requests // during prefill, sorted by offset. Entries are consumed as the
// branching at this point to restore non-rewindable caches (e.g. // cache advances past them.
// RecurrentCache) instead of re-evaluating from scratch. pendingSnapshots []pendingSnapshot
snapshotOffset int
} }
func (c *kvCache) ensureCaches(m base.Model) { func (c *kvCache) ensureCaches(m base.Model) {
@@ -100,31 +106,26 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
prefix := c.minCacheOffset() prefix := c.minCacheOffset()
remaining := inputs[prefix:] remaining := inputs[prefix:]
session := &cacheSession{
cache: c,
inputs: inputs,
caches: c.caches,
remaining: remaining,
}
// Schedule a snapshot at the branch point during prefill so future // Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating. // requests diverging here can restore instead of re-evaluating.
var snapshotAt int
if prefix < matched { if prefix < matched {
snapshotAt = matched session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false})
} }
args := []any{"total", len(inputs), "matched", originalMatched} msg := "cache hit"
args = append(args, "cached", prefix, "left", len(remaining))
if snapshotAt > 0 {
args = append(args, "pending_snapshot", snapshotAt)
}
if prefix == 0 { if prefix == 0 {
slog.Info("cache miss", args...) msg = "cache miss"
} else {
slog.Info("cache hit", args...)
} }
slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining))
return &cacheSession{ return session
cache: c,
inputs: inputs,
snapshotOffset: snapshotAt,
caches: c.caches,
remaining: remaining,
}
} }
// switchToPath transitions from the current active path to a new path, // switchToPath transitions from the current active path to a new path,
@@ -250,20 +251,54 @@ pageIn:
} }
} }
// snapshot creates a snapshot at the current cache position. During prefill, // requestSnapshot schedules a user snapshot at the given absolute token
// it is called at branch points (user=false) to create restore points for // offset. The snapshot will be captured during prefill when the cache
// future diverging requests and with user=true to mark an explicit reusable // reaches this offset.
// restore point. func (s *cacheSession) requestSnapshot(offset int) {
func (s *cacheSession) snapshot(user bool) { baseOffset := len(s.inputs) - len(s.remaining)
if offset <= baseOffset || offset > len(s.inputs) {
return
}
// Deduplicate: if this offset already exists, upgrade to user.
for i := range s.pendingSnapshots {
if s.pendingSnapshots[i].offset == offset {
s.pendingSnapshots[i].user = true
return
}
}
s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true})
slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int {
return a.offset - b.offset
})
}
// nextPendingSnapshot returns the offset of the next pending snapshot,
// or 0 if there are none.
func (s *cacheSession) nextPendingSnapshot() int {
if len(s.pendingSnapshots) == 0 {
return 0
}
return s.pendingSnapshots[0].offset
}
// snapshot creates a snapshot at the current cache position. It determines
// whether this is a user snapshot by consuming pending entries whose offset
// has been reached.
func (s *cacheSession) snapshot() {
c := s.cache c := s.cache
cacheOffset := c.minCacheOffset() cacheOffset := c.minCacheOffset()
if cacheOffset <= 0 { if cacheOffset <= 0 {
return return
} }
// Clear pending intermediate snapshot if we've reached or passed it. // Consume pending snapshots up to the current offset and derive
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset { // the user flag from them.
s.snapshotOffset = 0 user := false
for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset {
if s.pendingSnapshots[0].user {
user = true
}
s.pendingSnapshots = s.pendingSnapshots[1:]
} }
// The last node in activePath is the frontier where caches are advancing. // The last node in activePath is the frontier where caches are advancing.

View File

@@ -377,24 +377,25 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
// begin -> prefill with snapshot(false) at branch points -> generate -> close // begin -> prefill with snapshot(false) at branch points -> generate -> close
type requestResult struct { type requestResult struct {
remaining []int32 remaining []int32
snapshotOffset int pendingSnapshots int
} }
// simulateRequest runs a request through the harness. If userSnapshotAt > 0, // simulateRequest runs a request through the harness. If userSnapshotAt > 0,
// a user snapshot (snapshot(true)) is created at that offset during prefill. // a user snapshot is requested at that offset during prefill.
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult { func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
t.Helper() t.Helper()
userSnapAt := 0 session := kvc.begin(nil, inputs)
if len(userSnapshotAt) > 0 { for _, at := range userSnapshotAt {
userSnapAt = userSnapshotAt[0] if at > 0 {
session.requestSnapshot(at)
}
} }
session := kvc.begin(nil, inputs)
result := requestResult{ result := requestResult{
remaining: slices.Clone(session.remaining), remaining: slices.Clone(session.remaining),
snapshotOffset: session.snapshotOffset, pendingSnapshots: len(session.pendingSnapshots),
} }
assertCacheOffsetAlignment(t, kvc, "after begin") assertCacheOffsetAlignment(t, kvc, "after begin")
@@ -402,22 +403,9 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user
baseOffset := kvc.minCacheOffset() baseOffset := kvc.minCacheOffset()
remaining := inputs[baseOffset:] remaining := inputs[baseOffset:]
// Collect snapshot points (offset -> user flag) in ascending order. // Prefill: feed tokens, pausing at each pending snapshot.
type snapPoint struct { for len(session.pendingSnapshots) > 0 {
offset int sp := session.pendingSnapshots[0]
user bool
}
var points []snapPoint
if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset {
points = append(points, snapPoint{session.snapshotOffset, false})
}
if userSnapAt > 0 && userSnapAt > baseOffset {
points = append(points, snapPoint{userSnapAt, true})
}
slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset })
// Prefill: feed tokens, pausing at each snapshot point.
for _, sp := range points {
count := sp.offset - baseOffset count := sp.offset - baseOffset
if count > len(remaining) { if count > len(remaining) {
break break
@@ -428,7 +416,7 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user
baseOffset = sp.offset baseOffset = sp.offset
} }
assertCacheOffsetAlignment(t, kvc, "at snapshot point") assertCacheOffsetAlignment(t, kvc, "at snapshot point")
session.snapshot(sp.user) session.snapshot()
} }
// Feed rest of input tokens. // Feed rest of input tokens.
@@ -615,15 +603,15 @@ func TestBranchCreationAndReuse(t *testing.T) {
// caches (RecurrentCache), the rewind fails and freeAll fires. // caches (RecurrentCache), the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31}) resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
if env.rewindable { if env.rewindable {
if resB.snapshotOffset != 0 { if resB.pendingSnapshots != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset) t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
} }
if len(resB.remaining) != 3 { if len(resB.remaining) != 3 {
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining)) t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
} }
} else { } else {
if resB.snapshotOffset != 5 { if resB.pendingSnapshots != 1 {
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset) t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
} }
if len(resB.remaining) != 8 { if len(resB.remaining) != 8 {
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining)) t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
@@ -672,15 +660,15 @@ func TestExactMatchSeedBehavior(t *testing.T) {
if len(resB.remaining) != 1 { if len(resB.remaining) != 1 {
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining)) t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
} }
if resB.snapshotOffset != 0 { if resB.pendingSnapshots != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset) t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
} }
} else { } else {
if len(resB.remaining) != 5 { if len(resB.remaining) != 5 {
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining)) t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
} }
if resB.snapshotOffset != 4 { if resB.pendingSnapshots != 1 {
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset) t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
} }
} }
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21}) env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})

View File

@@ -79,10 +79,24 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
session := r.cache.begin(r.Model, inputs) session := r.cache.begin(r.Model, inputs)
defer session.close() defer session.close()
caches := session.caches caches := session.caches
tokens := session.remaining tokens := session.remaining
prefillChunk := prefillChunkSize() prefillChunk := prefillChunkSize()
// Request periodic snapshots during prefill and near the end of the
// prompt so that long prompts can be partially restored and
// thinking/generation can be retried without full reprocessing.
const snapshotInterval = 8192
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
session.requestSnapshot(offset)
}
const preThinking = 4
if end := len(inputs) - preThinking; end > 0 {
session.requestSnapshot(end)
}
materializeCaches := func() { materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches)) state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches { for _, c := range caches {
@@ -103,12 +117,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
n := min(prefillChunk, total-processed-1) n := min(prefillChunk, total-processed-1)
// If there's a pending intermediate snapshot, split the batch // If there's a pending snapshot, split the batch so we can
// so we can capture it at the exact offset. The cache offset // capture it at the exact offset.
// after this batch will be: baseOffset + processed + n. if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
if session.snapshotOffset > 0 {
baseOffset := len(session.inputs) - len(tokens) baseOffset := len(session.inputs) - len(tokens)
tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed) tokensUntilSnapshot := snapOffset - (baseOffset + processed)
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n { if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
n = tokensUntilSnapshot n = tokensUntilSnapshot
} }
@@ -120,11 +133,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
processed += n processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total) slog.Info("Prompt processing progress", "processed", processed, "total", total)
// Create snapshot at branch point for future diverging requests. // Create snapshot if we've reached a pending offset.
if session.snapshotOffset > 0 { if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens) baseOffset := len(session.inputs) - len(tokens)
if baseOffset+processed >= session.snapshotOffset { if baseOffset+processed >= snapOffset {
session.snapshot(false) session.snapshot()
} }
} }