mlxrunner: schedule periodic snapshots during prefill

Add periodic snapshots every 8k tokens and near the end of the prompt
so that long prompts can be partially restored and thinking/generation
can be retried without full reprocessing.
This commit is contained in:
Jesse Gross
2026-03-24 16:55:49 -07:00
parent 74c4a72fe8
commit 015546fded
3 changed files with 109 additions and 73 deletions

View File

@@ -20,6 +20,7 @@ import (
"cmp"
"fmt"
"log/slog"
"slices"
"time"
"github.com/ollama/ollama/logutil"
@@ -37,6 +38,12 @@ type kvCache struct {
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
}
// pendingSnapshot is a snapshot scheduled to be taken during prefill.
type pendingSnapshot struct {
offset int
user bool
}
// cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and
// defer close to save the cache state.
@@ -48,11 +55,10 @@ type cacheSession struct {
caches []cache.Cache
remaining []int32
// snapshotOffset, if > 0, is a trie node boundary where we need to
// capture a snapshot during prefill. This enables future requests
// branching at this point to restore non-rewindable caches (e.g.
// RecurrentCache) instead of re-evaluating from scratch.
snapshotOffset int
// pendingSnapshots lists offsets where snapshots should be captured
// during prefill, sorted by offset. Entries are consumed as the
// cache advances past them.
pendingSnapshots []pendingSnapshot
}
func (c *kvCache) ensureCaches(m base.Model) {
@@ -100,31 +106,26 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
prefix := c.minCacheOffset()
remaining := inputs[prefix:]
// Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating.
var snapshotAt int
if prefix < matched {
snapshotAt = matched
}
args := []any{"total", len(inputs), "matched", originalMatched}
args = append(args, "cached", prefix, "left", len(remaining))
if snapshotAt > 0 {
args = append(args, "pending_snapshot", snapshotAt)
}
if prefix == 0 {
slog.Info("cache miss", args...)
} else {
slog.Info("cache hit", args...)
}
return &cacheSession{
session := &cacheSession{
cache: c,
inputs: inputs,
snapshotOffset: snapshotAt,
caches: c.caches,
remaining: remaining,
}
// Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating.
if prefix < matched {
session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false})
}
msg := "cache hit"
if prefix == 0 {
msg = "cache miss"
}
slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining))
return session
}
// switchToPath transitions from the current active path to a new path,
@@ -250,20 +251,54 @@ pageIn:
}
}
// snapshot creates a snapshot at the current cache position. During prefill,
// it is called at branch points (user=false) to create restore points for
// future diverging requests and with user=true to mark an explicit reusable
// restore point.
func (s *cacheSession) snapshot(user bool) {
// requestSnapshot schedules a user snapshot at the given absolute token
// offset. The snapshot will be captured during prefill when the cache
// reaches this offset.
func (s *cacheSession) requestSnapshot(offset int) {
baseOffset := len(s.inputs) - len(s.remaining)
if offset <= baseOffset || offset > len(s.inputs) {
return
}
// Deduplicate: if this offset already exists, upgrade to user.
for i := range s.pendingSnapshots {
if s.pendingSnapshots[i].offset == offset {
s.pendingSnapshots[i].user = true
return
}
}
s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true})
slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int {
return a.offset - b.offset
})
}
// nextPendingSnapshot returns the offset of the next pending snapshot,
// or 0 if there are none.
func (s *cacheSession) nextPendingSnapshot() int {
if len(s.pendingSnapshots) == 0 {
return 0
}
return s.pendingSnapshots[0].offset
}
// snapshot creates a snapshot at the current cache position. It determines
// whether this is a user snapshot by consuming pending entries whose offset
// has been reached.
func (s *cacheSession) snapshot() {
c := s.cache
cacheOffset := c.minCacheOffset()
if cacheOffset <= 0 {
return
}
// Clear pending intermediate snapshot if we've reached or passed it.
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
s.snapshotOffset = 0
// Consume pending snapshots up to the current offset and derive
// the user flag from them.
user := false
for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset {
if s.pendingSnapshots[0].user {
user = true
}
s.pendingSnapshots = s.pendingSnapshots[1:]
}
// The last node in activePath is the frontier where caches are advancing.

View File

@@ -378,23 +378,24 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
type requestResult struct {
remaining []int32
snapshotOffset int
pendingSnapshots int
}
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
// a user snapshot (snapshot(true)) is created at that offset during prefill.
// a user snapshot is requested at that offset during prefill.
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
t.Helper()
userSnapAt := 0
if len(userSnapshotAt) > 0 {
userSnapAt = userSnapshotAt[0]
session := kvc.begin(nil, inputs)
for _, at := range userSnapshotAt {
if at > 0 {
session.requestSnapshot(at)
}
}
session := kvc.begin(nil, inputs)
result := requestResult{
remaining: slices.Clone(session.remaining),
snapshotOffset: session.snapshotOffset,
pendingSnapshots: len(session.pendingSnapshots),
}
assertCacheOffsetAlignment(t, kvc, "after begin")
@@ -402,22 +403,9 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user
baseOffset := kvc.minCacheOffset()
remaining := inputs[baseOffset:]
// Collect snapshot points (offset -> user flag) in ascending order.
type snapPoint struct {
offset int
user bool
}
var points []snapPoint
if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset {
points = append(points, snapPoint{session.snapshotOffset, false})
}
if userSnapAt > 0 && userSnapAt > baseOffset {
points = append(points, snapPoint{userSnapAt, true})
}
slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset })
// Prefill: feed tokens, pausing at each snapshot point.
for _, sp := range points {
// Prefill: feed tokens, pausing at each pending snapshot.
for len(session.pendingSnapshots) > 0 {
sp := session.pendingSnapshots[0]
count := sp.offset - baseOffset
if count > len(remaining) {
break
@@ -428,7 +416,7 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user
baseOffset = sp.offset
}
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
session.snapshot(sp.user)
session.snapshot()
}
// Feed rest of input tokens.
@@ -615,15 +603,15 @@ func TestBranchCreationAndReuse(t *testing.T) {
// caches (RecurrentCache), the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
if env.rewindable {
if resB.snapshotOffset != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
if resB.pendingSnapshots != 0 {
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
}
if len(resB.remaining) != 3 {
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
}
} else {
if resB.snapshotOffset != 5 {
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
if resB.pendingSnapshots != 1 {
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
}
if len(resB.remaining) != 8 {
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
@@ -672,15 +660,15 @@ func TestExactMatchSeedBehavior(t *testing.T) {
if len(resB.remaining) != 1 {
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
}
if resB.snapshotOffset != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
if resB.pendingSnapshots != 0 {
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
}
} else {
if len(resB.remaining) != 5 {
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
}
if resB.snapshotOffset != 4 {
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
if resB.pendingSnapshots != 1 {
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
}
}
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})

View File

@@ -79,10 +79,24 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches := session.caches
tokens := session.remaining
prefillChunk := prefillChunkSize()
// Request periodic snapshots during prefill and near the end of the
// prompt so that long prompts can be partially restored and
// thinking/generation can be retried without full reprocessing.
const snapshotInterval = 8192
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
session.requestSnapshot(offset)
}
const preThinking = 4
if end := len(inputs) - preThinking; end > 0 {
session.requestSnapshot(end)
}
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
@@ -103,12 +117,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
n := min(prefillChunk, total-processed-1)
// If there's a pending intermediate snapshot, split the batch
// so we can capture it at the exact offset. The cache offset
// after this batch will be: baseOffset + processed + n.
if session.snapshotOffset > 0 {
// If there's a pending snapshot, split the batch so we can
// capture it at the exact offset.
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed)
tokensUntilSnapshot := snapOffset - (baseOffset + processed)
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
n = tokensUntilSnapshot
}
@@ -120,11 +133,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
// Create snapshot at branch point for future diverging requests.
if session.snapshotOffset > 0 {
// Create snapshot if we've reached a pending offset.
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
if baseOffset+processed >= session.snapshotOffset {
session.snapshot(false)
if baseOffset+processed >= snapOffset {
session.snapshot()
}
}