This commit is contained in:
Michael Yang
2026-02-04 17:26:22 -08:00
parent da4d04b0e8
commit 20299cb1da

View File

@@ -22,7 +22,7 @@ type KVCache struct {
}
func NewKVCache() *KVCache {
return &KVCache{step: 256, keys: &mlx.Array{}, values: &mlx.Array{}}
return &KVCache{step: 256}
}
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
@@ -31,12 +31,12 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
prev := c.offset
// Grow buffer if needed
if !c.keys.Valid() || (prev+L) > c.keys.Dim(2) {
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
steps := (c.step + L - 1) / c.step
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
if c.keys.Valid() {
if c.keys != nil {
if prev%c.step != 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
@@ -103,7 +103,7 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if !c.keys.Valid() {
if c.keys == nil {
c.keys, c.values = keys, values
} else {
if c.idx < c.keys.Dim(2) {
@@ -134,11 +134,11 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
prev := c.offset
// Grow buffer if not yet at max
if !c.keys.Valid() || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
newSize := min(c.step, c.maxSize-prev)
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
if c.keys.Valid() {
if c.keys != nil {
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {