mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlxrunner: fix Slice(0, 0) returning full dimension instead of empty
Slice used cmp.Or to resolve a zero stop value to the dimension size, intended to support open-ended slices like a[i:]. This made Slice(0, 0) indistinguishable from Slice(), so any slice with a zero stop would silently include the entire dimension instead of being empty. Replace cmp.Or with an explicit End sentinel and resolve negative indices against the dimension size, matching Python/PyTorch semantics.
This commit is contained in:
@@ -4,10 +4,14 @@ package mlx
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// End is a sentinel value meaning "to the end of the dimension",
|
||||
// equivalent to an omitted stop in Python (e.g. a[i:]).
|
||||
const End = math.MaxInt32
|
||||
|
||||
type slice struct {
|
||||
args []int
|
||||
}
|
||||
@@ -16,6 +20,16 @@ func Slice(args ...int) slice {
|
||||
return slice{args: args}
|
||||
}
|
||||
|
||||
func resolve(val, dim int) C.int {
|
||||
if val == End {
|
||||
return C.int(dim)
|
||||
}
|
||||
if val < 0 {
|
||||
return C.int(dim + val)
|
||||
}
|
||||
return C.int(val)
|
||||
}
|
||||
|
||||
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
if len(slices) != len(dims) {
|
||||
panic("number of slice arguments must match number of tensor dimensions")
|
||||
@@ -28,26 +42,28 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
}
|
||||
|
||||
for i, s := range slices {
|
||||
dim := dims[i]
|
||||
switch len(s.args) {
|
||||
case 0:
|
||||
// slice[:]
|
||||
args[0][i] = C.int(0)
|
||||
args[1][i] = C.int(dims[i])
|
||||
args[1][i] = C.int(dim)
|
||||
args[2][i] = C.int(1)
|
||||
case 1:
|
||||
// slice[i]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = C.int(s.args[0] + 1)
|
||||
start := resolve(s.args[0], dim)
|
||||
args[0][i] = start
|
||||
args[1][i] = start + 1
|
||||
args[2][i] = C.int(1)
|
||||
case 2:
|
||||
// slice[i:j]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[0][i] = resolve(s.args[0], dim)
|
||||
args[1][i] = resolve(s.args[1], dim)
|
||||
args[2][i] = C.int(1)
|
||||
case 3:
|
||||
// slice[i:j:k]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[0][i] = resolve(s.args[0], dim)
|
||||
args[1][i] = resolve(s.args[1], dim)
|
||||
args[2][i] = C.int(s.args[2])
|
||||
default:
|
||||
panic("invalid slice arguments")
|
||||
|
||||
@@ -169,7 +169,7 @@ func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||
return logprobs
|
||||
}
|
||||
|
||||
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, 0))
|
||||
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user