mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlx.Copy shares the backing buffer with its source (via copy_shared_buffer) rather than allocating independent storage. When used to snapshot a slice of the KV cache, the snapshot array holds the entire original cache buffer alive through the shared data pointer — even after eval detaches the computation graph. Replace Copy with Contiguous in Snapshot and Split. Contiguous allocates a compact buffer when the source buffer is significantly larger than the logical slice (Contiguous::eval checks buffer_size > nbytes + 16384), which is always the case for KV cache slices.
431 lines
12 KiB
Go
431 lines
12 KiB
Go
package cache
|
|
|
|
import (
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
type Cache interface {
|
|
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
|
// State returns the cache-owned state roots that should be kept/evaluated.
|
|
State() []*mlx.Array
|
|
Free()
|
|
Offset() int
|
|
|
|
// Snapshot copies cache state from fromOffset to current offset into
|
|
// pinned VRAM arrays. The active cache is unchanged.
|
|
Snapshot(fromOffset int) Snapshot
|
|
|
|
// Restore brings the cache to target. If snapshot is nil, rewinds
|
|
// using the cache's own live state. Returns false if the target is
|
|
// unreachable (e.g. target > current offset, or negative).
|
|
Restore(snapshot Snapshot, target int) bool
|
|
|
|
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
|
// Takes ownership of both inputs.
|
|
Merge(parent, child Snapshot) Snapshot
|
|
|
|
// Split divides a snapshot [a,c) at offset b into [a,b) and [b,c).
|
|
// Takes ownership of the input. Cache types that cannot split
|
|
// (e.g. recurrent) return (nil, snapshot).
|
|
Split(snapshot Snapshot, at int) (parent, child Snapshot)
|
|
}
|
|
|
|
// Snapshot is paged-out cache state that can be restored later.
|
|
type Snapshot interface {
|
|
// Size returns the byte size of the paged-out data (in VRAM).
|
|
Size() int
|
|
// Close unpins the snapshot's arrays so they can be freed by Sweep.
|
|
Close()
|
|
}
|
|
|
|
type KVCache struct {
|
|
keys, values *mlx.Array
|
|
offset int
|
|
step int
|
|
}
|
|
|
|
func NewKVCache() *KVCache {
|
|
return &KVCache{step: 256}
|
|
}
|
|
|
|
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
|
|
|
prev := c.offset
|
|
|
|
// Grow buffer if needed
|
|
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
|
|
steps := (c.step + L - 1) / c.step
|
|
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
|
|
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
|
|
|
|
if c.keys != nil {
|
|
if prev%c.step != 0 {
|
|
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
|
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
|
}
|
|
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
|
c.values.Set(c.values.Concatenate(2, newValues))
|
|
} else {
|
|
c.keys, c.values = newKeys, newValues
|
|
mlx.Pin(c.keys, c.values)
|
|
}
|
|
}
|
|
|
|
c.offset += L
|
|
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
|
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
|
|
|
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
|
}
|
|
|
|
func (c *KVCache) State() []*mlx.Array {
|
|
if c.keys == nil || c.values == nil {
|
|
return nil
|
|
}
|
|
return []*mlx.Array{
|
|
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
|
}
|
|
}
|
|
|
|
// kvSnapshot holds paged-out KV data for a range [fromOffset, toOffset).
|
|
type kvSnapshot struct {
|
|
keys, values *mlx.Array
|
|
fromOffset, toOffset int
|
|
}
|
|
|
|
func (s *kvSnapshot) Size() int { return s.keys.NumBytes() + s.values.NumBytes() }
|
|
func (s *kvSnapshot) Close() { mlx.Unpin(s.keys, s.values) }
|
|
|
|
func (c *KVCache) Snapshot(fromOffset int) Snapshot {
|
|
if c.keys == nil || c.offset <= fromOffset {
|
|
return nil
|
|
}
|
|
from := max(0, fromOffset)
|
|
to := c.offset
|
|
|
|
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
|
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
|
kCopy := mlx.Contiguous(kSlice, false)
|
|
vCopy := mlx.Contiguous(vSlice, false)
|
|
mlx.Pin(kCopy, vCopy)
|
|
mlx.AsyncEval(kCopy, vCopy)
|
|
|
|
return &kvSnapshot{
|
|
keys: kCopy,
|
|
values: vCopy,
|
|
fromOffset: from,
|
|
toOffset: to,
|
|
}
|
|
}
|
|
|
|
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
|
if target < 0 {
|
|
return false
|
|
}
|
|
|
|
if snapshot == nil {
|
|
if target > c.offset {
|
|
return false
|
|
}
|
|
c.offset = target
|
|
return true
|
|
}
|
|
|
|
snap := snapshot.(*kvSnapshot)
|
|
|
|
if target > snap.toOffset || c.offset < snap.fromOffset {
|
|
return false
|
|
}
|
|
|
|
// Rewind to snapshot start, then feed snapshot data through Update.
|
|
c.offset = snap.fromOffset
|
|
c.Update(snap.keys, snap.values)
|
|
|
|
// Clamp to target if needed (target may be less than full snapshot).
|
|
if target < c.offset {
|
|
c.offset = target
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (c *KVCache) Merge(parent, child Snapshot) Snapshot {
|
|
if parent == nil || child == nil {
|
|
if parent != nil {
|
|
parent.Close()
|
|
}
|
|
if child != nil {
|
|
child.Close()
|
|
}
|
|
return nil
|
|
}
|
|
p := parent.(*kvSnapshot)
|
|
ch := child.(*kvSnapshot)
|
|
|
|
mk := p.keys.Concatenate(2, ch.keys)
|
|
mv := p.values.Concatenate(2, ch.values)
|
|
mlx.Pin(mk, mv)
|
|
mlx.AsyncEval(mk, mv)
|
|
|
|
p.Close()
|
|
ch.Close()
|
|
|
|
return &kvSnapshot{
|
|
keys: mk,
|
|
values: mv,
|
|
fromOffset: p.fromOffset,
|
|
toOffset: ch.toOffset,
|
|
}
|
|
}
|
|
|
|
func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
|
if snapshot == nil {
|
|
return nil, nil
|
|
}
|
|
snap := snapshot.(*kvSnapshot)
|
|
splitIdx := at - snap.fromOffset
|
|
seqLen := snap.toOffset - snap.fromOffset
|
|
if splitIdx <= 0 {
|
|
return nil, snapshot
|
|
}
|
|
if splitIdx >= seqLen {
|
|
return snapshot, nil
|
|
}
|
|
|
|
pk := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
|
|
pv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
|
|
ck := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
|
|
cv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
|
|
mlx.Pin(pk, pv, ck, cv)
|
|
mlx.AsyncEval(pk, pv, ck, cv)
|
|
|
|
snap.Close()
|
|
|
|
p := &kvSnapshot{
|
|
keys: pk,
|
|
values: pv,
|
|
fromOffset: snap.fromOffset,
|
|
toOffset: at,
|
|
}
|
|
ch := &kvSnapshot{
|
|
keys: ck,
|
|
values: cv,
|
|
fromOffset: at,
|
|
toOffset: snap.toOffset,
|
|
}
|
|
return p, ch
|
|
}
|
|
|
|
func (c *KVCache) Free() {
|
|
mlx.Unpin(c.keys, c.values)
|
|
c.keys, c.values = nil, nil
|
|
c.offset = 0
|
|
}
|
|
|
|
func (c *KVCache) Offset() int { return c.offset }
|
|
|
|
// RotatingKVCache implements sliding window attention with bounded memory
|
|
type RotatingKVCache struct {
|
|
maxSize int
|
|
idx int
|
|
|
|
*KVCache
|
|
}
|
|
|
|
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
|
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
|
|
}
|
|
|
|
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
if keys.Dim(2) > 1 {
|
|
return c.concat(keys, values)
|
|
}
|
|
return c.update(keys, values)
|
|
}
|
|
|
|
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
|
logutil.Trace("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
|
if c.keys == nil {
|
|
c.keys, c.values = keys.Clone(), values.Clone()
|
|
mlx.Pin(c.keys, c.values)
|
|
} else {
|
|
if c.idx < c.keys.Dim(2) {
|
|
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
|
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
|
}
|
|
|
|
// Trim to max_size to maintain sliding window
|
|
if trim := c.idx - c.maxSize + 1; trim > 0 {
|
|
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
|
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
|
}
|
|
|
|
c.keys.Set(c.keys.Concatenate(2, keys))
|
|
c.values.Set(c.values.Concatenate(2, values))
|
|
c.idx = c.keys.Dim(2)
|
|
}
|
|
|
|
c.offset += keys.Dim(2)
|
|
c.idx = c.keys.Dim(2)
|
|
return c.keys, c.values
|
|
}
|
|
|
|
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
logutil.Trace("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
|
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
|
|
|
prev := c.offset
|
|
|
|
// Grow buffer if not yet at max
|
|
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
|
|
newSize := min(c.step, c.maxSize-prev)
|
|
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
|
|
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
|
|
if c.keys != nil {
|
|
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
|
c.values.Set(c.values.Concatenate(2, newValues))
|
|
} else {
|
|
c.keys, c.values = newKeys, newValues
|
|
mlx.Pin(c.keys, c.values)
|
|
}
|
|
c.idx = prev
|
|
}
|
|
|
|
// Trim to max_size to maintain sliding window
|
|
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
|
|
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
|
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
|
c.idx = c.maxSize
|
|
}
|
|
|
|
// Rotate when hitting max
|
|
if c.idx >= c.maxSize {
|
|
c.idx = 0
|
|
}
|
|
|
|
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
|
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
|
|
|
c.offset += L
|
|
c.idx += L
|
|
|
|
validLen := min(c.offset, c.maxSize)
|
|
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
|
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
|
}
|
|
|
|
func (c *RotatingKVCache) State() []*mlx.Array {
|
|
if c.keys == nil || c.values == nil {
|
|
return nil
|
|
}
|
|
return []*mlx.Array{
|
|
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
|
}
|
|
}
|
|
|
|
// rotatingSnapshot holds paged-out data for a RotatingKVCache.
|
|
type rotatingSnapshot struct {
|
|
kvSnapshot // embedded KV data
|
|
idx int // buffer write position at snapshot time
|
|
}
|
|
|
|
func (s *rotatingSnapshot) Size() int { return s.kvSnapshot.Size() }
|
|
func (s *rotatingSnapshot) Close() { s.kvSnapshot.Close() }
|
|
|
|
func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
|
|
if c.keys == nil || c.offset <= fromOffset {
|
|
return nil
|
|
}
|
|
|
|
state := c.State()
|
|
k := state[0].Clone()
|
|
v := state[1].Clone()
|
|
mlx.Pin(k, v)
|
|
|
|
return &rotatingSnapshot{
|
|
kvSnapshot: kvSnapshot{
|
|
keys: k,
|
|
values: v,
|
|
fromOffset: fromOffset,
|
|
toOffset: c.offset,
|
|
},
|
|
idx: c.idx,
|
|
}
|
|
}
|
|
|
|
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
|
if target < 0 {
|
|
return false
|
|
}
|
|
|
|
if snapshot == nil {
|
|
if target >= c.offset {
|
|
return target == c.offset
|
|
}
|
|
// Live rewind is only safe when the buffer hasn't filled yet
|
|
// (offset <= maxSize). Once the window has shifted, rewinding
|
|
// leaves fewer than maxSize trailing tokens to attend to —
|
|
// a snapshot is required to restore the full window.
|
|
if c.offset > c.maxSize {
|
|
return false
|
|
}
|
|
c.offset = target
|
|
c.idx = target
|
|
return true
|
|
}
|
|
|
|
snap := snapshot.(*rotatingSnapshot)
|
|
|
|
if target > snap.toOffset {
|
|
return false
|
|
}
|
|
|
|
// Reject if clamping would leave an incomplete window.
|
|
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
|
return false
|
|
}
|
|
|
|
// Restore from snapshot: rebuild buffer state.
|
|
// Free existing state first.
|
|
if c.keys != nil {
|
|
mlx.Unpin(c.keys, c.values)
|
|
}
|
|
c.keys = snap.keys.Clone()
|
|
c.values = snap.values.Clone()
|
|
mlx.Pin(c.keys, c.values)
|
|
c.offset = snap.toOffset
|
|
c.idx = snap.idx
|
|
|
|
// Clamp to target if needed.
|
|
if target < c.offset {
|
|
c.offset = target
|
|
c.idx = target
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (c *RotatingKVCache) Merge(parent, child Snapshot) Snapshot {
|
|
// For rotating caches, the child snapshot supersedes the parent
|
|
// since it contains the full window state.
|
|
if parent != nil {
|
|
parent.Close()
|
|
}
|
|
return child
|
|
}
|
|
|
|
func (c *RotatingKVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
|
// Rotating cache snapshots contain the full window state.
|
|
// Cannot cleanly split a ring buffer at an arbitrary point.
|
|
return nil, snapshot
|
|
}
|
|
|
|
func (c *RotatingKVCache) Free() {
|
|
c.KVCache.Free()
|
|
c.idx = 0
|
|
}
|