mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlx: fix KV cache snapshot memory leak
mlx.Copy shares the backing buffer with its source (via copy_shared_buffer) rather than allocating independent storage. When used to snapshot a slice of the KV cache, the snapshot array holds the entire original cache buffer alive through the shared data pointer — even after eval detaches the computation graph. Replace Copy with Contiguous in Snapshot and Split. Contiguous allocates a compact buffer when the source buffer is significantly larger than the logical slice (Contiguous::eval checks buffer_size > nbytes + 16384), which is always the case for KV cache slices.
This commit is contained in:
12
x/mlxrunner/cache/cache.go
vendored
12
x/mlxrunner/cache/cache.go
vendored
@@ -109,8 +109,8 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot {
|
||||
|
||||
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
||||
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
||||
kCopy := mlx.Copy(kSlice)
|
||||
vCopy := mlx.Copy(vSlice)
|
||||
kCopy := mlx.Contiguous(kSlice, false)
|
||||
vCopy := mlx.Contiguous(vSlice, false)
|
||||
mlx.Pin(kCopy, vCopy)
|
||||
mlx.AsyncEval(kCopy, vCopy)
|
||||
|
||||
@@ -196,10 +196,10 @@ func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
pk := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
|
||||
pv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
|
||||
ck := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
|
||||
cv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
|
||||
pk := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
|
||||
pv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
|
||||
ck := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
|
||||
cv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
|
||||
mlx.Pin(pk, pv, ck, cv)
|
||||
mlx.AsyncEval(pk, pv, ck, cv)
|
||||
|
||||
|
||||
@@ -262,9 +262,11 @@ func LogArrays() {
|
||||
return arrays[i].NumBytes() > arrays[j].NumBytes()
|
||||
})
|
||||
|
||||
var total int
|
||||
for _, t := range arrays {
|
||||
nb := t.NumBytes()
|
||||
total += nb
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
|
||||
@@ -494,15 +494,6 @@ func Collect(v any) []*Array {
|
||||
return arrays
|
||||
}
|
||||
|
||||
func Copy(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("COPY")
|
||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user