mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlxrunner: panic on double unpin
This commit is contained in:
14
x/mlxrunner/cache/recurrent.go
vendored
14
x/mlxrunner/cache/recurrent.go
vendored
@@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
if old == v {
|
||||
return old
|
||||
}
|
||||
|
||||
mlx.Pin(v)
|
||||
if old != nil && old != v {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
mlx.Unpin(old)
|
||||
|
||||
return v
|
||||
}
|
||||
@@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
if old == v {
|
||||
return old
|
||||
}
|
||||
|
||||
root := v
|
||||
if ensureContiguous {
|
||||
@@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
||||
detached := root.Clone()
|
||||
|
||||
mlx.Pin(detached)
|
||||
if old != nil && old != detached {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
mlx.Unpin(old)
|
||||
|
||||
return detached
|
||||
}
|
||||
|
||||
@@ -137,6 +137,9 @@ func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned--
|
||||
if t.pinned < 0 {
|
||||
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -261,7 +264,7 @@ func LogArrays() {
|
||||
|
||||
for _, t := range arrays {
|
||||
nb := t.NumBytes()
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user