mlx: fix KV cache snapshot memory leak

mlx.Copy shares the backing buffer with its source (via
copy_shared_buffer) rather than allocating independent storage.
When used to snapshot a slice of the KV cache, the snapshot array
holds the entire original cache buffer alive through the shared
data pointer — even after eval detaches the computation graph.

Replace Copy with Contiguous in Snapshot and Split. Contiguous
allocates a compact buffer when the source buffer is significantly
larger than the logical slice (Contiguous::eval checks
buffer_size > nbytes + 16384), which is always the case for KV
cache slices.
This commit is contained in:
Jesse Gross
2026-03-25 10:47:59 -07:00
parent ebbce136c7
commit d1151e18a1
3 changed files with 9 additions and 16 deletions

View File

@@ -109,8 +109,8 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot {
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
kCopy := mlx.Copy(kSlice)
vCopy := mlx.Copy(vSlice)
kCopy := mlx.Contiguous(kSlice, false)
vCopy := mlx.Contiguous(vSlice, false)
mlx.Pin(kCopy, vCopy)
mlx.AsyncEval(kCopy, vCopy)
@@ -196,10 +196,10 @@ func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
return snapshot, nil
}
pk := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
pv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
ck := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
cv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
pk := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
pv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
ck := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
cv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
mlx.Pin(pk, pv, ck, cv)
mlx.AsyncEval(pk, pv, ck, cv)

View File

@@ -262,9 +262,11 @@ func LogArrays() {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
var total int
for _, t := range arrays {
nb := t.NumBytes()
total += nb
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
}
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
}

View File

@@ -494,15 +494,6 @@ func Collect(v any) []*Array {
return arrays
}
func Copy(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("COPY")
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return