diff --git a/x/mlxrunner/cache/recurrent.go b/x/mlxrunner/cache/recurrent.go index 2c02ff3d4..a3c7f90d8 100644 --- a/x/mlxrunner/cache/recurrent.go +++ b/x/mlxrunner/cache/recurrent.go @@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array { if v == nil || !v.Valid() { return old } - if old == v { - return old - } mlx.Pin(v) - if old != nil && old != v { - mlx.Unpin(old) - } + mlx.Unpin(old) return v } @@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo if v == nil || !v.Valid() { return old } - if old == v { - return old - } root := v if ensureContiguous { @@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo detached := root.Clone() mlx.Pin(detached) - if old != nil && old != detached { - mlx.Unpin(old) - } + mlx.Unpin(old) return detached } diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index de91813fc..5604da24a 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -137,6 +137,9 @@ func Unpin(s ...*Array) { for _, t := range s { if t != nil { t.pinned-- + if t.pinned < 0 { + panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name)) + } } } } @@ -261,7 +264,7 @@ func LogArrays() { for _, t := range arrays { nb := t.NumBytes() - logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims())) + logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims())) } logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory()))) }