From c2b0bb7a52b02a50e63274f07a21a4539e3cfe19 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 23 Mar 2026 11:28:44 -0700 Subject: [PATCH] mlx: update as of 3/23 (#14789) * mlx: update to HEAD on 3/23 Also fixes a few misc vendoring bugs uncovered with this first update. This also renames the version files to make them clearer. * CUDA Fast Gated Delta kernel * mlx: detect eval errors and panic On model errors or missing kernels, don't mask the error, bubble it up. --- Dockerfile | 2 +- MLX_CORE_VERSION | 1 - MLX_C_VERSION | 1 + MLX_VERSION | 2 +- x/imagegen/mlx/CMakeLists.txt | 17 +- x/imagegen/mlx/mlx.c | 70 ++++- x/imagegen/mlx/mlx.go | 6 +- x/imagegen/mlx/mlx.h | 36 ++- x/mlxrunner/client.go | 11 +- x/mlxrunner/mlx/CMakeLists.txt | 4 +- x/mlxrunner/mlx/gated_delta.go | 259 +++++++++++++++++- x/mlxrunner/mlx/generated.c | 18 +- x/mlxrunner/mlx/generated.h | 54 +++- x/mlxrunner/mlx/include/mlx/c/README.md | 2 +- .../mlx/include/mlx/c/distributed_group.h | 6 +- x/mlxrunner/mlx/include/mlx/c/ops.h | 8 + x/mlxrunner/mlx/mlx.go | 49 +++- x/mlxrunner/mlx/ops_extra.go | 6 +- 18 files changed, 497 insertions(+), 55 deletions(-) delete mode 100644 MLX_CORE_VERSION create mode 100644 MLX_C_VERSION diff --git a/Dockerfile b/Dockerfile index 0743cc45d..5b0b09dcc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -157,7 +157,7 @@ COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY x/imagegen/mlx x/imagegen/mlx COPY go.mod go.sum . -COPY MLX_VERSION MLX_CORE_VERSION . +COPY MLX_VERSION MLX_C_VERSION . RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local ENV PATH=/usr/local/go/bin:$PATH RUN go mod download diff --git a/MLX_CORE_VERSION b/MLX_CORE_VERSION deleted file mode 100644 index 912750052..000000000 --- a/MLX_CORE_VERSION +++ /dev/null @@ -1 +0,0 @@ -v0.30.6 diff --git a/MLX_C_VERSION b/MLX_C_VERSION new file mode 100644 index 000000000..5dc11c479 --- /dev/null +++ b/MLX_C_VERSION @@ -0,0 +1 @@ +0726ca922fc902c4c61ef9c27d94132be418e945 diff --git a/MLX_VERSION b/MLX_VERSION index b043aa648..12d7c829a 100644 --- a/MLX_VERSION +++ b/MLX_VERSION @@ -1 +1 @@ -v0.5.0 +38ad257088fb2193ad47e527cf6534a689f30943 diff --git a/x/imagegen/mlx/CMakeLists.txt b/x/imagegen/mlx/CMakeLists.txt index 70246ef4b..8f5491be1 100644 --- a/x/imagegen/mlx/CMakeLists.txt +++ b/x/imagegen/mlx/CMakeLists.txt @@ -1,11 +1,11 @@ include(FetchContent) -# Read MLX version from top-level file (shared with Dockerfile) -file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG) +# Read MLX-C version from top-level file (shared with Dockerfile) +file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG) string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG) -# Read MLX core version from top-level file -file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG) +# Read MLX version from top-level file +file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_GIT_TAG) string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG) set(MLX_C_BUILD_EXAMPLES OFF) @@ -98,6 +98,15 @@ FetchContent_MakeAvailable(mlx-c) file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h") file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/") +# Regenerate Go/C shim wrappers from the (possibly updated) headers. +find_program(GO_EXECUTABLE go REQUIRED) +message(STATUS "Regenerating MLX Go wrappers") +execute_process( + COMMAND ${GO_EXECUTABLE} generate ./x/... + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + COMMAND_ERROR_IS_FATAL ANY +) + # For local dev builds, override MLX_VERSION with git describe output if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX) execute_process( diff --git a/x/imagegen/mlx/mlx.c b/x/imagegen/mlx/mlx.c index b0ccbacdf..8d7ec0e0a 100644 --- a/x/imagegen/mlx/mlx.c +++ b/x/imagegen/mlx/mlx.c @@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL; mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL; -bool (*mlx_distributed_is_available_ptr)(void) = NULL; -mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL; +bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL; +mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL; void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL; void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL; int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL; @@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL; int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL; int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL; @@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL; -int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL; int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL; int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; @@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL; +int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL; +int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; @@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL; -int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; -int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; +int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL; +int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL; int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; @@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n"); return -1; } + mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett"); + if (mlx_bartlett_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n"); + return -1; + } mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and"); if (mlx_bitwise_and_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n"); @@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n"); return -1; } + mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman"); + if (mlx_blackman_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n"); + return -1; + } mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm"); if (mlx_block_masked_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n"); @@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n"); return -1; } + mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming"); + if (mlx_hamming_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n"); + return -1; + } + mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning"); + if (mlx_hanning_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n"); + return -1; + } mlx_identity_ptr = GET_SYM(handle, "mlx_identity"); if (mlx_identity_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n"); @@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i return mlx_distributed_group_split_ptr(group, color, key); } -bool mlx_distributed_is_available(void) { - return mlx_distributed_is_available_ptr(); +bool mlx_distributed_is_available(const char* bk) { + return mlx_distributed_is_available_ptr(bk); } -mlx_distributed_group mlx_distributed_init(bool strict) { - return mlx_distributed_init_ptr(strict); +mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) { + return mlx_distributed_init_ptr(strict, bk); } void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { @@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_3d_ptr(res, a, s); } +int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) { + return mlx_bartlett_ptr(res, M, s); +} + int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_bitwise_and_ptr(res, a, b, s); } @@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const return mlx_bitwise_xor_ptr(res, a, b, s); } +int mlx_blackman(mlx_array* res, int M, const mlx_stream s) { + return mlx_blackman_ptr(res, M, s); +} + int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) { return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); } @@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_ return mlx_depends_ptr(res, inputs, dependencies); } -int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) { - return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s); +int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) { + return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s); } int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { @@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float return mlx_hadamard_transform_ptr(res, a, scale, s); } +int mlx_hamming(mlx_array* res, int M, const mlx_stream s) { + return mlx_hamming_ptr(res, M, s); +} + +int mlx_hanning(mlx_array* res, int M, const mlx_stream s) { + return mlx_hanning_ptr(res, M, s); +} + int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { return mlx_identity_ptr(res, n, dtype, s); } @@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice return mlx_put_along_axis_ptr(res, a, indices, values, axis, s); } -int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { - return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s); +int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) { + return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s); } -int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { - return mlx_quantize_ptr(res, w, group_size, bits, mode, s); +int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) { + return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s); } int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go index dc925570b..bf6665263 100644 --- a/x/imagegen/mlx/mlx.go +++ b/x/imagegen/mlx/mlx.go @@ -2125,7 +2125,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} res := C.mlx_vector_array_new() - C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream()) + var globalScale C.mlx_array + C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream()) // Result is a vector of arrays: [weights, scales, biases?] // mxfp8 mode returns only 2 elements (no biases) @@ -2161,7 +2162,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr } res := C.mlx_array_new() - C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream()) + var globalScale C.mlx_array + C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream()) return newArray(res) } diff --git a/x/imagegen/mlx/mlx.h b/x/imagegen/mlx/mlx.h index 34829d732..3f53c8941 100644 --- a/x/imagegen/mlx/mlx.h +++ b/x/imagegen/mlx/mlx.h @@ -309,10 +309,12 @@ #undef mlx_atleast_1d #undef mlx_atleast_2d #undef mlx_atleast_3d +#undef mlx_bartlett #undef mlx_bitwise_and #undef mlx_bitwise_invert #undef mlx_bitwise_or #undef mlx_bitwise_xor +#undef mlx_blackman #undef mlx_block_masked_mm #undef mlx_broadcast_arrays #undef mlx_broadcast_to @@ -365,6 +367,8 @@ #undef mlx_greater #undef mlx_greater_equal #undef mlx_hadamard_transform +#undef mlx_hamming +#undef mlx_hanning #undef mlx_identity #undef mlx_imag #undef mlx_inner @@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group); extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group); extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key); -extern bool (*mlx_distributed_is_available_ptr)(void); -extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict); +extern bool (*mlx_distributed_is_available_ptr)(const char* bk); +extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk); extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...); extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless); @@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s); extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); @@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); -extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s); +extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s); extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s); extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); @@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); +extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s); +extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); @@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); -extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); -extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); +extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s); +extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s); extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); @@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group); mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key); -bool mlx_distributed_is_available(void); +bool mlx_distributed_is_available(const char* bk); -mlx_distributed_group mlx_distributed_init(bool strict); +mlx_distributed_group mlx_distributed_init(bool strict, const char* bk); void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); @@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_bartlett(mlx_array* res, int M, const mlx_stream s); + int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s); @@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +int mlx_blackman(mlx_array* res, int M, const mlx_stream s); + int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s); int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); @@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); -int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s); +int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s); int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); @@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); +int mlx_hamming(mlx_array* res, int M, const mlx_stream s); + +int mlx_hanning(mlx_array* res, int M, const mlx_stream s); + int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); @@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); -int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); +int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s); -int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); +int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s); int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index 989581113..e8d90147a 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -230,6 +230,9 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f resp, err := c.client.Do(httpReq) if err != nil { + if errMsg := c.status.getLastErr(); errMsg != "" { + return fmt.Errorf("mlx runner failed: %s", errMsg) + } return err } defer resp.Body.Close() @@ -267,7 +270,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f } } - return scanner.Err() + if err := scanner.Err(); err != nil { + if errMsg := c.status.getLastErr(); errMsg != "" { + return fmt.Errorf("mlx runner failed: %s", errMsg) + } + return err + } + return nil } func (c *Client) ContextLength() int { diff --git a/x/mlxrunner/mlx/CMakeLists.txt b/x/mlxrunner/mlx/CMakeLists.txt index 9825c441b..3db230012 100644 --- a/x/mlxrunner/mlx/CMakeLists.txt +++ b/x/mlxrunner/mlx/CMakeLists.txt @@ -15,7 +15,9 @@ set(CMAKE_INSTALL_RPATH "@loader_path") include(FetchContent) -set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "") +# Read MLX-C version from top-level file (shared with imagegen CMakeLists) +file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG) +string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG) FetchContent_Declare( mlx-c diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go index 31550cef1..c691f0512 100644 --- a/x/mlxrunner/mlx/gated_delta.go +++ b/x/mlxrunner/mlx/gated_delta.go @@ -13,6 +13,10 @@ var ( gatedDeltaMetalKernelOnce sync.Once gatedDeltaMetalKernel C.mlx_fast_metal_kernel gatedDeltaMetalDisabled bool + + gatedDeltaCUDAKernelOnce sync.Once + gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel + gatedDeltaCUDADisabled bool ) const gatedDeltaMetalKernelSource = ` @@ -83,6 +87,86 @@ for (int i = 0; i < n_per_t; ++i) { } ` +const gatedDeltaCUDAKernelSource = ` +auto tid_x = threadIdx.x; +auto tid_y = threadIdx.y; +auto grid_y = blockIdx.y * blockDim.y + tid_y; +auto grid_z = blockIdx.z; + +int T_val = static_cast(*T); + +auto n = grid_z; +auto b_idx = n / Hv; +auto hv_idx = n % Hv; +auto hk_idx = hv_idx / (Hv / Hk); +constexpr int n_per_t = Dk / 32; + +// q, k: [B, T, Hk, Dk] +auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk; +auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk; + +// v, y: [B, T, Hv, Dv] +auto dv_idx = grid_y; +auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv; +y += b_idx * T_val * Hv * Dv + hv_idx * Dv; + +auto dk_idx = tid_x; + +// state_in, state_out: [B, Hv, Dv, Dk] +auto i_state = state_in + (n * Dv + dv_idx) * Dk; +auto o_state = state_out + (n * Dv + dv_idx) * Dk; + +float state[n_per_t]; +for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = static_cast(i_state[s_idx]); +} + +// g: [B, T, Hv] +auto g_ = g + b_idx * T_val * Hv; +auto beta_ = beta + b_idx * T_val * Hv; + +for (int t = 0; t < T_val; ++t) { + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * static_cast(g_[hv_idx]); + kv_mem += state[i] * static_cast(k_[s_idx]); + } + // Warp reduction (full warp, 32 threads in x) + for (int offset = 16; offset > 0; offset >>= 1) + kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset); + kv_mem = __shfl_sync(0xffffffff, kv_mem, 0); + + auto delta = (static_cast(v_[dv_idx]) - kv_mem) * static_cast(beta_[hv_idx]); + + float out = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] + static_cast(k_[s_idx]) * delta; + out += state[i] * static_cast(q_[s_idx]); + } + // Warp reduction + for (int offset = 16; offset > 0; offset >>= 1) + out += __shfl_down_sync(0xffffffff, out, offset); + if (tid_x == 0) { + y[dv_idx] = static_cast(out); + } + + q_ += Hk * Dk; + k_ += Hk * Dk; + v_ += Hv * Dv; + y += Hv * Dv; + g_ += Hv; + beta_ += Hv; +} + +for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + o_state[s_idx] = static_cast(state[i]); +} +` + func cStringVector(values []string) (C.mlx_vector_string, func(), bool) { vec := C.mlx_vector_string_new() ok := true @@ -352,11 +436,184 @@ func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) { return Concatenate(outs, 1), nextState } +func initGatedDeltaCUDAKernel() { + var cudaAvail C.bool + if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) { + gatedDeltaCUDADisabled = true + return + } + + inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"}) + if !ok { + gatedDeltaCUDADisabled = true + freeInputs() + return + } + defer freeInputs() + + outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"}) + if !ok { + gatedDeltaCUDADisabled = true + freeOutputs() + return + } + defer freeOutputs() + + cName := C.CString("gated_delta_step") + defer C.free(unsafe.Pointer(cName)) + cSource := C.CString(gatedDeltaCUDAKernelSource) + defer C.free(unsafe.Pointer(cSource)) + cHeader := C.CString("") + defer C.free(unsafe.Pointer(cHeader)) + + gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new( + cName, + inputs, + outputs, + cSource, + cHeader, + C.bool(true), + C.int(0), + ) +} + +func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) { + if gatedDeltaCUDADisabled { + return nil, nil, false + } + if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil { + return nil, nil, false + } + + qd := q.Dims() + kd := k.Dims() + vd := v.Dims() + gd := g.Dims() + bd := beta.Dims() + sd := state.Dims() + if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 { + return nil, nil, false + } + + B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3] + if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 { + return nil, nil, false + } + if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk { + return nil, nil, false + } + Hv, Dv := vd[2], vd[3] + if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 { + return nil, nil, false + } + if gd[0] != B || gd[1] != T || gd[2] != Hv { + return nil, nil, false + } + if bd[0] != B || bd[1] != T || bd[2] != Hv { + return nil, nil, false + } + if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk { + return nil, nil, false + } + + dtype := q.DType() + if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype { + return nil, nil, false + } + + gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel) + if gatedDeltaCUDADisabled { + return nil, nil, false + } + + cfg := C.mlx_fast_cuda_kernel_config_new() + defer C.mlx_fast_cuda_kernel_config_free(cfg) + + cInT := C.CString("InT") + defer C.free(unsafe.Pointer(cInT)) + if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 { + gatedDeltaCUDADisabled = true + return nil, nil, false + } + for _, tpl := range []struct { + name string + value int + }{ + {name: "Dk", value: Dk}, + {name: "Dv", value: Dv}, + {name: "Hk", value: Hk}, + {name: "Hv", value: Hv}, + } { + cn := C.CString(tpl.name) + rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value)) + C.free(unsafe.Pointer(cn)) + if rc != 0 { + gatedDeltaCUDADisabled = true + return nil, nil, false + } + } + + yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)} + stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)} + if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 { + gatedDeltaCUDADisabled = true + return nil, nil, false + } + if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 { + gatedDeltaCUDADisabled = true + return nil, nil, false + } + if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 { + gatedDeltaCUDADisabled = true + return nil, nil, false + } + threadY := Dv + if threadY > 4 { + threadY = 4 + } + if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 { + gatedDeltaCUDADisabled = true + return nil, nil, false + } + + tScalar := FromValue(T) + inputs := []C.mlx_array{ + q.ctx, + k.ctx, + v.ctx, + g.ctx, + beta.ctx, + state.ctx, + tScalar.ctx, + } + inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs))) + defer C.mlx_vector_array_free(inVec) + + outVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 { + gatedDeltaCUDADisabled = true + return nil, nil, false + } + if int(C.mlx_vector_array_size(outVec)) < 2 { + return nil, nil, false + } + + y = New("GATED_DELTA_CUDA_Y") + nextState = New("GATED_DELTA_CUDA_STATE") + C.mlx_vector_array_get(&y.ctx, outVec, 0) + C.mlx_vector_array_get(&nextState.ctx, outVec, 1) + return y, nextState, true +} + // GatedDelta runs the recurrent update operation. // -// It uses the fused Metal kernel when available and otherwise falls back to a +// It tries the fused CUDA kernel first, then Metal, then falls back to a // backend-agnostic MLX implementation with identical inputs/outputs. func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) { + if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok { + return y, nextState + } if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok { return y, nextState } diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c index ecf9d30c8..f1333971c 100644 --- a/x/mlxrunner/mlx/generated.c +++ b/x/mlxrunner/mlx/generated.c @@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)( int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; -bool (*mlx_distributed_is_available_)(void) = NULL; -mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL; +bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL; +mlx_distributed_group (*mlx_distributed_init_)( + bool strict, + const char* bk /* may be null */) = NULL; void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, @@ -924,6 +926,7 @@ int (*mlx_astype_)( int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_bitwise_and_)( mlx_array* res, const mlx_array a, @@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)( const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_block_masked_mm_)( mlx_array* res, const mlx_array a, @@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale /* may be null */, mlx_optional_dtype dtype, const mlx_stream s) = NULL; int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; @@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)( const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL; +int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL; +int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_inner_)( @@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale_x /* may be null */, + const mlx_array global_scale_w /* may be null */, const mlx_stream s) = NULL; int (*mlx_quantize_)( mlx_vector_array* res, @@ -1555,6 +1564,7 @@ int (*mlx_quantize_)( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale /* may be null */, const mlx_stream s) = NULL; int (*mlx_quantized_matmul_)( mlx_array* res, @@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_atleast_1d); CHECK_LOAD(handle, mlx_atleast_2d); CHECK_LOAD(handle, mlx_atleast_3d); + CHECK_LOAD(handle, mlx_bartlett); CHECK_LOAD(handle, mlx_bitwise_and); CHECK_LOAD(handle, mlx_bitwise_invert); CHECK_LOAD(handle, mlx_bitwise_or); CHECK_LOAD(handle, mlx_bitwise_xor); + CHECK_LOAD(handle, mlx_blackman); CHECK_LOAD(handle, mlx_block_masked_mm); CHECK_LOAD(handle, mlx_broadcast_arrays); CHECK_LOAD(handle, mlx_broadcast_to); @@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_greater); CHECK_LOAD(handle, mlx_greater_equal); CHECK_LOAD(handle, mlx_hadamard_transform); + CHECK_LOAD(handle, mlx_hamming); + CHECK_LOAD(handle, mlx_hanning); CHECK_LOAD(handle, mlx_identity); CHECK_LOAD(handle, mlx_imag); CHECK_LOAD(handle, mlx_inner); diff --git a/x/mlxrunner/mlx/generated.h b/x/mlxrunner/mlx/generated.h index e8dfa7b90..26119f2ff 100644 --- a/x/mlxrunner/mlx/generated.h +++ b/x/mlxrunner/mlx/generated.h @@ -300,10 +300,12 @@ #define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_ #define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_ #define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_ +#define mlx_bartlett mlx_bartlett_mlx_gen_orig_ #define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_ #define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_ #define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_ #define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_ +#define mlx_blackman mlx_blackman_mlx_gen_orig_ #define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_ #define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_ #define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_ @@ -356,6 +358,8 @@ #define mlx_greater mlx_greater_mlx_gen_orig_ #define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_ #define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_ +#define mlx_hamming mlx_hamming_mlx_gen_orig_ +#define mlx_hanning mlx_hanning_mlx_gen_orig_ #define mlx_identity mlx_identity_mlx_gen_orig_ #define mlx_imag mlx_imag_mlx_gen_orig_ #define mlx_inner mlx_inner_mlx_gen_orig_ @@ -889,10 +893,12 @@ #undef mlx_atleast_1d #undef mlx_atleast_2d #undef mlx_atleast_3d +#undef mlx_bartlett #undef mlx_bitwise_and #undef mlx_bitwise_invert #undef mlx_bitwise_or #undef mlx_bitwise_xor +#undef mlx_blackman #undef mlx_block_masked_mm #undef mlx_broadcast_arrays #undef mlx_broadcast_to @@ -945,6 +951,8 @@ #undef mlx_greater #undef mlx_greater_equal #undef mlx_hadamard_transform +#undef mlx_hamming +#undef mlx_hanning #undef mlx_identity #undef mlx_imag #undef mlx_inner @@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)( extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); -extern bool (*mlx_distributed_is_available_)(void); -extern mlx_distributed_group (*mlx_distributed_init_)(bool strict); +extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */); +extern mlx_distributed_group (*mlx_distributed_init_)( + bool strict, + const char* bk /* may be null */); extern void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, @@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)( extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_bitwise_and_)( mlx_array* res, const mlx_array a, @@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)( const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_block_masked_mm_)( mlx_array* res, const mlx_array a, @@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale /* may be null */, mlx_optional_dtype dtype, const mlx_stream s); extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); @@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)( const mlx_array a, mlx_optional_float scale, const mlx_stream s); +extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s); +extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_inner_)( @@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale_x /* may be null */, + const mlx_array global_scale_w /* may be null */, const mlx_stream s); extern int (*mlx_quantize_)( mlx_vector_array* res, @@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale /* may be null */, const mlx_stream s); extern int (*mlx_quantized_matmul_)( mlx_array* res, @@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) { static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { return mlx_distributed_group_split_(group, color, key); } -static inline bool mlx_distributed_is_available(void) { - return mlx_distributed_is_available_(); +static inline bool mlx_distributed_is_available(const char* bk /* may be null */) { + return mlx_distributed_is_available_(bk); } -static inline mlx_distributed_group mlx_distributed_init(bool strict) { - return mlx_distributed_init_(strict); +static inline mlx_distributed_group mlx_distributed_init( + bool strict, + const char* bk /* may be null */) { + return mlx_distributed_init_(strict, bk); } static inline void mlx_set_error_handler( mlx_error_handler_func handler, @@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_3d_(res, a, s); } +static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) { + return mlx_bartlett_(res, M, s); +} static inline int mlx_bitwise_and( mlx_array* res, const mlx_array a, @@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor( const mlx_stream s) { return mlx_bitwise_xor_(res, a, b, s); } +static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) { + return mlx_blackman_(res, M, s); +} static inline int mlx_block_masked_mm( mlx_array* res, const mlx_array a, @@ -5193,9 +5219,10 @@ static inline int mlx_dequantize( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale /* may be null */, mlx_optional_dtype dtype, const mlx_stream s) { - return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s); + return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s); } static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_diag_(res, a, k, s); @@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform( const mlx_stream s) { return mlx_hadamard_transform_(res, a, scale, s); } +static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) { + return mlx_hamming_(res, M, s); +} +static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) { + return mlx_hanning_(res, M, s); +} static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { return mlx_identity_(res, n, dtype, s); } @@ -5793,8 +5826,10 @@ static inline int mlx_qqmm( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale_x /* may be null */, + const mlx_array global_scale_w /* may be null */, const mlx_stream s) { - return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s); + return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s); } static inline int mlx_quantize( mlx_vector_array* res, @@ -5802,8 +5837,9 @@ static inline int mlx_quantize( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale /* may be null */, const mlx_stream s) { - return mlx_quantize_(res, w, group_size, bits, mode, s); + return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s); } static inline int mlx_quantized_matmul( mlx_array* res, diff --git a/x/mlxrunner/mlx/include/mlx/c/README.md b/x/mlxrunner/mlx/include/mlx/c/README.md index 905ca451c..1d693359d 100644 --- a/x/mlxrunner/mlx/include/mlx/c/README.md +++ b/x/mlxrunner/mlx/include/mlx/c/README.md @@ -1,7 +1,7 @@ # Vendored MLX-C Headers These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c). -The pinned version is in `MLX_VERSION` at the repo root. +The pinned version is in `MLX_C_VERSION` at the repo root. Headers are automatically refreshed when you run a CMake build: diff --git a/x/mlxrunner/mlx/include/mlx/c/distributed_group.h b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h index 3cfccc806..43aa2ae56 100644 --- a/x/mlxrunner/mlx/include/mlx/c/distributed_group.h +++ b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h @@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key); /** * Check if distributed is available. */ -bool mlx_distributed_is_available(void); +bool mlx_distributed_is_available(const char* bk /* may be null */); /** * Initialize distributed. */ -mlx_distributed_group mlx_distributed_init(bool strict); +mlx_distributed_group mlx_distributed_init( + bool strict, + const char* bk /* may be null */); /**@}*/ diff --git a/x/mlxrunner/mlx/include/mlx/c/ops.h b/x/mlxrunner/mlx/include/mlx/c/ops.h index a1446fb9e..64d70e2f4 100644 --- a/x/mlxrunner/mlx/include/mlx/c/ops.h +++ b/x/mlxrunner/mlx/include/mlx/c/ops.h @@ -166,6 +166,7 @@ int mlx_astype( int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_bartlett(mlx_array* res, int M, const mlx_stream s); int mlx_bitwise_and( mlx_array* res, const mlx_array a, @@ -182,6 +183,7 @@ int mlx_bitwise_xor( const mlx_array a, const mlx_array b, const mlx_stream s); +int mlx_blackman(mlx_array* res, int M, const mlx_stream s); int mlx_block_masked_mm( mlx_array* res, const mlx_array a, @@ -362,6 +364,7 @@ int mlx_dequantize( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale /* may be null */, mlx_optional_dtype dtype, const mlx_stream s); int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); @@ -498,6 +501,8 @@ int mlx_hadamard_transform( const mlx_array a, mlx_optional_float scale, const mlx_stream s); +int mlx_hamming(mlx_array* res, int M, const mlx_stream s); +int mlx_hanning(mlx_array* res, int M, const mlx_stream s); int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_inner( @@ -790,6 +795,8 @@ int mlx_qqmm( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale_x /* may be null */, + const mlx_array global_scale_w /* may be null */, const mlx_stream s); int mlx_quantize( mlx_vector_array* res, @@ -797,6 +804,7 @@ int mlx_quantize( mlx_optional_int group_size, mlx_optional_int bits, const char* mode, + const mlx_array global_scale /* may be null */, const mlx_stream s); int mlx_quantized_matmul( mlx_array* res, diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go index f2daa2e28..5ec3fc850 100644 --- a/x/mlxrunner/mlx/mlx.go +++ b/x/mlxrunner/mlx/mlx.go @@ -7,8 +7,44 @@ package mlx // #cgo LDFLAGS: -lstdc++ // #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate // #include "generated.h" +// #include +// +// static char _mlx_last_error_msg[1024] = {0}; +// static int _mlx_last_error_flag = 0; +// +// static void _mlx_capture_error_handler(const char* msg, void* data) { +// (void)data; +// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1); +// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0'; +// _mlx_last_error_flag = 1; +// } +// +// static void mlx_install_capture_handler(void) { +// if (mlx_set_error_handler_) { +// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL); +// } +// } +// +// static void mlx_clear_last_error(void) { +// _mlx_last_error_flag = 0; +// _mlx_last_error_msg[0] = '\0'; +// } +// +// static int mlx_had_last_error(void) { +// return _mlx_last_error_flag; +// } +// +// static const char* mlx_get_last_error(void) { +// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL; +// } import "C" +func init() { + // Replace the default exit(-1) error handler with one that captures + // the error message so we can surface it in Go. + C.mlx_install_capture_handler() +} + // Version returns the MLX core library version string. func Version() string { str := C.mlx_string_new() @@ -31,10 +67,19 @@ func doEval(outputs []*Array, async bool) { } } + C.mlx_clear_last_error() + var rc C.int if async { - C.mlx_async_eval(vector) + rc = C.mlx_async_eval(vector) } else { - C.mlx_eval(vector) + rc = C.mlx_eval(vector) + } + if rc != 0 { + msg := "mlx eval failed" + if C.mlx_had_last_error() != 0 { + msg = C.GoString(C.mlx_get_last_error()) + } + panic("mlx: " + msg) } } diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index ff06092e9..a4b294776 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -17,7 +17,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} res := C.mlx_vector_array_new() defer C.mlx_vector_array_free(res) - C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx) + var globalScale C.mlx_array + C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx) vecSize := int(C.mlx_vector_array_size(res)) w0 := New("QUANTIZE_W") @@ -45,7 +46,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr } out := New("DEQUANTIZE") - C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx) + var globalScale C.mlx_array + C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx) return out }