mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlxrunner: schedule periodic snapshots during prefill
Add periodic snapshots every 8k tokens and near the end of the prompt so that long prompts can be partially restored and thinking/generation can be retried without full reprocessing.
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -37,6 +38,12 @@ type kvCache struct {
|
||||
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
|
||||
}
|
||||
|
||||
// pendingSnapshot is a snapshot scheduled to be taken during prefill.
|
||||
type pendingSnapshot struct {
|
||||
offset int
|
||||
user bool
|
||||
}
|
||||
|
||||
// cacheSession manages caches for a single pipeline run.
|
||||
// Callers should append generated tokens to outputs and
|
||||
// defer close to save the cache state.
|
||||
@@ -48,11 +55,10 @@ type cacheSession struct {
|
||||
caches []cache.Cache
|
||||
remaining []int32
|
||||
|
||||
// snapshotOffset, if > 0, is a trie node boundary where we need to
|
||||
// capture a snapshot during prefill. This enables future requests
|
||||
// branching at this point to restore non-rewindable caches (e.g.
|
||||
// RecurrentCache) instead of re-evaluating from scratch.
|
||||
snapshotOffset int
|
||||
// pendingSnapshots lists offsets where snapshots should be captured
|
||||
// during prefill, sorted by offset. Entries are consumed as the
|
||||
// cache advances past them.
|
||||
pendingSnapshots []pendingSnapshot
|
||||
}
|
||||
|
||||
func (c *kvCache) ensureCaches(m base.Model) {
|
||||
@@ -100,31 +106,26 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
prefix := c.minCacheOffset()
|
||||
remaining := inputs[prefix:]
|
||||
|
||||
// Schedule a snapshot at the branch point during prefill so future
|
||||
// requests diverging here can restore instead of re-evaluating.
|
||||
var snapshotAt int
|
||||
if prefix < matched {
|
||||
snapshotAt = matched
|
||||
}
|
||||
|
||||
args := []any{"total", len(inputs), "matched", originalMatched}
|
||||
args = append(args, "cached", prefix, "left", len(remaining))
|
||||
if snapshotAt > 0 {
|
||||
args = append(args, "pending_snapshot", snapshotAt)
|
||||
}
|
||||
if prefix == 0 {
|
||||
slog.Info("cache miss", args...)
|
||||
} else {
|
||||
slog.Info("cache hit", args...)
|
||||
}
|
||||
|
||||
return &cacheSession{
|
||||
session := &cacheSession{
|
||||
cache: c,
|
||||
inputs: inputs,
|
||||
snapshotOffset: snapshotAt,
|
||||
caches: c.caches,
|
||||
remaining: remaining,
|
||||
}
|
||||
|
||||
// Schedule a snapshot at the branch point during prefill so future
|
||||
// requests diverging here can restore instead of re-evaluating.
|
||||
if prefix < matched {
|
||||
session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false})
|
||||
}
|
||||
|
||||
msg := "cache hit"
|
||||
if prefix == 0 {
|
||||
msg = "cache miss"
|
||||
}
|
||||
slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining))
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// switchToPath transitions from the current active path to a new path,
|
||||
@@ -250,20 +251,54 @@ pageIn:
|
||||
}
|
||||
}
|
||||
|
||||
// snapshot creates a snapshot at the current cache position. During prefill,
|
||||
// it is called at branch points (user=false) to create restore points for
|
||||
// future diverging requests and with user=true to mark an explicit reusable
|
||||
// restore point.
|
||||
func (s *cacheSession) snapshot(user bool) {
|
||||
// requestSnapshot schedules a user snapshot at the given absolute token
|
||||
// offset. The snapshot will be captured during prefill when the cache
|
||||
// reaches this offset.
|
||||
func (s *cacheSession) requestSnapshot(offset int) {
|
||||
baseOffset := len(s.inputs) - len(s.remaining)
|
||||
if offset <= baseOffset || offset > len(s.inputs) {
|
||||
return
|
||||
}
|
||||
// Deduplicate: if this offset already exists, upgrade to user.
|
||||
for i := range s.pendingSnapshots {
|
||||
if s.pendingSnapshots[i].offset == offset {
|
||||
s.pendingSnapshots[i].user = true
|
||||
return
|
||||
}
|
||||
}
|
||||
s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true})
|
||||
slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int {
|
||||
return a.offset - b.offset
|
||||
})
|
||||
}
|
||||
|
||||
// nextPendingSnapshot returns the offset of the next pending snapshot,
|
||||
// or 0 if there are none.
|
||||
func (s *cacheSession) nextPendingSnapshot() int {
|
||||
if len(s.pendingSnapshots) == 0 {
|
||||
return 0
|
||||
}
|
||||
return s.pendingSnapshots[0].offset
|
||||
}
|
||||
|
||||
// snapshot creates a snapshot at the current cache position. It determines
|
||||
// whether this is a user snapshot by consuming pending entries whose offset
|
||||
// has been reached.
|
||||
func (s *cacheSession) snapshot() {
|
||||
c := s.cache
|
||||
cacheOffset := c.minCacheOffset()
|
||||
if cacheOffset <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear pending intermediate snapshot if we've reached or passed it.
|
||||
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
|
||||
s.snapshotOffset = 0
|
||||
// Consume pending snapshots up to the current offset and derive
|
||||
// the user flag from them.
|
||||
user := false
|
||||
for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset {
|
||||
if s.pendingSnapshots[0].user {
|
||||
user = true
|
||||
}
|
||||
s.pendingSnapshots = s.pendingSnapshots[1:]
|
||||
}
|
||||
|
||||
// The last node in activePath is the frontier where caches are advancing.
|
||||
|
||||
@@ -378,23 +378,24 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
|
||||
|
||||
type requestResult struct {
|
||||
remaining []int32
|
||||
snapshotOffset int
|
||||
pendingSnapshots int
|
||||
}
|
||||
|
||||
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
|
||||
// a user snapshot (snapshot(true)) is created at that offset during prefill.
|
||||
// a user snapshot is requested at that offset during prefill.
|
||||
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
|
||||
t.Helper()
|
||||
|
||||
userSnapAt := 0
|
||||
if len(userSnapshotAt) > 0 {
|
||||
userSnapAt = userSnapshotAt[0]
|
||||
session := kvc.begin(nil, inputs)
|
||||
for _, at := range userSnapshotAt {
|
||||
if at > 0 {
|
||||
session.requestSnapshot(at)
|
||||
}
|
||||
}
|
||||
|
||||
session := kvc.begin(nil, inputs)
|
||||
result := requestResult{
|
||||
remaining: slices.Clone(session.remaining),
|
||||
snapshotOffset: session.snapshotOffset,
|
||||
pendingSnapshots: len(session.pendingSnapshots),
|
||||
}
|
||||
|
||||
assertCacheOffsetAlignment(t, kvc, "after begin")
|
||||
@@ -402,22 +403,9 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user
|
||||
baseOffset := kvc.minCacheOffset()
|
||||
remaining := inputs[baseOffset:]
|
||||
|
||||
// Collect snapshot points (offset -> user flag) in ascending order.
|
||||
type snapPoint struct {
|
||||
offset int
|
||||
user bool
|
||||
}
|
||||
var points []snapPoint
|
||||
if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset {
|
||||
points = append(points, snapPoint{session.snapshotOffset, false})
|
||||
}
|
||||
if userSnapAt > 0 && userSnapAt > baseOffset {
|
||||
points = append(points, snapPoint{userSnapAt, true})
|
||||
}
|
||||
slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset })
|
||||
|
||||
// Prefill: feed tokens, pausing at each snapshot point.
|
||||
for _, sp := range points {
|
||||
// Prefill: feed tokens, pausing at each pending snapshot.
|
||||
for len(session.pendingSnapshots) > 0 {
|
||||
sp := session.pendingSnapshots[0]
|
||||
count := sp.offset - baseOffset
|
||||
if count > len(remaining) {
|
||||
break
|
||||
@@ -428,7 +416,7 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user
|
||||
baseOffset = sp.offset
|
||||
}
|
||||
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
|
||||
session.snapshot(sp.user)
|
||||
session.snapshot()
|
||||
}
|
||||
|
||||
// Feed rest of input tokens.
|
||||
@@ -615,15 +603,15 @@ func TestBranchCreationAndReuse(t *testing.T) {
|
||||
// caches (RecurrentCache), the rewind fails and freeAll fires.
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
||||
if env.rewindable {
|
||||
if resB.snapshotOffset != 0 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
||||
if resB.pendingSnapshots != 0 {
|
||||
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
|
||||
}
|
||||
if len(resB.remaining) != 3 {
|
||||
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
|
||||
}
|
||||
} else {
|
||||
if resB.snapshotOffset != 5 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
||||
if resB.pendingSnapshots != 1 {
|
||||
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
|
||||
}
|
||||
if len(resB.remaining) != 8 {
|
||||
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
|
||||
@@ -672,15 +660,15 @@ func TestExactMatchSeedBehavior(t *testing.T) {
|
||||
if len(resB.remaining) != 1 {
|
||||
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
|
||||
}
|
||||
if resB.snapshotOffset != 0 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
||||
if resB.pendingSnapshots != 0 {
|
||||
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
|
||||
}
|
||||
} else {
|
||||
if len(resB.remaining) != 5 {
|
||||
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
|
||||
}
|
||||
if resB.snapshotOffset != 4 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
||||
if resB.pendingSnapshots != 1 {
|
||||
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
|
||||
}
|
||||
}
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
||||
|
||||
@@ -79,10 +79,24 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
prefillChunk := prefillChunkSize()
|
||||
|
||||
// Request periodic snapshots during prefill and near the end of the
|
||||
// prompt so that long prompts can be partially restored and
|
||||
// thinking/generation can be retried without full reprocessing.
|
||||
const snapshotInterval = 8192
|
||||
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
|
||||
session.requestSnapshot(offset)
|
||||
}
|
||||
|
||||
const preThinking = 4
|
||||
if end := len(inputs) - preThinking; end > 0 {
|
||||
session.requestSnapshot(end)
|
||||
}
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
@@ -103,12 +117,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
|
||||
// If there's a pending intermediate snapshot, split the batch
|
||||
// so we can capture it at the exact offset. The cache offset
|
||||
// after this batch will be: baseOffset + processed + n.
|
||||
if session.snapshotOffset > 0 {
|
||||
// If there's a pending snapshot, split the batch so we can
|
||||
// capture it at the exact offset.
|
||||
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
|
||||
baseOffset := len(session.inputs) - len(tokens)
|
||||
tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed)
|
||||
tokensUntilSnapshot := snapOffset - (baseOffset + processed)
|
||||
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
||||
n = tokensUntilSnapshot
|
||||
}
|
||||
@@ -120,11 +133,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
|
||||
// Create snapshot at branch point for future diverging requests.
|
||||
if session.snapshotOffset > 0 {
|
||||
// Create snapshot if we've reached a pending offset.
|
||||
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
|
||||
baseOffset := len(session.inputs) - len(tokens)
|
||||
if baseOffset+processed >= session.snapshotOffset {
|
||||
session.snapshot(false)
|
||||
if baseOffset+processed >= snapOffset {
|
||||
session.snapshot()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user