mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
nil keys
This commit is contained in:
12
x/mlxrunner/cache/cache.go
vendored
12
x/mlxrunner/cache/cache.go
vendored
@@ -22,7 +22,7 @@ type KVCache struct {
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256, keys: &mlx.Array{}, values: &mlx.Array{}}
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
@@ -31,12 +31,12 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if needed
|
||||
if !c.keys.Valid() || (prev+L) > c.keys.Dim(2) {
|
||||
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
|
||||
steps := (c.step + L - 1) / c.step
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
|
||||
|
||||
if c.keys.Valid() {
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
@@ -103,7 +103,7 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
if !c.keys.Valid() {
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = keys, values
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
@@ -134,11 +134,11 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if not yet at max
|
||||
if !c.keys.Valid() || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
|
||||
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
|
||||
newSize := min(c.step, c.maxSize-prev)
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
|
||||
if c.keys.Valid() {
|
||||
if c.keys != nil {
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user