mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlxrunner: schedule periodic snapshots during prefill
Add periodic snapshots every 8k tokens and near the end of the prompt so that long prompts can be partially restored and thinking/generation can be retried without full reprocessing.
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
|||||||
"cmp"
|
"cmp"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
@@ -37,6 +38,12 @@ type kvCache struct {
|
|||||||
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
|
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pendingSnapshot is a snapshot scheduled to be taken during prefill.
|
||||||
|
type pendingSnapshot struct {
|
||||||
|
offset int
|
||||||
|
user bool
|
||||||
|
}
|
||||||
|
|
||||||
// cacheSession manages caches for a single pipeline run.
|
// cacheSession manages caches for a single pipeline run.
|
||||||
// Callers should append generated tokens to outputs and
|
// Callers should append generated tokens to outputs and
|
||||||
// defer close to save the cache state.
|
// defer close to save the cache state.
|
||||||
@@ -48,11 +55,10 @@ type cacheSession struct {
|
|||||||
caches []cache.Cache
|
caches []cache.Cache
|
||||||
remaining []int32
|
remaining []int32
|
||||||
|
|
||||||
// snapshotOffset, if > 0, is a trie node boundary where we need to
|
// pendingSnapshots lists offsets where snapshots should be captured
|
||||||
// capture a snapshot during prefill. This enables future requests
|
// during prefill, sorted by offset. Entries are consumed as the
|
||||||
// branching at this point to restore non-rewindable caches (e.g.
|
// cache advances past them.
|
||||||
// RecurrentCache) instead of re-evaluating from scratch.
|
pendingSnapshots []pendingSnapshot
|
||||||
snapshotOffset int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *kvCache) ensureCaches(m base.Model) {
|
func (c *kvCache) ensureCaches(m base.Model) {
|
||||||
@@ -100,31 +106,26 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
prefix := c.minCacheOffset()
|
prefix := c.minCacheOffset()
|
||||||
remaining := inputs[prefix:]
|
remaining := inputs[prefix:]
|
||||||
|
|
||||||
|
session := &cacheSession{
|
||||||
|
cache: c,
|
||||||
|
inputs: inputs,
|
||||||
|
caches: c.caches,
|
||||||
|
remaining: remaining,
|
||||||
|
}
|
||||||
|
|
||||||
// Schedule a snapshot at the branch point during prefill so future
|
// Schedule a snapshot at the branch point during prefill so future
|
||||||
// requests diverging here can restore instead of re-evaluating.
|
// requests diverging here can restore instead of re-evaluating.
|
||||||
var snapshotAt int
|
|
||||||
if prefix < matched {
|
if prefix < matched {
|
||||||
snapshotAt = matched
|
session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false})
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []any{"total", len(inputs), "matched", originalMatched}
|
msg := "cache hit"
|
||||||
args = append(args, "cached", prefix, "left", len(remaining))
|
|
||||||
if snapshotAt > 0 {
|
|
||||||
args = append(args, "pending_snapshot", snapshotAt)
|
|
||||||
}
|
|
||||||
if prefix == 0 {
|
if prefix == 0 {
|
||||||
slog.Info("cache miss", args...)
|
msg = "cache miss"
|
||||||
} else {
|
|
||||||
slog.Info("cache hit", args...)
|
|
||||||
}
|
}
|
||||||
|
slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining))
|
||||||
|
|
||||||
return &cacheSession{
|
return session
|
||||||
cache: c,
|
|
||||||
inputs: inputs,
|
|
||||||
snapshotOffset: snapshotAt,
|
|
||||||
caches: c.caches,
|
|
||||||
remaining: remaining,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// switchToPath transitions from the current active path to a new path,
|
// switchToPath transitions from the current active path to a new path,
|
||||||
@@ -250,20 +251,54 @@ pageIn:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// snapshot creates a snapshot at the current cache position. During prefill,
|
// requestSnapshot schedules a user snapshot at the given absolute token
|
||||||
// it is called at branch points (user=false) to create restore points for
|
// offset. The snapshot will be captured during prefill when the cache
|
||||||
// future diverging requests and with user=true to mark an explicit reusable
|
// reaches this offset.
|
||||||
// restore point.
|
func (s *cacheSession) requestSnapshot(offset int) {
|
||||||
func (s *cacheSession) snapshot(user bool) {
|
baseOffset := len(s.inputs) - len(s.remaining)
|
||||||
|
if offset <= baseOffset || offset > len(s.inputs) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Deduplicate: if this offset already exists, upgrade to user.
|
||||||
|
for i := range s.pendingSnapshots {
|
||||||
|
if s.pendingSnapshots[i].offset == offset {
|
||||||
|
s.pendingSnapshots[i].user = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true})
|
||||||
|
slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int {
|
||||||
|
return a.offset - b.offset
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextPendingSnapshot returns the offset of the next pending snapshot,
|
||||||
|
// or 0 if there are none.
|
||||||
|
func (s *cacheSession) nextPendingSnapshot() int {
|
||||||
|
if len(s.pendingSnapshots) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return s.pendingSnapshots[0].offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// snapshot creates a snapshot at the current cache position. It determines
|
||||||
|
// whether this is a user snapshot by consuming pending entries whose offset
|
||||||
|
// has been reached.
|
||||||
|
func (s *cacheSession) snapshot() {
|
||||||
c := s.cache
|
c := s.cache
|
||||||
cacheOffset := c.minCacheOffset()
|
cacheOffset := c.minCacheOffset()
|
||||||
if cacheOffset <= 0 {
|
if cacheOffset <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear pending intermediate snapshot if we've reached or passed it.
|
// Consume pending snapshots up to the current offset and derive
|
||||||
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
|
// the user flag from them.
|
||||||
s.snapshotOffset = 0
|
user := false
|
||||||
|
for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset {
|
||||||
|
if s.pendingSnapshots[0].user {
|
||||||
|
user = true
|
||||||
|
}
|
||||||
|
s.pendingSnapshots = s.pendingSnapshots[1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
// The last node in activePath is the frontier where caches are advancing.
|
// The last node in activePath is the frontier where caches are advancing.
|
||||||
|
|||||||
@@ -377,24 +377,25 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
|
|||||||
// begin -> prefill with snapshot(false) at branch points -> generate -> close
|
// begin -> prefill with snapshot(false) at branch points -> generate -> close
|
||||||
|
|
||||||
type requestResult struct {
|
type requestResult struct {
|
||||||
remaining []int32
|
remaining []int32
|
||||||
snapshotOffset int
|
pendingSnapshots int
|
||||||
}
|
}
|
||||||
|
|
||||||
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
|
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
|
||||||
// a user snapshot (snapshot(true)) is created at that offset during prefill.
|
// a user snapshot is requested at that offset during prefill.
|
||||||
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
|
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
userSnapAt := 0
|
session := kvc.begin(nil, inputs)
|
||||||
if len(userSnapshotAt) > 0 {
|
for _, at := range userSnapshotAt {
|
||||||
userSnapAt = userSnapshotAt[0]
|
if at > 0 {
|
||||||
|
session.requestSnapshot(at)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
session := kvc.begin(nil, inputs)
|
|
||||||
result := requestResult{
|
result := requestResult{
|
||||||
remaining: slices.Clone(session.remaining),
|
remaining: slices.Clone(session.remaining),
|
||||||
snapshotOffset: session.snapshotOffset,
|
pendingSnapshots: len(session.pendingSnapshots),
|
||||||
}
|
}
|
||||||
|
|
||||||
assertCacheOffsetAlignment(t, kvc, "after begin")
|
assertCacheOffsetAlignment(t, kvc, "after begin")
|
||||||
@@ -402,22 +403,9 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user
|
|||||||
baseOffset := kvc.minCacheOffset()
|
baseOffset := kvc.minCacheOffset()
|
||||||
remaining := inputs[baseOffset:]
|
remaining := inputs[baseOffset:]
|
||||||
|
|
||||||
// Collect snapshot points (offset -> user flag) in ascending order.
|
// Prefill: feed tokens, pausing at each pending snapshot.
|
||||||
type snapPoint struct {
|
for len(session.pendingSnapshots) > 0 {
|
||||||
offset int
|
sp := session.pendingSnapshots[0]
|
||||||
user bool
|
|
||||||
}
|
|
||||||
var points []snapPoint
|
|
||||||
if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset {
|
|
||||||
points = append(points, snapPoint{session.snapshotOffset, false})
|
|
||||||
}
|
|
||||||
if userSnapAt > 0 && userSnapAt > baseOffset {
|
|
||||||
points = append(points, snapPoint{userSnapAt, true})
|
|
||||||
}
|
|
||||||
slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset })
|
|
||||||
|
|
||||||
// Prefill: feed tokens, pausing at each snapshot point.
|
|
||||||
for _, sp := range points {
|
|
||||||
count := sp.offset - baseOffset
|
count := sp.offset - baseOffset
|
||||||
if count > len(remaining) {
|
if count > len(remaining) {
|
||||||
break
|
break
|
||||||
@@ -428,7 +416,7 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user
|
|||||||
baseOffset = sp.offset
|
baseOffset = sp.offset
|
||||||
}
|
}
|
||||||
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
|
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
|
||||||
session.snapshot(sp.user)
|
session.snapshot()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Feed rest of input tokens.
|
// Feed rest of input tokens.
|
||||||
@@ -615,15 +603,15 @@ func TestBranchCreationAndReuse(t *testing.T) {
|
|||||||
// caches (RecurrentCache), the rewind fails and freeAll fires.
|
// caches (RecurrentCache), the rewind fails and freeAll fires.
|
||||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
||||||
if env.rewindable {
|
if env.rewindable {
|
||||||
if resB.snapshotOffset != 0 {
|
if resB.pendingSnapshots != 0 {
|
||||||
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
|
||||||
}
|
}
|
||||||
if len(resB.remaining) != 3 {
|
if len(resB.remaining) != 3 {
|
||||||
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
|
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if resB.snapshotOffset != 5 {
|
if resB.pendingSnapshots != 1 {
|
||||||
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
|
||||||
}
|
}
|
||||||
if len(resB.remaining) != 8 {
|
if len(resB.remaining) != 8 {
|
||||||
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
|
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
|
||||||
@@ -672,15 +660,15 @@ func TestExactMatchSeedBehavior(t *testing.T) {
|
|||||||
if len(resB.remaining) != 1 {
|
if len(resB.remaining) != 1 {
|
||||||
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
|
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
|
||||||
}
|
}
|
||||||
if resB.snapshotOffset != 0 {
|
if resB.pendingSnapshots != 0 {
|
||||||
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(resB.remaining) != 5 {
|
if len(resB.remaining) != 5 {
|
||||||
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
|
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
|
||||||
}
|
}
|
||||||
if resB.snapshotOffset != 4 {
|
if resB.pendingSnapshots != 1 {
|
||||||
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
||||||
|
|||||||
@@ -79,10 +79,24 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
session := r.cache.begin(r.Model, inputs)
|
session := r.cache.begin(r.Model, inputs)
|
||||||
defer session.close()
|
defer session.close()
|
||||||
|
|
||||||
caches := session.caches
|
caches := session.caches
|
||||||
tokens := session.remaining
|
tokens := session.remaining
|
||||||
prefillChunk := prefillChunkSize()
|
prefillChunk := prefillChunkSize()
|
||||||
|
|
||||||
|
// Request periodic snapshots during prefill and near the end of the
|
||||||
|
// prompt so that long prompts can be partially restored and
|
||||||
|
// thinking/generation can be retried without full reprocessing.
|
||||||
|
const snapshotInterval = 8192
|
||||||
|
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
|
||||||
|
session.requestSnapshot(offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
const preThinking = 4
|
||||||
|
if end := len(inputs) - preThinking; end > 0 {
|
||||||
|
session.requestSnapshot(end)
|
||||||
|
}
|
||||||
|
|
||||||
materializeCaches := func() {
|
materializeCaches := func() {
|
||||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||||
for _, c := range caches {
|
for _, c := range caches {
|
||||||
@@ -103,12 +117,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
n := min(prefillChunk, total-processed-1)
|
n := min(prefillChunk, total-processed-1)
|
||||||
|
|
||||||
// If there's a pending intermediate snapshot, split the batch
|
// If there's a pending snapshot, split the batch so we can
|
||||||
// so we can capture it at the exact offset. The cache offset
|
// capture it at the exact offset.
|
||||||
// after this batch will be: baseOffset + processed + n.
|
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
|
||||||
if session.snapshotOffset > 0 {
|
|
||||||
baseOffset := len(session.inputs) - len(tokens)
|
baseOffset := len(session.inputs) - len(tokens)
|
||||||
tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed)
|
tokensUntilSnapshot := snapOffset - (baseOffset + processed)
|
||||||
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
||||||
n = tokensUntilSnapshot
|
n = tokensUntilSnapshot
|
||||||
}
|
}
|
||||||
@@ -120,11 +133,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
processed += n
|
processed += n
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||||
|
|
||||||
// Create snapshot at branch point for future diverging requests.
|
// Create snapshot if we've reached a pending offset.
|
||||||
if session.snapshotOffset > 0 {
|
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
|
||||||
baseOffset := len(session.inputs) - len(tokens)
|
baseOffset := len(session.inputs) - len(tokens)
|
||||||
if baseOffset+processed >= session.snapshotOffset {
|
if baseOffset+processed >= snapOffset {
|
||||||
session.snapshot(false)
|
session.snapshot()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user