diff --git a/x/mlxrunner/mlx/slice.go b/x/mlxrunner/mlx/slice.go index ea642ebf7..d6ae215a2 100644 --- a/x/mlxrunner/mlx/slice.go +++ b/x/mlxrunner/mlx/slice.go @@ -4,10 +4,14 @@ package mlx import "C" import ( - "cmp" + "math" "unsafe" ) +// End is a sentinel value meaning "to the end of the dimension", +// equivalent to an omitted stop in Python (e.g. a[i:]). +const End = math.MaxInt32 + type slice struct { args []int } @@ -16,6 +20,16 @@ func Slice(args ...int) slice { return slice{args: args} } +func resolve(val, dim int) C.int { + if val == End { + return C.int(dim) + } + if val < 0 { + return C.int(dim + val) + } + return C.int(val) +} + func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) { if len(slices) != len(dims) { panic("number of slice arguments must match number of tensor dimensions") @@ -28,26 +42,28 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) { } for i, s := range slices { + dim := dims[i] switch len(s.args) { case 0: // slice[:] args[0][i] = C.int(0) - args[1][i] = C.int(dims[i]) + args[1][i] = C.int(dim) args[2][i] = C.int(1) case 1: // slice[i] - args[0][i] = C.int(s.args[0]) - args[1][i] = C.int(s.args[0] + 1) + start := resolve(s.args[0], dim) + args[0][i] = start + args[1][i] = start + 1 args[2][i] = C.int(1) case 2: // slice[i:j] - args[0][i] = C.int(s.args[0]) - args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i])) + args[0][i] = resolve(s.args[0], dim) + args[1][i] = resolve(s.args[1], dim) args[2][i] = C.int(1) case 3: // slice[i:j:k] - args[0][i] = C.int(s.args[0]) - args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i])) + args[0][i] = resolve(s.args[0], dim) + args[1][i] = resolve(s.args[1], dim) args[2][i] = C.int(s.args[2]) default: panic("invalid slice arguments") diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index df9da7a99..fe07cf7b3 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -169,7 +169,7 @@ func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array { return logprobs } - mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, 0)) + mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) }