diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 5bb4b440c..06a04b453 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -22,7 +22,7 @@ type KVCache struct { } func NewKVCache() *KVCache { - return &KVCache{step: 256, keys: &mlx.Array{}, values: &mlx.Array{}} + return &KVCache{step: 256} } func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { @@ -31,12 +31,12 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { prev := c.offset // Grow buffer if needed - if !c.keys.Valid() || (prev+L) > c.keys.Dim(2) { + if c.keys == nil || (prev+L) > c.keys.Dim(2) { steps := (c.step + L - 1) / c.step newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk) newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv) - if c.keys.Valid() { + if c.keys != nil { if prev%c.step != 0 { c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice())) c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice())) @@ -103,7 +103,7 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) { slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize) - if !c.keys.Valid() { + if c.keys == nil { c.keys, c.values = keys, values } else { if c.idx < c.keys.Dim(2) { @@ -134,11 +134,11 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra prev := c.offset // Grow buffer if not yet at max - if !c.keys.Valid() || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) { + if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) { newSize := min(c.step, c.maxSize-prev) newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk) newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv) - if c.keys.Valid() { + if c.keys != nil { c.keys.Set(c.keys.Concatenate(2, newKeys)) c.values.Set(c.values.Concatenate(2, newValues)) } else {