mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlxrunner: panic on double unpin
This commit is contained in:
14
x/mlxrunner/cache/recurrent.go
vendored
14
x/mlxrunner/cache/recurrent.go
vendored
@@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
|
|||||||
if v == nil || !v.Valid() {
|
if v == nil || !v.Valid() {
|
||||||
return old
|
return old
|
||||||
}
|
}
|
||||||
if old == v {
|
|
||||||
return old
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.Pin(v)
|
mlx.Pin(v)
|
||||||
if old != nil && old != v {
|
mlx.Unpin(old)
|
||||||
mlx.Unpin(old)
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
@@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
|||||||
if v == nil || !v.Valid() {
|
if v == nil || !v.Valid() {
|
||||||
return old
|
return old
|
||||||
}
|
}
|
||||||
if old == v {
|
|
||||||
return old
|
|
||||||
}
|
|
||||||
|
|
||||||
root := v
|
root := v
|
||||||
if ensureContiguous {
|
if ensureContiguous {
|
||||||
@@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
|||||||
detached := root.Clone()
|
detached := root.Clone()
|
||||||
|
|
||||||
mlx.Pin(detached)
|
mlx.Pin(detached)
|
||||||
if old != nil && old != detached {
|
mlx.Unpin(old)
|
||||||
mlx.Unpin(old)
|
|
||||||
}
|
|
||||||
|
|
||||||
return detached
|
return detached
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -137,6 +137,9 @@ func Unpin(s ...*Array) {
|
|||||||
for _, t := range s {
|
for _, t := range s {
|
||||||
if t != nil {
|
if t != nil {
|
||||||
t.pinned--
|
t.pinned--
|
||||||
|
if t.pinned < 0 {
|
||||||
|
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -261,7 +264,7 @@ func LogArrays() {
|
|||||||
|
|
||||||
for _, t := range arrays {
|
for _, t := range arrays {
|
||||||
nb := t.NumBytes()
|
nb := t.NumBytes()
|
||||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
|
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
|
||||||
}
|
}
|
||||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user