mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlx.Copy shares the backing buffer with its source (via copy_shared_buffer) rather than allocating independent storage. When used to snapshot a slice of the KV cache, the snapshot array holds the entire original cache buffer alive through the shared data pointer — even after eval detaches the computation graph. Replace Copy with Contiguous in Snapshot and Split. Contiguous allocates a compact buffer when the source buffer is significantly larger than the logical slice (Contiguous::eval checks buffer_size > nbytes + 16384), which is always the case for KV cache slices.
559 lines
13 KiB
Go
559 lines
13 KiB
Go
package mlx
|
|
|
|
// #include "generated.h"
|
|
import "C"
|
|
|
|
import (
|
|
"reflect"
|
|
"unsafe"
|
|
)
|
|
|
|
// Quantization operations
|
|
|
|
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
res := C.mlx_vector_array_new()
|
|
defer C.mlx_vector_array_free(res)
|
|
var globalScale C.mlx_array
|
|
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
|
|
|
|
vecSize := int(C.mlx_vector_array_size(res))
|
|
w0 := New("QUANTIZE_W")
|
|
C.mlx_vector_array_get(&w0.ctx, res, 0)
|
|
w1 := New("QUANTIZE_S")
|
|
C.mlx_vector_array_get(&w1.ctx, res, 1)
|
|
if vecSize >= 3 {
|
|
w2 := New("QUANTIZE_B")
|
|
C.mlx_vector_array_get(&w2.ctx, res, 2)
|
|
return w0, w1, w2
|
|
}
|
|
return w0, w1, nil
|
|
}
|
|
|
|
func FromFP8(x *Array, dtype DType) *Array {
|
|
out := New("FROM_FP8")
|
|
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func ToFP8(x *Array) *Array {
|
|
out := New("TO_FP8")
|
|
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
optDtype := C.mlx_optional_dtype{has_value: false}
|
|
|
|
var b C.mlx_array
|
|
if biases != nil {
|
|
b = biases.ctx
|
|
}
|
|
|
|
out := New("DEQUANTIZE")
|
|
var globalScale C.mlx_array
|
|
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
|
|
var b C.mlx_array
|
|
if biases != nil {
|
|
b = biases.ctx
|
|
}
|
|
|
|
out := New("QUANTIZED_MATMUL")
|
|
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
|
|
var b, lhs, rhs C.mlx_array
|
|
if biases != nil {
|
|
b = biases.ctx
|
|
}
|
|
if lhsIndices != nil {
|
|
lhs = lhsIndices.ctx
|
|
}
|
|
if rhsIndices != nil {
|
|
rhs = rhsIndices.ctx
|
|
}
|
|
|
|
out := New("GATHER_QMM")
|
|
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
// Missing tensor ops
|
|
|
|
func Tile(a *Array, reps []int32) *Array {
|
|
cReps := make([]C.int, len(reps))
|
|
for i, r := range reps {
|
|
cReps[i] = C.int(r)
|
|
}
|
|
out := New("TILE")
|
|
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Tri(n, m int32, k int) *Array {
|
|
out := New("TRI")
|
|
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Where(condition, a, b *Array) *Array {
|
|
out := New("WHERE")
|
|
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
|
out := New("CONV1D")
|
|
C.mlx_conv1d(
|
|
&out.ctx,
|
|
x.ctx,
|
|
weight.ctx,
|
|
C.int(stride),
|
|
C.int(padding),
|
|
C.int(dilation),
|
|
C.int(groups),
|
|
DefaultStream().ctx,
|
|
)
|
|
if bias != nil && bias.Valid() {
|
|
out = Add(out, bias)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func Contiguous(a *Array, allowColMajor bool) *Array {
|
|
out := New("CONTIGUOUS")
|
|
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Pad(a *Array, paddings []int32) *Array {
|
|
numAxes := len(paddings) / 2
|
|
axes := make([]C.int, numAxes)
|
|
lowPad := make([]C.int, numAxes)
|
|
highPad := make([]C.int, numAxes)
|
|
for i := range numAxes {
|
|
axes[i] = C.int(i)
|
|
lowPad[i] = C.int(paddings[i*2])
|
|
highPad[i] = C.int(paddings[i*2+1])
|
|
}
|
|
|
|
padValue := C.mlx_array_new_float(C.float(0))
|
|
defer C.mlx_array_free(padValue)
|
|
|
|
cMode := C.CString("constant")
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
|
|
out := New("PAD")
|
|
C.mlx_pad(
|
|
&out.ctx,
|
|
a.ctx,
|
|
unsafe.SliceData(axes),
|
|
C.size_t(len(axes)),
|
|
unsafe.SliceData(lowPad),
|
|
C.size_t(len(lowPad)),
|
|
unsafe.SliceData(highPad),
|
|
C.size_t(len(highPad)),
|
|
padValue,
|
|
cMode,
|
|
DefaultStream().ctx,
|
|
)
|
|
return out
|
|
}
|
|
|
|
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
|
groups := int32(x.Dim(x.NumDims() - 1))
|
|
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
|
}
|
|
|
|
// Convenience wrappers (function-style for the model code)
|
|
|
|
func Stack(arrays []*Array, axis int) *Array {
|
|
vectorData := make([]C.mlx_array, len(arrays))
|
|
for i := range arrays {
|
|
vectorData[i] = arrays[i].ctx
|
|
}
|
|
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
|
defer C.mlx_vector_array_free(vector)
|
|
|
|
out := New("STACK")
|
|
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Neg(a *Array) *Array {
|
|
return a.Negative()
|
|
}
|
|
|
|
func Sum(a *Array, axis int, keepDims bool) *Array {
|
|
return a.SumAxis(axis, keepDims)
|
|
}
|
|
|
|
func Argsort(a *Array, axis int) *Array {
|
|
return a.ArgsortAxis(axis)
|
|
}
|
|
|
|
func Take(a *Array, indices *Array, axis int) *Array {
|
|
return a.TakeAxis(indices, axis)
|
|
}
|
|
|
|
func RSqrt(a *Array) *Array {
|
|
out := New("RSQRT")
|
|
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Mean(a *Array, axis int, keepDims bool) *Array {
|
|
out := New("MEAN_AXIS")
|
|
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Argpartition(a *Array, kth int, axis int) *Array {
|
|
return a.ArgpartitionAxis(kth, axis)
|
|
}
|
|
|
|
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
|
return a.TakeAlongAxis(indices, axis)
|
|
}
|
|
|
|
// Function-style wrappers matching imagegen API
|
|
|
|
func Add(a, b *Array) *Array {
|
|
return a.Add(b)
|
|
}
|
|
|
|
func Sub(a, b *Array) *Array {
|
|
return a.Subtract(b)
|
|
}
|
|
|
|
func Mul(a, b *Array) *Array {
|
|
return a.Multiply(b)
|
|
}
|
|
|
|
func Div(a, b *Array) *Array {
|
|
return a.Divide(b)
|
|
}
|
|
|
|
func Matmul(a, b *Array) *Array {
|
|
return a.Matmul(b)
|
|
}
|
|
|
|
func Reshape(a *Array, shape ...int32) *Array {
|
|
axes := make([]int, len(shape))
|
|
for i, s := range shape {
|
|
axes[i] = int(s)
|
|
}
|
|
return a.Reshape(axes...)
|
|
}
|
|
|
|
func Transpose(a *Array, axes ...int) *Array {
|
|
return a.Transpose(axes...)
|
|
}
|
|
|
|
func ExpandDims(a *Array, axis int) *Array {
|
|
return a.ExpandDims(axis)
|
|
}
|
|
|
|
func Squeeze(a *Array, axis int) *Array {
|
|
return a.Squeeze(axis)
|
|
}
|
|
|
|
func Flatten(a *Array) *Array {
|
|
return a.Flatten(0, -1)
|
|
}
|
|
|
|
func Concatenate(arrays []*Array, axis int) *Array {
|
|
if len(arrays) == 0 {
|
|
return nil
|
|
}
|
|
return arrays[0].Concatenate(axis, arrays[1:]...)
|
|
}
|
|
|
|
func SliceStartStop(a *Array, start, stop []int32) *Array {
|
|
n := len(start)
|
|
cStart := make([]C.int, n)
|
|
cStop := make([]C.int, n)
|
|
cStrides := make([]C.int, n)
|
|
for i := 0; i < n; i++ {
|
|
cStart[i] = C.int(start[i])
|
|
cStop[i] = C.int(stop[i])
|
|
cStrides[i] = 1
|
|
}
|
|
out := New("SLICE")
|
|
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
|
|
if lhsIndices == nil {
|
|
lhsIndices = New("")
|
|
}
|
|
if rhsIndices == nil {
|
|
rhsIndices = New("")
|
|
}
|
|
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
|
}
|
|
|
|
func SiLU(a *Array) *Array {
|
|
sig := a.Sigmoid()
|
|
return a.Multiply(sig)
|
|
}
|
|
|
|
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
|
freqs := New("")
|
|
out := New("FAST_ROPE")
|
|
C.mlx_fast_rope(
|
|
&out.ctx,
|
|
x.ctx,
|
|
C.int(dims),
|
|
C.bool(traditional),
|
|
C.mlx_optional_float{
|
|
value: C.float(base),
|
|
has_value: C.bool(func() bool { return base != 0 }()),
|
|
},
|
|
C.float(scale),
|
|
C.int(offset),
|
|
freqs.ctx,
|
|
DefaultStream().ctx,
|
|
)
|
|
return out
|
|
}
|
|
|
|
func Sigmoid(a *Array) *Array {
|
|
return a.Sigmoid()
|
|
}
|
|
|
|
func Exp(a *Array) *Array {
|
|
out := New("EXP")
|
|
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Log(a *Array) *Array {
|
|
out := New("LOG")
|
|
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Logaddexp(a, b *Array) *Array {
|
|
out := New("LOGADDEXP")
|
|
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
|
out := New("SOFTMAX_AXIS")
|
|
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
|
mask := New("")
|
|
sinks := New("")
|
|
mode := ""
|
|
if causalMask {
|
|
mode = "causal"
|
|
}
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
|
|
out := New("FAST_SDPA")
|
|
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
|
out := New("FAST_LAYERNORM")
|
|
var w, b C.mlx_array
|
|
if weight != nil {
|
|
w = weight.ctx
|
|
}
|
|
if bias != nil {
|
|
b = bias.ctx
|
|
}
|
|
C.mlx_fast_layer_norm(&out.ctx, x.ctx, w, b, C.float(eps), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
|
out := New("FAST_RMSNORM")
|
|
var w C.mlx_array
|
|
if weight != nil {
|
|
w = weight.ctx
|
|
}
|
|
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
|
|
return c.Addmm(a, b, alpha, beta)
|
|
}
|
|
|
|
// Scalar helpers
|
|
|
|
// scalarWithDtype creates a scalar array matching the dtype of a.
|
|
// Matching dtype is important for graph fusion and avoiding implicit casts.
|
|
func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
|
f32 := C.mlx_array_new_float(C.float(s))
|
|
dtype := a.DType()
|
|
if dtype == DTypeFloat32 {
|
|
return f32
|
|
}
|
|
casted := C.mlx_array_new()
|
|
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
|
|
C.mlx_array_free(f32)
|
|
return casted
|
|
}
|
|
|
|
func AddScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
out := New("ADD_SCALAR")
|
|
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
|
C.mlx_array_free(scalar)
|
|
return out
|
|
}
|
|
|
|
func MulScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
out := New("MUL_SCALAR")
|
|
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
|
C.mlx_array_free(scalar)
|
|
return out
|
|
}
|
|
|
|
func DivScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
out := New("DIV_SCALAR")
|
|
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
|
C.mlx_array_free(scalar)
|
|
return out
|
|
}
|
|
|
|
func FloorDivideScalar(a *Array, s int32) *Array {
|
|
scalar := FromValue(int(s))
|
|
return a.FloorDivide(scalar)
|
|
}
|
|
|
|
// Array constructors
|
|
|
|
func NewArrayInt32(data []int32, shape []int32) *Array {
|
|
cShape := make([]C.int, len(shape))
|
|
for i, s := range shape {
|
|
cShape[i] = C.int(s)
|
|
}
|
|
out := New("NEW_ARRAY_INT32")
|
|
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
|
|
return out
|
|
}
|
|
|
|
func NewScalarArray(value float32) *Array {
|
|
out := New("SCALAR")
|
|
out.ctx = C.mlx_array_new_float32(C.float(value))
|
|
return out
|
|
}
|
|
|
|
func ZerosF32(shape []int32) *Array {
|
|
return Zeros(DTypeFloat32, func() []int {
|
|
ints := make([]int, len(shape))
|
|
for i, s := range shape {
|
|
ints[i] = int(s)
|
|
}
|
|
return ints
|
|
}()...)
|
|
}
|
|
|
|
// Utility
|
|
|
|
func Collect(v any) []*Array {
|
|
var arrays []*Array
|
|
seen := make(map[uintptr]bool)
|
|
collect(reflect.ValueOf(v), &arrays, seen)
|
|
return arrays
|
|
}
|
|
|
|
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
|
if !v.IsValid() {
|
|
return
|
|
}
|
|
|
|
if v.Kind() == reflect.Ptr {
|
|
if v.IsNil() {
|
|
return
|
|
}
|
|
ptr := v.Pointer()
|
|
if seen[ptr] {
|
|
return
|
|
}
|
|
seen[ptr] = true
|
|
|
|
if arr, ok := v.Interface().(*Array); ok {
|
|
if arr != nil && arr.Valid() {
|
|
*arrays = append(*arrays, arr)
|
|
}
|
|
return
|
|
}
|
|
collect(v.Elem(), arrays, seen)
|
|
return
|
|
}
|
|
|
|
switch v.Kind() {
|
|
case reflect.Struct:
|
|
// Check if this struct IS an Array (not a pointer to one)
|
|
if arr, ok := v.Addr().Interface().(*Array); ok {
|
|
if arr != nil && arr.Valid() {
|
|
*arrays = append(*arrays, arr)
|
|
}
|
|
return
|
|
}
|
|
for i := 0; i < v.NumField(); i++ {
|
|
field := v.Field(i)
|
|
if field.CanInterface() {
|
|
collect(field, arrays, seen)
|
|
}
|
|
}
|
|
case reflect.Slice:
|
|
for i := 0; i < v.Len(); i++ {
|
|
collect(v.Index(i), arrays, seen)
|
|
}
|
|
case reflect.Map:
|
|
for _, key := range v.MapKeys() {
|
|
collect(v.MapIndex(key), arrays, seen)
|
|
}
|
|
case reflect.Interface:
|
|
if !v.IsNil() {
|
|
collect(v.Elem(), arrays, seen)
|
|
}
|
|
}
|
|
}
|
|
|
|
func EnableCompile() {
|
|
C.mlx_enable_compile()
|
|
}
|
|
|
|
func DisableCompile() {
|
|
C.mlx_disable_compile()
|
|
}
|