diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 87b42e6cc..39f5c1f5a 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -109,8 +109,8 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot { kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice()) vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice()) - kCopy := mlx.Copy(kSlice) - vCopy := mlx.Copy(vSlice) + kCopy := mlx.Contiguous(kSlice, false) + vCopy := mlx.Contiguous(vSlice, false) mlx.Pin(kCopy, vCopy) mlx.AsyncEval(kCopy, vCopy) @@ -196,10 +196,10 @@ func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) { return snapshot, nil } - pk := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice())) - pv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice())) - ck := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice())) - cv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice())) + pk := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false) + pv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false) + ck := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false) + cv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false) mlx.Pin(pk, pv, ck, cv) mlx.AsyncEval(pk, pv, ck, cv) diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index 5604da24a..198162efd 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -262,9 +262,11 @@ func LogArrays() { return arrays[i].NumBytes() > arrays[j].NumBytes() }) + var total int for _, t := range arrays { nb := t.NumBytes() + total += nb logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims())) } - logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory()))) + logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory()))) } diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index 107279441..9de0037da 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -494,15 +494,6 @@ func Collect(v any) []*Array { return arrays } -func Copy(a *Array) *Array { - if a == nil || !a.Valid() { - return a - } - out := New("COPY") - C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) { if !v.IsValid() { return