package mlx // #include "generated.h" import "C" import ( "reflect" "unsafe" ) // Quantization operations func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) { cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} res := C.mlx_vector_array_new() defer C.mlx_vector_array_free(res) var globalScale C.mlx_array C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx) vecSize := int(C.mlx_vector_array_size(res)) w0 := New("QUANTIZE_W") C.mlx_vector_array_get(&w0.ctx, res, 0) w1 := New("QUANTIZE_S") C.mlx_vector_array_get(&w1.ctx, res, 1) if vecSize >= 3 { w2 := New("QUANTIZE_B") C.mlx_vector_array_get(&w2.ctx, res, 2) return w0, w1, w2 } return w0, w1, nil } func FromFP8(x *Array, dtype DType) *Array { out := New("FROM_FP8") C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx) return out } func ToFP8(x *Array) *Array { out := New("TO_FP8") C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx) return out } func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array { cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} optDtype := C.mlx_optional_dtype{has_value: false} var b C.mlx_array if biases != nil { b = biases.ctx } out := New("DEQUANTIZE") var globalScale C.mlx_array C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx) return out } func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array { cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} var b C.mlx_array if biases != nil { b = biases.ctx } out := New("QUANTIZED_MATMUL") C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx) return out } func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array { cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} var b, lhs, rhs C.mlx_array if biases != nil { b = biases.ctx } if lhsIndices != nil { lhs = lhsIndices.ctx } if rhsIndices != nil { rhs = rhsIndices.ctx } out := New("GATHER_QMM") C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx) return out } // Missing tensor ops func Tile(a *Array, reps []int32) *Array { cReps := make([]C.int, len(reps)) for i, r := range reps { cReps[i] = C.int(r) } out := New("TILE") C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx) return out } func Tri(n, m int32, k int) *Array { out := New("TRI") C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx) return out } func Where(condition, a, b *Array) *Array { out := New("WHERE") C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array { out := New("CONV1D") C.mlx_conv1d( &out.ctx, x.ctx, weight.ctx, C.int(stride), C.int(padding), C.int(dilation), C.int(groups), DefaultStream().ctx, ) if bias != nil && bias.Valid() { out = Add(out, bias) } return out } func Contiguous(a *Array, allowColMajor bool) *Array { out := New("CONTIGUOUS") C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx) return out } func Pad(a *Array, paddings []int32) *Array { numAxes := len(paddings) / 2 axes := make([]C.int, numAxes) lowPad := make([]C.int, numAxes) highPad := make([]C.int, numAxes) for i := range numAxes { axes[i] = C.int(i) lowPad[i] = C.int(paddings[i*2]) highPad[i] = C.int(paddings[i*2+1]) } padValue := C.mlx_array_new_float(C.float(0)) defer C.mlx_array_free(padValue) cMode := C.CString("constant") defer C.free(unsafe.Pointer(cMode)) out := New("PAD") C.mlx_pad( &out.ctx, a.ctx, unsafe.SliceData(axes), C.size_t(len(axes)), unsafe.SliceData(lowPad), C.size_t(len(lowPad)), unsafe.SliceData(highPad), C.size_t(len(highPad)), padValue, cMode, DefaultStream().ctx, ) return out } func DepthwiseConv1d(x, weight *Array, bias *Array) *Array { groups := int32(x.Dim(x.NumDims() - 1)) return Conv1d(x, weight, bias, 1, 0, 1, groups) } // Convenience wrappers (function-style for the model code) func Stack(arrays []*Array, axis int) *Array { vectorData := make([]C.mlx_array, len(arrays)) for i := range arrays { vectorData[i] = arrays[i].ctx } vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData))) defer C.mlx_vector_array_free(vector) out := New("STACK") C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) return out } func Neg(a *Array) *Array { return a.Negative() } func Sum(a *Array, axis int, keepDims bool) *Array { return a.SumAxis(axis, keepDims) } func Argsort(a *Array, axis int) *Array { return a.ArgsortAxis(axis) } func Take(a *Array, indices *Array, axis int) *Array { return a.TakeAxis(indices, axis) } func RSqrt(a *Array) *Array { out := New("RSQRT") C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx) return out } func Mean(a *Array, axis int, keepDims bool) *Array { out := New("MEAN_AXIS") C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) return out } func Argpartition(a *Array, kth int, axis int) *Array { return a.ArgpartitionAxis(kth, axis) } func TakeAlongAxis(a, indices *Array, axis int) *Array { return a.TakeAlongAxis(indices, axis) } // Function-style wrappers matching imagegen API func Add(a, b *Array) *Array { return a.Add(b) } func Sub(a, b *Array) *Array { return a.Subtract(b) } func Mul(a, b *Array) *Array { return a.Multiply(b) } func Div(a, b *Array) *Array { return a.Divide(b) } func Matmul(a, b *Array) *Array { return a.Matmul(b) } func Reshape(a *Array, shape ...int32) *Array { axes := make([]int, len(shape)) for i, s := range shape { axes[i] = int(s) } return a.Reshape(axes...) } func Transpose(a *Array, axes ...int) *Array { return a.Transpose(axes...) } func ExpandDims(a *Array, axis int) *Array { return a.ExpandDims(axis) } func Squeeze(a *Array, axis int) *Array { return a.Squeeze(axis) } func Flatten(a *Array) *Array { return a.Flatten(0, -1) } func Concatenate(arrays []*Array, axis int) *Array { if len(arrays) == 0 { return nil } return arrays[0].Concatenate(axis, arrays[1:]...) } func SliceStartStop(a *Array, start, stop []int32) *Array { n := len(start) cStart := make([]C.int, n) cStop := make([]C.int, n) cStrides := make([]C.int, n) for i := 0; i < n; i++ { cStart[i] = C.int(start[i]) cStop[i] = C.int(stop[i]) cStrides[i] = 1 } out := New("SLICE") C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx) return out } func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array { if lhsIndices == nil { lhsIndices = New("") } if rhsIndices == nil { rhsIndices = New("") } return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices) } func SiLU(a *Array) *Array { sig := a.Sigmoid() return a.Multiply(sig) } func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { freqs := New("") out := New("FAST_ROPE") C.mlx_fast_rope( &out.ctx, x.ctx, C.int(dims), C.bool(traditional), C.mlx_optional_float{ value: C.float(base), has_value: C.bool(func() bool { return base != 0 }()), }, C.float(scale), C.int(offset), freqs.ctx, DefaultStream().ctx, ) return out } func Sigmoid(a *Array) *Array { return a.Sigmoid() } func Exp(a *Array) *Array { out := New("EXP") C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx) return out } func Log(a *Array) *Array { out := New("LOG") C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx) return out } func Logaddexp(a, b *Array) *Array { out := New("LOGADDEXP") C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } func SoftmaxAxis(a *Array, axis int, precise bool) *Array { out := New("SOFTMAX_AXIS") C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx) return out } func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array { mask := New("") sinks := New("") mode := "" if causalMask { mode = "causal" } cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) out := New("FAST_SDPA") C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx) return out } func LayerNormFn(x, weight, bias *Array, eps float32) *Array { out := New("FAST_LAYERNORM") var w, b C.mlx_array if weight != nil { w = weight.ctx } if bias != nil { b = bias.ctx } C.mlx_fast_layer_norm(&out.ctx, x.ctx, w, b, C.float(eps), DefaultStream().ctx) return out } func RMSNormFn(x, weight *Array, eps float32) *Array { out := New("FAST_RMSNORM") var w C.mlx_array if weight != nil { w = weight.ctx } C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx) return out } func AddMM(c, a, b *Array, alpha, beta float32) *Array { return c.Addmm(a, b, alpha, beta) } // Scalar helpers // scalarWithDtype creates a scalar array matching the dtype of a. // Matching dtype is important for graph fusion and avoiding implicit casts. func scalarWithDtype(s float32, a *Array) C.mlx_array { f32 := C.mlx_array_new_float(C.float(s)) dtype := a.DType() if dtype == DTypeFloat32 { return f32 } casted := C.mlx_array_new() C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx) C.mlx_array_free(f32) return casted } func AddScalar(a *Array, s float32) *Array { scalar := scalarWithDtype(s, a) out := New("ADD_SCALAR") C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx) C.mlx_array_free(scalar) return out } func MulScalar(a *Array, s float32) *Array { scalar := scalarWithDtype(s, a) out := New("MUL_SCALAR") C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx) C.mlx_array_free(scalar) return out } func DivScalar(a *Array, s float32) *Array { scalar := scalarWithDtype(s, a) out := New("DIV_SCALAR") C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx) C.mlx_array_free(scalar) return out } func FloorDivideScalar(a *Array, s int32) *Array { scalar := FromValue(int(s)) return a.FloorDivide(scalar) } // Array constructors func NewArrayInt32(data []int32, shape []int32) *Array { cShape := make([]C.int, len(shape)) for i, s := range shape { cShape[i] = C.int(s) } out := New("NEW_ARRAY_INT32") out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32)) return out } func NewScalarArray(value float32) *Array { out := New("SCALAR") out.ctx = C.mlx_array_new_float32(C.float(value)) return out } func ZerosF32(shape []int32) *Array { return Zeros(DTypeFloat32, func() []int { ints := make([]int, len(shape)) for i, s := range shape { ints[i] = int(s) } return ints }()...) } // Utility func Collect(v any) []*Array { var arrays []*Array seen := make(map[uintptr]bool) collect(reflect.ValueOf(v), &arrays, seen) return arrays } func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) { if !v.IsValid() { return } if v.Kind() == reflect.Ptr { if v.IsNil() { return } ptr := v.Pointer() if seen[ptr] { return } seen[ptr] = true if arr, ok := v.Interface().(*Array); ok { if arr != nil && arr.Valid() { *arrays = append(*arrays, arr) } return } collect(v.Elem(), arrays, seen) return } switch v.Kind() { case reflect.Struct: // Check if this struct IS an Array (not a pointer to one) if arr, ok := v.Addr().Interface().(*Array); ok { if arr != nil && arr.Valid() { *arrays = append(*arrays, arr) } return } for i := 0; i < v.NumField(); i++ { field := v.Field(i) if field.CanInterface() { collect(field, arrays, seen) } } case reflect.Slice: for i := 0; i < v.Len(); i++ { collect(v.Index(i), arrays, seen) } case reflect.Map: for _, key := range v.MapKeys() { collect(v.MapIndex(key), arrays, seen) } case reflect.Interface: if !v.IsNil() { collect(v.Elem(), arrays, seen) } } } func EnableCompile() { C.mlx_enable_compile() } func DisableCompile() { C.mlx_disable_compile() }