mlxrunner: panic on double unpin

This commit is contained in:
Jesse Gross
2026-03-20 16:10:19 -07:00
parent ec55536734
commit 95ee7fbd29
2 changed files with 6 additions and 13 deletions

View File

@@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
if v == nil || !v.Valid() { if v == nil || !v.Valid() {
return old return old
} }
if old == v {
return old
}
mlx.Pin(v) mlx.Pin(v)
if old != nil && old != v { mlx.Unpin(old)
mlx.Unpin(old)
}
return v return v
} }
@@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
if v == nil || !v.Valid() { if v == nil || !v.Valid() {
return old return old
} }
if old == v {
return old
}
root := v root := v
if ensureContiguous { if ensureContiguous {
@@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
detached := root.Clone() detached := root.Clone()
mlx.Pin(detached) mlx.Pin(detached)
if old != nil && old != detached { mlx.Unpin(old)
mlx.Unpin(old)
}
return detached return detached
} }

View File

@@ -137,6 +137,9 @@ func Unpin(s ...*Array) {
for _, t := range s { for _, t := range s {
if t != nil { if t != nil {
t.pinned-- t.pinned--
if t.pinned < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
}
} }
} }
} }
@@ -261,7 +264,7 @@ func LogArrays() {
for _, t := range arrays { for _, t := range arrays {
nb := t.NumBytes() nb := t.NumBytes()
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims())) logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
} }
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory()))) logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
} }