diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index 6d9596807..a51ca1b9c 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -117,6 +117,25 @@ jobs:
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: ''
runner_dir: 'vulkan'
+ - os: windows
+ arch: amd64
+ preset: 'MLX CUDA 13'
+ install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
+ cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
+ cuda-components:
+ - '"cudart"'
+ - '"nvcc"'
+ - '"cublas"'
+ - '"cublas_dev"'
+ - '"cufft"'
+ - '"cufft_dev"'
+ - '"nvrtc"'
+ - '"nvrtc_dev"'
+ - '"crt"'
+ - '"nvvm"'
+ - '"nvptxcompiler"'
+ cuda-version: '13.0'
+ flags: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -125,8 +144,10 @@ jobs:
- name: Install system dependencies
run: |
choco install -y --no-progress ccache ninja
- ccache -o cache_dir=${{ github.workspace }}\.ccache
- - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
+ if (Get-Command ccache -ErrorAction SilentlyContinue) {
+ ccache -o cache_dir=${{ github.workspace }}\.ccache
+ }
+ - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
id: cache-install
uses: actions/cache/restore@v4
with:
@@ -134,8 +155,9 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
- key: ${{ matrix.install }}
- - if: startsWith(matrix.preset, 'CUDA ')
+ C:\Program Files\NVIDIA\CUDNN
+ key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
+ - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
name: Install CUDA ${{ matrix.cuda-version }}
run: |
$ErrorActionPreference = "Stop"
@@ -179,6 +201,23 @@ jobs:
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
+ - if: startsWith(matrix.preset, 'MLX ')
+ name: Install cuDNN for MLX
+ run: |
+ $ErrorActionPreference = "Stop"
+ $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
+ if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
+ Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
+ Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
+ $cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
+ New-Item -ItemType Directory -Force -Path $cudnnRoot
+ Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
+ }
+
+ echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
+ echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
+ echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
+ echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
@@ -186,7 +225,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
- key: ${{ matrix.install }}
+ C:\Program Files\NVIDIA\CUDNN
+ key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
@@ -198,7 +238,7 @@ jobs:
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
- cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
+ cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env:
CMAKE_GENERATOR: Ninja
diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index a47156516..cf0545b56 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -37,7 +37,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
}
- echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
+ echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux:
@@ -60,6 +60,10 @@ jobs:
mesa-vulkan-drivers vulkan-tools
libvulkan1 libvulkan-dev
vulkan-sdk cmake ccache g++ make
+ - preset: 'MLX CUDA 13'
+ container: nvidia/cuda:13.0.0-devel-ubuntu22.04
+ extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
+ flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
runs-on: linux
container: ${{ matrix.container }}
steps:
@@ -76,6 +80,10 @@ jobs:
$sudo apt-get update
fi
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
+ # MLX requires CMake 3.25+, install from official releases
+ if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
+ curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
+ fi
# Export VULKAN_SDK if provided by LunarG package (defensive)
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
@@ -87,8 +95,8 @@ jobs:
path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
- run: |
- cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
- cmake --build --preset ${{ matrix.preset }} --parallel
+ cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
+ cmake --build --preset "${{ matrix.preset }}" --parallel
windows:
needs: [changes]
@@ -114,12 +122,31 @@ jobs:
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
- preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
+ - preset: 'MLX CUDA 13'
+ install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
+ cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
+ flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
+ cuda-components:
+ - '"cudart"'
+ - '"nvcc"'
+ - '"cublas"'
+ - '"cublas_dev"'
+ - '"cufft"'
+ - '"cufft_dev"'
+ - '"nvrtc"'
+ - '"nvrtc_dev"'
+ - '"crt"'
+ - '"nvvm"'
+ - '"nvptxcompiler"'
+ cuda-version: '13.0'
runs-on: windows
steps:
- run: |
choco install -y --no-progress ccache ninja
- ccache -o cache_dir=${{ github.workspace }}\.ccache
- - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
+ if (Get-Command ccache -ErrorAction SilentlyContinue) {
+ ccache -o cache_dir=${{ github.workspace }}\.ccache
+ }
+ - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
id: cache-install
uses: actions/cache/restore@v4
with:
@@ -127,8 +154,9 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
- key: ${{ matrix.install }}
- - if: matrix.preset == 'CUDA'
+ C:\Program Files\NVIDIA\CUDNN
+ key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
+ - if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
name: Install CUDA ${{ matrix.cuda-version }}
run: |
$ErrorActionPreference = "Stop"
@@ -164,10 +192,27 @@ jobs:
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
-
+
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
+ - if: matrix.preset == 'MLX CUDA 13'
+ name: Install cuDNN for MLX
+ run: |
+ $ErrorActionPreference = "Stop"
+ $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
+ if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
+ Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
+ Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
+ $cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
+ New-Item -ItemType Directory -Force -Path $cudnnRoot
+ Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
+ }
+
+ echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
+ echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
+ echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
+ echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
@@ -175,7 +220,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
- key: ${{ matrix.install }}
+ C:\Program Files\NVIDIA\CUDNN
+ key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a9e4471d8..8c37d3374 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
-include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
-include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
-include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
-include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
+# Store ggml include paths for use with target_include_directories later.
+# We avoid global include_directories() to prevent polluting the include path
+# for other projects like MLX (whose openblas dependency has its own common.h).
+set(GGML_INCLUDE_DIRS
+ ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
+ ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
+ ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
+ ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
+)
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
set(CPU_VARIANTS "ggml-cpu")
endif()
+# Apply ggml include directories to ggml targets only (not globally)
+target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
+foreach(variant ${CPU_VARIANTS})
+ if(TARGET ${variant})
+ target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
+ endif()
+endforeach()
+
install(TARGETS ggml-base ${CPU_VARIANTS}
RUNTIME_DEPENDENCIES
PRE_EXCLUDE_REGEXES ".*"
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
+ target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
if(AMDGPU_TARGETS)
find_package(hip REQUIRED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
+ target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
@@ -168,6 +183,7 @@ if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
+ target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
@@ -179,7 +195,6 @@ if(NOT APPLE)
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
-
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
+ # Build list of directories for runtime dependency resolution
+ set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
+ # Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
+ # CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
+ if(DEFINED ENV{CUDNN_ROOT_DIR})
+ # cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
+ file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
+ list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
+ endif()
+ # Add build output directory and MLX dependency build directories
+ list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
+ # OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
+ list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
+ # NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
+ # regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
+ # so there is no runtime dependency. This path covers the stub build dir
+ # for windows so we include the DLL in our dependencies.
+ list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
+
+ # Base regexes for runtime dependencies (cross-platform)
+ set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
+ # On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
+ if(WIN32)
+ list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
+ endif()
+
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
- DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
- PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
+ DIRECTORIES ${MLX_RUNTIME_DIRS}
+ PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
@@ -205,13 +246,54 @@ if(MLX_ENGINE)
COMPONENT MLX)
endif()
- # Manually install cudart and cublas since they might not be picked up as direct dependencies
+ # Install CCCL headers for NVRTC JIT compilation at runtime.
+ # MLX's own install rules use the default component so they get skipped by
+ # --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
+ # On Linux, MLX's jit_module.cpp resolves CCCL via
+ # current_binary_dir().parent_path() / "include" / "cccl", so we create a
+ # symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
+ # This will need refinement if we add multiple CUDA versions for MLX in the future.
+ if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
+ install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
+ DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
+ COMPONENT MLX)
+ install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
+ DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
+ COMPONENT MLX)
+ if(NOT WIN32 AND NOT APPLE)
+ install(CODE "
+ set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
+ set(_target \"${OLLAMA_RUNNER_DIR}/include\")
+ if(NOT EXISTS \${_link})
+ execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
+ endif()
+ " COMPONENT MLX)
+ endif()
+ endif()
+
+ # On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
+ # RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
+ # dlfcn-win32 is a known CMake target with its own install rules (which install
+ # to the wrong destination). We must install it explicitly here.
+ if(WIN32)
+ install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
+ DESTINATION ${OLLAMA_INSTALL_DIR}
+ COMPONENT MLX)
+ endif()
+
+ # Manually install CUDA runtime libraries that MLX loads via dlopen
+ # (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
if(CUDAToolkit_FOUND)
- file(GLOB CUDART_LIBS
+ file(GLOB MLX_CUDA_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
- "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
- if(CUDART_LIBS)
- install(FILES ${CUDART_LIBS}
+ "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
+ "${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
+ "${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
+ "${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
+ "${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
+ "${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
+ if(MLX_CUDA_LIBS)
+ install(FILES ${MLX_CUDA_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
diff --git a/CMakePresets.json b/CMakePresets.json
index 0d643038a..d099d3f16 100644
--- a/CMakePresets.json
+++ b/CMakePresets.json
@@ -112,6 +112,7 @@
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
+ "MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
diff --git a/Dockerfile b/Dockerfile
index cabb6cc82..0743cc45d 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -9,6 +9,11 @@ ARG CMAKEVERSION=3.31.2
ARG NINJAVERSION=1.12.1
ARG VULKANVERSION=1.4.321.1
+# Default empty stages for local MLX source overrides.
+# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
+FROM scratch AS local-mlx
+FROM scratch AS local-mlx-c
+
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
@@ -152,12 +157,20 @@ COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum .
-COPY MLX_VERSION .
+COPY MLX_VERSION MLX_CORE_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
- cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
+ --mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
+ --mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
+ if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
+ export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
+ fi \
+ && if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
+ export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
+ fi \
+ && cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
&& cmake --install build --component MLX --strip
@@ -168,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
COPY . .
-# Clone mlx-c headers for CGO (version from MLX_VERSION file)
-RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
-ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
+ENV CGO_CFLAGS="${CGO_CFLAGS}"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \
- go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
+ go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
diff --git a/MLX_CORE_VERSION b/MLX_CORE_VERSION
new file mode 100644
index 000000000..912750052
--- /dev/null
+++ b/MLX_CORE_VERSION
@@ -0,0 +1 @@
+v0.30.6
diff --git a/docs/development.md b/docs/development.md
index d0120a191..12c69c204 100644
--- a/docs/development.md
+++ b/docs/development.md
@@ -51,6 +51,9 @@ Install prerequisites:
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
+- (Optional) MLX engine support
+ - [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
+ - [cuDNN 9+](https://developer.nvidia.com/cudnn)
Then, configure and build the project:
@@ -101,6 +104,10 @@ Install prerequisites:
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
+- (Optional) MLX engine support
+ - [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
+ - [cuDNN 9+](https://developer.nvidia.com/cudnn)
+ - OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.
@@ -118,6 +125,67 @@ Lastly, run Ollama:
go run . serve
```
+## MLX Engine (Optional)
+
+The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
+
+### macOS (Apple Silicon)
+
+Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
+
+```shell
+xcodebuild -downloadComponent MetalToolchain
+```
+
+Verify it's installed correctly (should print "no input files"):
+
+```shell
+xcrun metal
+```
+
+Then build:
+
+```shell
+cmake -B build --preset MLX
+cmake --build build --preset MLX --parallel
+cmake --install build --component MLX
+```
+
+> [!NOTE]
+> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
+
+### Windows / Linux (CUDA)
+
+Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
+
+```shell
+cmake -B build --preset "MLX CUDA 13"
+cmake --build build --target mlx --target mlxc --config Release --parallel
+cmake --install build --component MLX --strip
+```
+
+### Local MLX source overrides
+
+To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
+
+```shell
+export OLLAMA_MLX_SOURCE=/path/to/mlx
+export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
+```
+
+For example, using the helper scripts with local mlx and mlx-c repos:
+```shell
+OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
+
+OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
+```
+
+```powershell
+$env:OLLAMA_MLX_SOURCE="../mlx"
+$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
+./scripts/build_darwin.ps1
+```
+
## Docker
```shell
diff --git a/parser/parser.go b/parser/parser.go
index 5ef918bf2..f3b6dcb55 100644
--- a/parser/parser.go
+++ b/parser/parser.go
@@ -181,6 +181,9 @@ func fileDigestMap(path string) (map[string]string, error) {
}
if !filepath.IsLocal(rel) {
+ if strings.Contains(rel, ".cache") {
+ return nil, fmt.Errorf("insecure path: %s\n\nUse --local-dir
when downloading model to disable caching", rel)
+ }
return nil, fmt.Errorf("insecure path: %s", rel)
}
diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh
index 4325a9787..4bec54cf8 100755
--- a/scripts/build_darwin.sh
+++ b/scripts/build_darwin.sh
@@ -59,7 +59,7 @@ _build_darwin() {
cmake --install $BUILD_DIR --component CPU
cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory
- MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
+ MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else
BUILD_DIR=build
@@ -70,10 +70,10 @@ _build_darwin() {
cmake --build --preset MLX --parallel
cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64
- MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
+ MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi
- GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
+ GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -o $INSTALL_PREFIX .
# Copy MLX libraries to same directory as executable for dlopen
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1
index 21e6f3be0..f162e4f7e 100644
--- a/scripts/build_windows.ps1
+++ b/scripts/build_windows.ps1
@@ -4,7 +4,10 @@
#
# gcloud auth application-default login
-$ErrorActionPreference = "Stop"
+# Use "Continue" so that stderr output from native commands (e.g. CGo warnings)
+# is not promoted to a terminating exception by the try/catch block.
+# All native commands already check $LASTEXITCODE explicitly.
+$ErrorActionPreference = "Continue"
mkdir -Force -path .\dist | Out-Null
@@ -16,13 +19,13 @@ function checkEnv {
if ($null -ne $arch) {
$script:ARCH = ($arch.ToString().ToLower()).Replace("x64", "amd64")
} else {
- write-host "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
+ Write-Output "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
$script:ARCH="amd64"
}
}
$script:TARGET_ARCH=$script:ARCH
Write-host "Building for ${script:TARGET_ARCH}"
- write-host "Locating required tools and paths"
+ Write-Output "Locating required tools and paths"
$script:SRC_DIR=$PWD
# Locate CUDA versions
@@ -37,16 +40,17 @@ function checkEnv {
$script:CUDA_DIRS=($cudaList | sort-object -Descending)
}
if ($script:CUDA_DIRS.length -gt 0) {
- write-host "Available CUDA Versions: $script:CUDA_DIRS"
+ Write-Output "Available CUDA Versions: $script:CUDA_DIRS"
} else {
- write-host "No CUDA versions detected"
+ Write-Output "No CUDA versions detected"
}
- # Locate ROCm version
- if ($null -ne $env:HIP_PATH) {
+ # Locate ROCm v6
+ $rocmDir=(get-item "C:\Program Files\AMD\ROCm\6.*" -ea 'silentlycontinue' | sort-object -Descending | select-object -First 1)
+ if ($null -ne $rocmDir) {
+ $script:HIP_PATH=$rocmDir.FullName
+ } elseif ($null -ne $env:HIP_PATH -and $env:HIP_PATH -match '[/\\]6\.') {
$script:HIP_PATH=$env:HIP_PATH
- } else {
- $script:HIP_PATH=(get-item "C:\Program Files\AMD\ROCm\*\bin\" -ea 'silentlycontinue' | sort-object -Descending)
}
$inoSetup=(get-item "C:\Program Files*\Inno Setup*\")
@@ -78,7 +82,7 @@ function checkEnv {
} else {
$script:PKG_VERSION="0.0.0"
}
- write-host "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
+ Write-Output "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
# Note: Windows Kits 10 signtool crashes with GCP's plugin
if ($null -eq $env:SIGN_TOOL) {
@@ -87,12 +91,32 @@ function checkEnv {
${script:SignTool}=${env:SIGN_TOOL}
}
if ("${env:KEY_CONTAINER}") {
- ${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
- Write-host "Code signing enabled"
+ if (Test-Path "${script:SRC_DIR}\ollama_inc.crt") {
+ ${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
+ Write-host "Code signing enabled"
+ } else {
+ Write-Output "WARNING: KEY_CONTAINER is set but ollama_inc.crt not found at ${script:SRC_DIR}\ollama_inc.crt - code signing disabled"
+ }
} else {
- write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
+ Write-Output "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
}
- $script:JOBS=([Environment]::ProcessorCount)
+ if ($env:OLLAMA_BUILD_PARALLEL) {
+ $script:JOBS=[int]$env:OLLAMA_BUILD_PARALLEL
+ } else {
+ # Use physical core count rather than logical processors (hyperthreads)
+ # to avoid saturating the system during builds
+ try {
+ $cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfCores -Sum).Sum
+ } catch {
+ $cores = 0
+ }
+ if ($cores -gt 0) {
+ $script:JOBS = $cores
+ } else {
+ $script:JOBS = [Environment]::ProcessorCount
+ }
+ }
+ Write-Output "Build parallelism: $script:JOBS (set OLLAMA_BUILD_PARALLEL to override)"
}
@@ -127,7 +151,7 @@ function cuda11 {
}
}
}
- write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
+ Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix "$script:DIST_DIR"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -136,12 +160,12 @@ function cuda11 {
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
- write-host "CUDA v$cudaMajorVer not detected, skipping"
+ Write-Output "CUDA v$cudaMajorVer not detected, skipping"
}
} else {
- write-host "not arch we wanted"
+ Write-Output "not arch we wanted"
}
- write-host "done"
+ Write-Output "done"
}
function cudaCommon {
@@ -159,7 +183,7 @@ function cudaCommon {
}
}
}
- write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
+ Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -168,7 +192,7 @@ function cudaCommon {
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
- write-host "CUDA v$cudaMajorVer not detected, skipping"
+ Write-Output "CUDA v$cudaMajorVer not detected, skipping"
}
}
}
@@ -181,11 +205,11 @@ function cuda13 {
cudaCommon("13")
}
-function rocm {
+function rocm6 {
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
if ($script:ARCH -ne "arm64") {
if ($script:HIP_PATH) {
- write-host "Building ROCm backend libraries $script:HIP_PATH"
+ Write-Output "Building ROCm backend libraries $script:HIP_PATH"
if (-Not (get-command -ErrorAction silent ninja)) {
$NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe).Directory.FullName
$env:PATH="$NINJA_DIR;$env:PATH"
@@ -193,9 +217,11 @@ function rocm {
$env:HIPCXX="${script:HIP_PATH}\bin\clang++.exe"
$env:HIP_PLATFORM="amd"
$env:CMAKE_PREFIX_PATH="${script:HIP_PATH}"
+ # Set CC/CXX via environment instead of -D flags to avoid triggering
+ # spurious compiler-change reconfigures that reset CMAKE_INSTALL_PREFIX
+ $env:CC="${script:HIP_PATH}\bin\clang.exe"
+ $env:CXX="${script:HIP_PATH}\bin\clang++.exe"
& cmake -B build\rocm --preset "ROCm 6" -G Ninja `
- -DCMAKE_C_COMPILER=clang `
- -DCMAKE_CXX_COMPILER=clang++ `
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
--install-prefix $script:DIST_DIR
@@ -203,20 +229,22 @@ function rocm {
$env:HIPCXX=""
$env:HIP_PLATFORM=""
$env:CMAKE_PREFIX_PATH=""
+ $env:CC=""
+ $env:CXX=""
& cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build\rocm --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
} else {
- write-host "ROCm not detected, skipping"
+ Write-Output "ROCm not detected, skipping"
}
}
}
function vulkan {
if ($env:VULKAN_SDK) {
- write-host "Building Vulkan backend libraries"
+ Write-Output "Building Vulkan backend libraries"
& cmake -B build\vulkan --preset Vulkan --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build build\vulkan --target ggml-vulkan --config Release --parallel $script:JOBS
@@ -224,33 +252,91 @@ function vulkan {
& cmake --install build\vulkan --component Vulkan --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
- write-host "Vulkan not detected, skipping"
+ Write-Output "Vulkan not detected, skipping"
+ }
+}
+
+function mlxCuda13 {
+ mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
+ $cudaMajorVer="13"
+ if ($script:ARCH -ne "arm64") {
+ if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) {
+ foreach ($d in $Script:CUDA_DIRS){
+ if ($d.FullName.Contains("v$cudaMajorVer")) {
+ if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) {
+ $cuda=($d.FullName|split-path -parent)
+ break
+ }
+ }
+ }
+
+ # Check for cuDNN - required for MLX CUDA backend
+ # Supports two layouts:
+ # 1. CI/zip extract: CUDNN\include\cudnn.h, lib\x64\, bin\x64\
+ # 2. Official installer: CUDNN\v*\include\{cuda-ver}\cudnn.h, lib\{cuda-ver}\x64\, bin\{cuda-ver}\
+ if ($env:CUDNN_INCLUDE_PATH -and $env:CUDNN_LIBRARY_PATH) {
+ Write-Output "Using cuDNN from environment: $env:CUDNN_INCLUDE_PATH"
+ } elseif (Test-Path "C:\Program Files\NVIDIA\CUDNN\include\cudnn.h") {
+ # CI/zip layout (flat)
+ $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
+ $env:CUDNN_ROOT_DIR = $cudnnRoot
+ $env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include"
+ $env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\x64"
+ Write-Output "Found cuDNN at $cudnnRoot (flat layout)"
+ } else {
+ # Official installer layout (versioned)
+ $cudnnRoot = $null
+ $resolved = Resolve-Path -Path "C:\Program Files\NVIDIA\CUDNN\v*" -ErrorAction SilentlyContinue | Sort-Object -Descending | Select-Object -First 1
+ if ($resolved -and (Test-Path "$($resolved.Path)\include\$cudaMajorVer.0\cudnn.h")) {
+ $cudnnRoot = $resolved.Path
+ $env:CUDNN_ROOT_DIR = $cudnnRoot
+ $env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include\$cudaMajorVer.0"
+ $env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\$cudaMajorVer.0\x64"
+ Write-Output "Found cuDNN at $cudnnRoot (official installer, CUDA $cudaMajorVer.0)"
+ } else {
+ Write-Output "cuDNN not found - set CUDNN_INCLUDE_PATH and CUDNN_LIBRARY_PATH environment variables"
+ Write-Output "Skipping MLX build"
+ return
+ }
+ }
+
+ Write-Output "Building MLX CUDA v$cudaMajorVer backend libraries $cuda"
+ $env:CUDAToolkit_ROOT=$cuda
+ & cmake -B build\mlx_cuda_v$cudaMajorVer --preset "MLX CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
+ if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
+ & cmake --build build\mlx_cuda_v$cudaMajorVer --target mlx --target mlxc --config Release --parallel $script:JOBS -- /nodeReuse:false
+ if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
+ & cmake --install build\mlx_cuda_v$cudaMajorVer --component "MLX" --strip
+ if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
+ } else {
+ Write-Output "CUDA v$cudaMajorVer not detected, skipping MLX build"
+ }
}
}
function ollama {
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
- write-host "Building ollama CLI"
+ Write-Output "Building ollama CLI"
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
cp .\ollama.exe "${script:DIST_DIR}\"
}
function app {
- write-host "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
+ Write-Output "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
if (!(Get-Command npm -ErrorAction SilentlyContinue)) {
- write-host "npm is not installed. Please install Node.js and npm first:"
- write-host " Visit: https://nodejs.org/"
+ Write-Output "npm is not installed. Please install Node.js and npm first:"
+ Write-Output " Visit: https://nodejs.org/"
exit 1
}
if (!(Get-Command tsc -ErrorAction SilentlyContinue)) {
- write-host "Installing TypeScript compiler..."
+ Write-Output "Installing TypeScript compiler..."
npm install -g typescript
}
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
- write-host "Installing tscriptify..."
+ Write-Output "Installing tscriptify..."
go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest
}
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
@@ -260,32 +346,32 @@ function app {
Push-Location app/ui/app
npm install
if ($LASTEXITCODE -ne 0) {
- write-host "ERROR: npm install failed with exit code $LASTEXITCODE"
+ Write-Output "ERROR: npm install failed with exit code $LASTEXITCODE"
exit $LASTEXITCODE
}
- write-host "Building React application..."
+ Write-Output "Building React application..."
npm run build
if ($LASTEXITCODE -ne 0) {
- write-host "ERROR: npm run build failed with exit code $LASTEXITCODE"
+ Write-Output "ERROR: npm run build failed with exit code $LASTEXITCODE"
exit $LASTEXITCODE
}
# Check if dist directory exists and has content
if (!(Test-Path "dist")) {
- write-host "ERROR: dist directory was not created by npm run build"
+ Write-Output "ERROR: dist directory was not created by npm run build"
exit 1
}
$distFiles = Get-ChildItem "dist" -Recurse
if ($distFiles.Count -eq 0) {
- write-host "ERROR: dist directory is empty after npm run build"
+ Write-Output "ERROR: dist directory is empty after npm run build"
exit 1
}
Pop-Location
- write-host "Running go generate"
+ Write-Output "Running go generate"
& go generate ./...
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/ollama/ollama/app/version.Version=$script:VERSION" -o .\dist\windows-ollama-app-${script:ARCH}.exe ./app/cmd/app/
@@ -293,42 +379,42 @@ function app {
}
function deps {
- write-host "Download MSVC Redistributables"
+ Write-Output "Download MSVC Redistributables"
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-arm64" | Out-Null
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-amd64" | Out-Null
- invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe"
- invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe"
- write-host "Done."
+ invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" -ErrorAction Stop
+ invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" -ErrorAction Stop
+ Write-Output "Done."
}
function sign {
# Copy install.ps1 to dist for release packaging
- write-host "Copying install.ps1 to dist"
- Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
+ Write-Output "Copying install.ps1 to dist"
+ Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" -ErrorAction Stop
if ("${env:KEY_CONTAINER}") {
- write-host "Signing Ollama executables, scripts and libraries"
+ Write-Output "Signing Ollama executables, scripts and libraries"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
- write-host "Signing install.ps1"
+ Write-Output "Signing install.ps1"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
"${script:SRC_DIR}\dist\install.ps1"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
- write-host "Signing not enabled"
+ Write-Output "Signing not enabled"
}
}
function installer {
if ($null -eq ${script:INNO_SETUP_DIR}) {
- write-host "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
+ Write-Output "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
exit 1
}
- write-host "Building Ollama Installer"
+ Write-Output "Building Ollama Installer"
cd "${script:SRC_DIR}\app"
$env:PKG_VERSION=$script:PKG_VERSION
if ("${env:KEY_CONTAINER}") {
@@ -342,24 +428,24 @@ function installer {
function zip {
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") {
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") {
- write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
+ Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
# Temporarily adjust paths so we can retain the same directory structure
Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm"
mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt"
- Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
+ Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" -ErrorAction Stop
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force
}
- write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
+ Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") {
- Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm"
+ Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" -ErrorAction Stop
}
}
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") {
- write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
+ Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force
}
}
@@ -375,8 +461,9 @@ try {
cpu
cuda12
cuda13
- rocm
+ rocm6
vulkan
+ mlxCuda13
ollama
app
deps
@@ -385,13 +472,13 @@ try {
zip
} else {
for ( $i = 0; $i -lt $args.count; $i++ ) {
- write-host "running build step $($args[$i])"
+ Write-Output "running build step $($args[$i])"
& $($args[$i])
}
}
} catch {
- write-host "Build Failed"
- write-host $_
+ Write-Error "Build Failed: $($_.Exception.Message)"
+ Write-Error "$($_.ScriptStackTrace)"
} finally {
set-location $script:SRC_DIR
$env:PKG_VERSION=""
diff --git a/scripts/env.sh b/scripts/env.sh
index 65a970bdc..3e6d69cc9 100644
--- a/scripts/env.sh
+++ b/scripts/env.sh
@@ -18,6 +18,14 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=GPU_RUNNER_CPU_FLAGS \
--build-arg=AMDGPU_TARGETS"
+# Forward local MLX source overrides as Docker build contexts
+if [ -n "${OLLAMA_MLX_SOURCE:-}" ]; then
+ OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx=$(cd "$OLLAMA_MLX_SOURCE" && pwd)"
+fi
+if [ -n "${OLLAMA_MLX_C_SOURCE:-}" ]; then
+ OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx-c=$(cd "$OLLAMA_MLX_C_SOURCE" && pwd)"
+fi
+
echo "Building Ollama"
echo "VERSION=$VERSION"
echo "PLATFORM=$PLATFORM"
\ No newline at end of file
diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go
index 6ccb1ad6d..bb379bbac 100644
--- a/x/create/client/quantize.go
+++ b/x/create/client/quantize.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package client
import (
@@ -194,9 +192,10 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
return blobData, nil
}
-// QuantizeSupported returns true if quantization is supported (MLX build)
+// QuantizeSupported returns true if quantization is supported (MLX library available)
func QuantizeSupported() bool {
- return true
+ mlx.InitMLX()
+ return mlx.IsMLXAvailable()
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
diff --git a/x/create/client/quantize_stub.go b/x/create/client/quantize_stub.go
deleted file mode 100644
index 7a75671a0..000000000
--- a/x/create/client/quantize_stub.go
+++ /dev/null
@@ -1,25 +0,0 @@
-//go:build !mlx
-
-package client
-
-import (
- "fmt"
- "io"
-
- "github.com/ollama/ollama/x/create"
-)
-
-// quantizeTensor is not available without MLX
-func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
- return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
-}
-
-// quantizePackedGroup is not available without MLX
-func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
- return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
-}
-
-// QuantizeSupported returns false when MLX is not available
-func QuantizeSupported() bool {
- return false
-}
diff --git a/x/imagegen/cache/cache.go b/x/imagegen/cache/cache.go
index 8a25193cd..d4e19ba55 100644
--- a/x/imagegen/cache/cache.go
+++ b/x/imagegen/cache/cache.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
diff --git a/x/imagegen/cache/step.go b/x/imagegen/cache/step.go
index f91f22fa0..066f2f645 100644
--- a/x/imagegen/cache/step.go
+++ b/x/imagegen/cache/step.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
diff --git a/x/imagegen/cache/teacache.go b/x/imagegen/cache/teacache.go
index 60031d8cb..fb06047ea 100644
--- a/x/imagegen/cache/teacache.go
+++ b/x/imagegen/cache/teacache.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package cache provides caching mechanisms for diffusion model inference.
package cache
diff --git a/x/imagegen/cmd/engine/README.md b/x/imagegen/cmd/engine/README.md
index 3991c02a8..e6ab2f5d3 100644
--- a/x/imagegen/cmd/engine/README.md
+++ b/x/imagegen/cmd/engine/README.md
@@ -5,7 +5,7 @@ Experimental MLX backend for running models on Apple Silicon and CUDA.
## Build
```bash
-go build -tags mlx -o engine ./x/imagegen/cmd/engine
+go build -o engine ./x/imagegen/cmd/engine
```
## Text Generation
diff --git a/x/imagegen/cmd/engine/generate.go b/x/imagegen/cmd/engine/generate.go
index 51118afc1..95173cd80 100644
--- a/x/imagegen/cmd/engine/generate.go
+++ b/x/imagegen/cmd/engine/generate.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package main
import (
diff --git a/x/imagegen/cmd/engine/image.go b/x/imagegen/cmd/engine/image.go
index e8af2222a..3c393cf66 100644
--- a/x/imagegen/cmd/engine/image.go
+++ b/x/imagegen/cmd/engine/image.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package main
import (
diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go
index 6ec7de9e1..31411f466 100644
--- a/x/imagegen/cmd/engine/main.go
+++ b/x/imagegen/cmd/engine/main.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package main
import (
diff --git a/x/imagegen/cmd/engine/sample.go b/x/imagegen/cmd/engine/sample.go
index 5d723e6dc..165c40774 100644
--- a/x/imagegen/cmd/engine/sample.go
+++ b/x/imagegen/cmd/engine/sample.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package main
import "github.com/ollama/ollama/x/imagegen/mlx"
diff --git a/x/imagegen/image.go b/x/imagegen/image.go
index 2dca0ee1d..9bdd95304 100644
--- a/x/imagegen/image.go
+++ b/x/imagegen/image.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package imagegen
import (
diff --git a/x/imagegen/image_processor.go b/x/imagegen/image_processor.go
index 7a562feb5..c3e68ebb3 100644
--- a/x/imagegen/image_processor.go
+++ b/x/imagegen/image_processor.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package imagegen
import (
diff --git a/x/imagegen/imagegen.go b/x/imagegen/imagegen.go
index d870bed9b..c4b586505 100644
--- a/x/imagegen/imagegen.go
+++ b/x/imagegen/imagegen.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package imagegen
import (
diff --git a/x/imagegen/manifest/weights.go b/x/imagegen/manifest/weights.go
index e1209c9db..fcb30449e 100644
--- a/x/imagegen/manifest/weights.go
+++ b/x/imagegen/manifest/weights.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package manifest
import (
diff --git a/x/imagegen/mlx/CMakeLists.txt b/x/imagegen/mlx/CMakeLists.txt
index b62cbf2eb..70246ef4b 100644
--- a/x/imagegen/mlx/CMakeLists.txt
+++ b/x/imagegen/mlx/CMakeLists.txt
@@ -4,6 +4,10 @@ include(FetchContent)
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
+# Read MLX core version from top-level file
+file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG)
+string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
+
set(MLX_C_BUILD_EXAMPLES OFF)
set(MLX_BUILD_GGUF OFF)
@@ -43,6 +47,17 @@ if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
endif()
+# Forward cuDNN environment variables to cmake variables so MLX's FindCUDNN.cmake
+# can find them via HINTS ${CUDNN_INCLUDE_PATH} / ${CUDNN_LIBRARY_PATH}.
+if(DEFINED ENV{CUDNN_INCLUDE_PATH} AND NOT CUDNN_INCLUDE_PATH)
+ set(CUDNN_INCLUDE_PATH "$ENV{CUDNN_INCLUDE_PATH}" CACHE PATH "cuDNN include path")
+ message(STATUS "Using CUDNN_INCLUDE_PATH from environment: ${CUDNN_INCLUDE_PATH}")
+endif()
+if(DEFINED ENV{CUDNN_LIBRARY_PATH} AND NOT CUDNN_LIBRARY_PATH)
+ set(CUDNN_LIBRARY_PATH "$ENV{CUDNN_LIBRARY_PATH}" CACHE PATH "cuDNN library path")
+ message(STATUS "Using CUDNN_LIBRARY_PATH from environment: ${CUDNN_LIBRARY_PATH}")
+endif()
+
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
@@ -51,11 +66,58 @@ elseif(MLX_CUDA_ARCHITECTURES)
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
endif()
+# Allow local source overrides via environment variables
+# Resolve to absolute paths so FetchContent doesn't break on relative dirs.
+if(DEFINED ENV{OLLAMA_MLX_SOURCE})
+ get_filename_component(_mlx_src "$ENV{OLLAMA_MLX_SOURCE}" ABSOLUTE BASE_DIR ${CMAKE_SOURCE_DIR})
+ set(FETCHCONTENT_SOURCE_DIR_MLX "${_mlx_src}" CACHE PATH "" FORCE)
+ message(STATUS "Using local MLX source: ${_mlx_src}")
+endif()
+if(DEFINED ENV{OLLAMA_MLX_C_SOURCE})
+ get_filename_component(_mlx_c_src "$ENV{OLLAMA_MLX_C_SOURCE}" ABSOLUTE BASE_DIR ${CMAKE_SOURCE_DIR})
+ set(FETCHCONTENT_SOURCE_DIR_MLX-C "${_mlx_c_src}" CACHE PATH "" FORCE)
+ message(STATUS "Using local MLX-C source: ${_mlx_c_src}")
+endif()
+
+# Pre-declare mlx so our pinned version takes precedence over the one
+# hardcoded in mlx-c's CMakeLists.txt (first FetchContent_Declare wins).
FetchContent_Declare(
- mlx-c
- GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
- GIT_TAG ${MLX_C_GIT_TAG})
+ mlx
+ GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
+ GIT_TAG ${MLX_GIT_TAG}
+)
+
+FetchContent_Declare(
+ mlx-c
+ GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
+ GIT_TAG ${MLX_C_GIT_TAG}
+)
FetchContent_MakeAvailable(mlx-c)
+# Sync vendored headers with fetched version
+file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
+file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
+
+# For local dev builds, override MLX_VERSION with git describe output
+if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
+ execute_process(
+ COMMAND git describe --tags --first-parent --abbrev=7 --long --dirty --always
+ WORKING_DIRECTORY ${mlx_SOURCE_DIR}
+ OUTPUT_VARIABLE _mlx_git_version
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ ERROR_QUIET
+ RESULT_VARIABLE _mlx_git_result
+ )
+ if(_mlx_git_result EQUAL 0 AND _mlx_git_version)
+ # Strip leading "v" prefix for consistency
+ string(REGEX REPLACE "^v" "" _mlx_git_version "${_mlx_git_version}")
+ get_target_property(_mlx_defs mlx_version COMPILE_DEFINITIONS)
+ list(FILTER _mlx_defs EXCLUDE REGEX "^MLX_VERSION=")
+ set_target_properties(mlx_version PROPERTIES COMPILE_DEFINITIONS "${_mlx_defs}")
+ target_compile_definitions(mlx_version PRIVATE "MLX_VERSION=\"${_mlx_git_version}\"")
+ message(STATUS "MLX version (local dev): ${_mlx_git_version}")
+ endif()
+endif()
+
set_target_output_directory(mlx)
set_target_output_directory(mlxc)
diff --git a/x/imagegen/mlx/compile.go b/x/imagegen/mlx/compile.go
index 0dd2dd02a..746e6eaf6 100644
--- a/x/imagegen/mlx/compile.go
+++ b/x/imagegen/mlx/compile.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
/*
diff --git a/x/imagegen/mlx/doc.go b/x/imagegen/mlx/doc.go
index ced1802b0..5410f3b2d 100644
--- a/x/imagegen/mlx/doc.go
+++ b/x/imagegen/mlx/doc.go
@@ -1,6 +1,4 @@
-//go:build mlx
-
// Package mlx provides Go bindings for the MLX-C library with dynamic loading support.
//
-//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c
+//go:generate go run generate_wrappers.go ../../mlxrunner/mlx/include/mlx/c mlx.h mlx.c
package mlx
diff --git a/x/imagegen/mlx/generate_wrappers.go b/x/imagegen/mlx/generate_wrappers.go
index 8aa5bd0c8..114ac2a15 100644
--- a/x/imagegen/mlx/generate_wrappers.go
+++ b/x/imagegen/mlx/generate_wrappers.go
@@ -291,8 +291,15 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err
implBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
implBuf.WriteString("#include \"mlx_dynamic.h\"\n")
- implBuf.WriteString("#include \n")
- implBuf.WriteString("#include \n\n")
+ implBuf.WriteString("#include \n\n")
+ implBuf.WriteString("// Platform-specific dynamic loading\n")
+ implBuf.WriteString("#ifdef _WIN32\n")
+ implBuf.WriteString("#include \n")
+ implBuf.WriteString("#define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name)\n")
+ implBuf.WriteString("#else\n")
+ implBuf.WriteString("#include \n")
+ implBuf.WriteString("#define GET_SYM(handle, name) dlsym(handle, name)\n")
+ implBuf.WriteString("#endif\n\n")
// Function pointer definitions
implBuf.WriteString("// Function pointer definitions\n")
@@ -308,7 +315,7 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err
implBuf.WriteString("\n")
// Initialization function
- implBuf.WriteString("// Initialize all function pointers via dlsym\n")
+ implBuf.WriteString("// Initialize all function pointers\n")
implBuf.WriteString("int mlx_load_functions(void* handle) {\n")
implBuf.WriteString(" if (handle == NULL) {\n")
implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n")
@@ -319,7 +326,7 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err
if fn.NeedsARM64Guard {
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
- implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name))
+ implBuf.WriteString(fmt.Sprintf(" %s_ptr = GET_SYM(handle, \"%s\");\n", fn.Name, fn.Name))
implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name))
implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name))
implBuf.WriteString(" return -1;\n")
diff --git a/x/imagegen/mlx/mlx.c b/x/imagegen/mlx/mlx.c
index 770b60922..b0ccbacdf 100644
--- a/x/imagegen/mlx/mlx.c
+++ b/x/imagegen/mlx/mlx.c
@@ -5,7 +5,15 @@
#include "mlx/c/mlx.h"
#include "mlx_dynamic.h"
#include
+
+// Platform-specific dynamic loading
+#ifdef _WIN32
+#include
+#define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name)
+#else
#include
+#define GET_SYM(handle, name) dlsym(handle, name)
+#endif
// Function pointer definitions
size_t (*mlx_dtype_size_ptr)(mlx_dtype dtype) = NULL;
@@ -603,2947 +611,2947 @@ size_t (*mlx_vector_string_size_ptr)(mlx_vector_string vec) = NULL;
int (*mlx_vector_string_get_ptr)(char** res, const mlx_vector_string vec, size_t idx) = NULL;
int (*mlx_version_ptr)(mlx_string* str_) = NULL;
-// Initialize all function pointers via dlsym
+// Initialize all function pointers
int mlx_load_functions(void* handle) {
if (handle == NULL) {
fprintf(stderr, "MLX: Invalid library handle\n");
return -1;
}
- mlx_dtype_size_ptr = dlsym(handle, "mlx_dtype_size");
+ mlx_dtype_size_ptr = GET_SYM(handle, "mlx_dtype_size");
if (mlx_dtype_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_dtype_size\n");
return -1;
}
- mlx_array_tostring_ptr = dlsym(handle, "mlx_array_tostring");
+ mlx_array_tostring_ptr = GET_SYM(handle, "mlx_array_tostring");
if (mlx_array_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_tostring\n");
return -1;
}
- mlx_array_new_ptr = dlsym(handle, "mlx_array_new");
+ mlx_array_new_ptr = GET_SYM(handle, "mlx_array_new");
if (mlx_array_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new\n");
return -1;
}
- mlx_array_free_ptr = dlsym(handle, "mlx_array_free");
+ mlx_array_free_ptr = GET_SYM(handle, "mlx_array_free");
if (mlx_array_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_free\n");
return -1;
}
- mlx_array_new_bool_ptr = dlsym(handle, "mlx_array_new_bool");
+ mlx_array_new_bool_ptr = GET_SYM(handle, "mlx_array_new_bool");
if (mlx_array_new_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_bool\n");
return -1;
}
- mlx_array_new_int_ptr = dlsym(handle, "mlx_array_new_int");
+ mlx_array_new_int_ptr = GET_SYM(handle, "mlx_array_new_int");
if (mlx_array_new_int_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_int\n");
return -1;
}
- mlx_array_new_float32_ptr = dlsym(handle, "mlx_array_new_float32");
+ mlx_array_new_float32_ptr = GET_SYM(handle, "mlx_array_new_float32");
if (mlx_array_new_float32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float32\n");
return -1;
}
- mlx_array_new_float_ptr = dlsym(handle, "mlx_array_new_float");
+ mlx_array_new_float_ptr = GET_SYM(handle, "mlx_array_new_float");
if (mlx_array_new_float_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float\n");
return -1;
}
- mlx_array_new_float64_ptr = dlsym(handle, "mlx_array_new_float64");
+ mlx_array_new_float64_ptr = GET_SYM(handle, "mlx_array_new_float64");
if (mlx_array_new_float64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float64\n");
return -1;
}
- mlx_array_new_double_ptr = dlsym(handle, "mlx_array_new_double");
+ mlx_array_new_double_ptr = GET_SYM(handle, "mlx_array_new_double");
if (mlx_array_new_double_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_double\n");
return -1;
}
- mlx_array_new_complex_ptr = dlsym(handle, "mlx_array_new_complex");
+ mlx_array_new_complex_ptr = GET_SYM(handle, "mlx_array_new_complex");
if (mlx_array_new_complex_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_complex\n");
return -1;
}
- mlx_array_new_data_ptr = dlsym(handle, "mlx_array_new_data");
+ mlx_array_new_data_ptr = GET_SYM(handle, "mlx_array_new_data");
if (mlx_array_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
return -1;
}
- mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
+ mlx_array_new_data_managed_ptr = GET_SYM(handle, "mlx_array_new_data_managed");
if (mlx_array_new_data_managed_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
return -1;
}
- mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
+ mlx_array_new_data_managed_payload_ptr = GET_SYM(handle, "mlx_array_new_data_managed_payload");
if (mlx_array_new_data_managed_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
return -1;
}
- mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
+ mlx_array_set_ptr = GET_SYM(handle, "mlx_array_set");
if (mlx_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
return -1;
}
- mlx_array_set_bool_ptr = dlsym(handle, "mlx_array_set_bool");
+ mlx_array_set_bool_ptr = GET_SYM(handle, "mlx_array_set_bool");
if (mlx_array_set_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_bool\n");
return -1;
}
- mlx_array_set_int_ptr = dlsym(handle, "mlx_array_set_int");
+ mlx_array_set_int_ptr = GET_SYM(handle, "mlx_array_set_int");
if (mlx_array_set_int_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_int\n");
return -1;
}
- mlx_array_set_float32_ptr = dlsym(handle, "mlx_array_set_float32");
+ mlx_array_set_float32_ptr = GET_SYM(handle, "mlx_array_set_float32");
if (mlx_array_set_float32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float32\n");
return -1;
}
- mlx_array_set_float_ptr = dlsym(handle, "mlx_array_set_float");
+ mlx_array_set_float_ptr = GET_SYM(handle, "mlx_array_set_float");
if (mlx_array_set_float_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float\n");
return -1;
}
- mlx_array_set_float64_ptr = dlsym(handle, "mlx_array_set_float64");
+ mlx_array_set_float64_ptr = GET_SYM(handle, "mlx_array_set_float64");
if (mlx_array_set_float64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float64\n");
return -1;
}
- mlx_array_set_double_ptr = dlsym(handle, "mlx_array_set_double");
+ mlx_array_set_double_ptr = GET_SYM(handle, "mlx_array_set_double");
if (mlx_array_set_double_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_double\n");
return -1;
}
- mlx_array_set_complex_ptr = dlsym(handle, "mlx_array_set_complex");
+ mlx_array_set_complex_ptr = GET_SYM(handle, "mlx_array_set_complex");
if (mlx_array_set_complex_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_complex\n");
return -1;
}
- mlx_array_set_data_ptr = dlsym(handle, "mlx_array_set_data");
+ mlx_array_set_data_ptr = GET_SYM(handle, "mlx_array_set_data");
if (mlx_array_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_data\n");
return -1;
}
- mlx_array_itemsize_ptr = dlsym(handle, "mlx_array_itemsize");
+ mlx_array_itemsize_ptr = GET_SYM(handle, "mlx_array_itemsize");
if (mlx_array_itemsize_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_itemsize\n");
return -1;
}
- mlx_array_size_ptr = dlsym(handle, "mlx_array_size");
+ mlx_array_size_ptr = GET_SYM(handle, "mlx_array_size");
if (mlx_array_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_size\n");
return -1;
}
- mlx_array_nbytes_ptr = dlsym(handle, "mlx_array_nbytes");
+ mlx_array_nbytes_ptr = GET_SYM(handle, "mlx_array_nbytes");
if (mlx_array_nbytes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_nbytes\n");
return -1;
}
- mlx_array_ndim_ptr = dlsym(handle, "mlx_array_ndim");
+ mlx_array_ndim_ptr = GET_SYM(handle, "mlx_array_ndim");
if (mlx_array_ndim_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_ndim\n");
return -1;
}
- mlx_array_shape_ptr = dlsym(handle, "mlx_array_shape");
+ mlx_array_shape_ptr = GET_SYM(handle, "mlx_array_shape");
if (mlx_array_shape_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_shape\n");
return -1;
}
- mlx_array_strides_ptr = dlsym(handle, "mlx_array_strides");
+ mlx_array_strides_ptr = GET_SYM(handle, "mlx_array_strides");
if (mlx_array_strides_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_strides\n");
return -1;
}
- mlx_array_dim_ptr = dlsym(handle, "mlx_array_dim");
+ mlx_array_dim_ptr = GET_SYM(handle, "mlx_array_dim");
if (mlx_array_dim_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dim\n");
return -1;
}
- mlx_array_dtype_ptr = dlsym(handle, "mlx_array_dtype");
+ mlx_array_dtype_ptr = GET_SYM(handle, "mlx_array_dtype");
if (mlx_array_dtype_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dtype\n");
return -1;
}
- mlx_array_eval_ptr = dlsym(handle, "mlx_array_eval");
+ mlx_array_eval_ptr = GET_SYM(handle, "mlx_array_eval");
if (mlx_array_eval_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_eval\n");
return -1;
}
- mlx_array_item_bool_ptr = dlsym(handle, "mlx_array_item_bool");
+ mlx_array_item_bool_ptr = GET_SYM(handle, "mlx_array_item_bool");
if (mlx_array_item_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bool\n");
return -1;
}
- mlx_array_item_uint8_ptr = dlsym(handle, "mlx_array_item_uint8");
+ mlx_array_item_uint8_ptr = GET_SYM(handle, "mlx_array_item_uint8");
if (mlx_array_item_uint8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint8\n");
return -1;
}
- mlx_array_item_uint16_ptr = dlsym(handle, "mlx_array_item_uint16");
+ mlx_array_item_uint16_ptr = GET_SYM(handle, "mlx_array_item_uint16");
if (mlx_array_item_uint16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint16\n");
return -1;
}
- mlx_array_item_uint32_ptr = dlsym(handle, "mlx_array_item_uint32");
+ mlx_array_item_uint32_ptr = GET_SYM(handle, "mlx_array_item_uint32");
if (mlx_array_item_uint32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint32\n");
return -1;
}
- mlx_array_item_uint64_ptr = dlsym(handle, "mlx_array_item_uint64");
+ mlx_array_item_uint64_ptr = GET_SYM(handle, "mlx_array_item_uint64");
if (mlx_array_item_uint64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint64\n");
return -1;
}
- mlx_array_item_int8_ptr = dlsym(handle, "mlx_array_item_int8");
+ mlx_array_item_int8_ptr = GET_SYM(handle, "mlx_array_item_int8");
if (mlx_array_item_int8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int8\n");
return -1;
}
- mlx_array_item_int16_ptr = dlsym(handle, "mlx_array_item_int16");
+ mlx_array_item_int16_ptr = GET_SYM(handle, "mlx_array_item_int16");
if (mlx_array_item_int16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int16\n");
return -1;
}
- mlx_array_item_int32_ptr = dlsym(handle, "mlx_array_item_int32");
+ mlx_array_item_int32_ptr = GET_SYM(handle, "mlx_array_item_int32");
if (mlx_array_item_int32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int32\n");
return -1;
}
- mlx_array_item_int64_ptr = dlsym(handle, "mlx_array_item_int64");
+ mlx_array_item_int64_ptr = GET_SYM(handle, "mlx_array_item_int64");
if (mlx_array_item_int64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int64\n");
return -1;
}
- mlx_array_item_float32_ptr = dlsym(handle, "mlx_array_item_float32");
+ mlx_array_item_float32_ptr = GET_SYM(handle, "mlx_array_item_float32");
if (mlx_array_item_float32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float32\n");
return -1;
}
- mlx_array_item_float64_ptr = dlsym(handle, "mlx_array_item_float64");
+ mlx_array_item_float64_ptr = GET_SYM(handle, "mlx_array_item_float64");
if (mlx_array_item_float64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float64\n");
return -1;
}
- mlx_array_item_complex64_ptr = dlsym(handle, "mlx_array_item_complex64");
+ mlx_array_item_complex64_ptr = GET_SYM(handle, "mlx_array_item_complex64");
if (mlx_array_item_complex64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_complex64\n");
return -1;
}
#if defined(__aarch64__) || defined(_M_ARM64)
- mlx_array_item_float16_ptr = dlsym(handle, "mlx_array_item_float16");
+ mlx_array_item_float16_ptr = GET_SYM(handle, "mlx_array_item_float16");
if (mlx_array_item_float16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float16\n");
return -1;
}
#endif
#if defined(__aarch64__) || defined(_M_ARM64)
- mlx_array_item_bfloat16_ptr = dlsym(handle, "mlx_array_item_bfloat16");
+ mlx_array_item_bfloat16_ptr = GET_SYM(handle, "mlx_array_item_bfloat16");
if (mlx_array_item_bfloat16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bfloat16\n");
return -1;
}
#endif
- mlx_array_data_bool_ptr = dlsym(handle, "mlx_array_data_bool");
+ mlx_array_data_bool_ptr = GET_SYM(handle, "mlx_array_data_bool");
if (mlx_array_data_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bool\n");
return -1;
}
- mlx_array_data_uint8_ptr = dlsym(handle, "mlx_array_data_uint8");
+ mlx_array_data_uint8_ptr = GET_SYM(handle, "mlx_array_data_uint8");
if (mlx_array_data_uint8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint8\n");
return -1;
}
- mlx_array_data_uint16_ptr = dlsym(handle, "mlx_array_data_uint16");
+ mlx_array_data_uint16_ptr = GET_SYM(handle, "mlx_array_data_uint16");
if (mlx_array_data_uint16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint16\n");
return -1;
}
- mlx_array_data_uint32_ptr = dlsym(handle, "mlx_array_data_uint32");
+ mlx_array_data_uint32_ptr = GET_SYM(handle, "mlx_array_data_uint32");
if (mlx_array_data_uint32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint32\n");
return -1;
}
- mlx_array_data_uint64_ptr = dlsym(handle, "mlx_array_data_uint64");
+ mlx_array_data_uint64_ptr = GET_SYM(handle, "mlx_array_data_uint64");
if (mlx_array_data_uint64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint64\n");
return -1;
}
- mlx_array_data_int8_ptr = dlsym(handle, "mlx_array_data_int8");
+ mlx_array_data_int8_ptr = GET_SYM(handle, "mlx_array_data_int8");
if (mlx_array_data_int8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int8\n");
return -1;
}
- mlx_array_data_int16_ptr = dlsym(handle, "mlx_array_data_int16");
+ mlx_array_data_int16_ptr = GET_SYM(handle, "mlx_array_data_int16");
if (mlx_array_data_int16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int16\n");
return -1;
}
- mlx_array_data_int32_ptr = dlsym(handle, "mlx_array_data_int32");
+ mlx_array_data_int32_ptr = GET_SYM(handle, "mlx_array_data_int32");
if (mlx_array_data_int32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int32\n");
return -1;
}
- mlx_array_data_int64_ptr = dlsym(handle, "mlx_array_data_int64");
+ mlx_array_data_int64_ptr = GET_SYM(handle, "mlx_array_data_int64");
if (mlx_array_data_int64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int64\n");
return -1;
}
- mlx_array_data_float32_ptr = dlsym(handle, "mlx_array_data_float32");
+ mlx_array_data_float32_ptr = GET_SYM(handle, "mlx_array_data_float32");
if (mlx_array_data_float32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float32\n");
return -1;
}
- mlx_array_data_float64_ptr = dlsym(handle, "mlx_array_data_float64");
+ mlx_array_data_float64_ptr = GET_SYM(handle, "mlx_array_data_float64");
if (mlx_array_data_float64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float64\n");
return -1;
}
- mlx_array_data_complex64_ptr = dlsym(handle, "mlx_array_data_complex64");
+ mlx_array_data_complex64_ptr = GET_SYM(handle, "mlx_array_data_complex64");
if (mlx_array_data_complex64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_complex64\n");
return -1;
}
#if defined(__aarch64__) || defined(_M_ARM64)
- mlx_array_data_float16_ptr = dlsym(handle, "mlx_array_data_float16");
+ mlx_array_data_float16_ptr = GET_SYM(handle, "mlx_array_data_float16");
if (mlx_array_data_float16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float16\n");
return -1;
}
#endif
#if defined(__aarch64__) || defined(_M_ARM64)
- mlx_array_data_bfloat16_ptr = dlsym(handle, "mlx_array_data_bfloat16");
+ mlx_array_data_bfloat16_ptr = GET_SYM(handle, "mlx_array_data_bfloat16");
if (mlx_array_data_bfloat16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bfloat16\n");
return -1;
}
#endif
- _mlx_array_is_available_ptr = dlsym(handle, "_mlx_array_is_available");
+ _mlx_array_is_available_ptr = GET_SYM(handle, "_mlx_array_is_available");
if (_mlx_array_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_available\n");
return -1;
}
- _mlx_array_wait_ptr = dlsym(handle, "_mlx_array_wait");
+ _mlx_array_wait_ptr = GET_SYM(handle, "_mlx_array_wait");
if (_mlx_array_wait_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_wait\n");
return -1;
}
- _mlx_array_is_contiguous_ptr = dlsym(handle, "_mlx_array_is_contiguous");
+ _mlx_array_is_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_contiguous");
if (_mlx_array_is_contiguous_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_contiguous\n");
return -1;
}
- _mlx_array_is_row_contiguous_ptr = dlsym(handle, "_mlx_array_is_row_contiguous");
+ _mlx_array_is_row_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_row_contiguous");
if (_mlx_array_is_row_contiguous_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_row_contiguous\n");
return -1;
}
- _mlx_array_is_col_contiguous_ptr = dlsym(handle, "_mlx_array_is_col_contiguous");
+ _mlx_array_is_col_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_col_contiguous");
if (_mlx_array_is_col_contiguous_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_col_contiguous\n");
return -1;
}
- mlx_closure_new_ptr = dlsym(handle, "mlx_closure_new");
+ mlx_closure_new_ptr = GET_SYM(handle, "mlx_closure_new");
if (mlx_closure_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new\n");
return -1;
}
- mlx_closure_free_ptr = dlsym(handle, "mlx_closure_free");
+ mlx_closure_free_ptr = GET_SYM(handle, "mlx_closure_free");
if (mlx_closure_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_free\n");
return -1;
}
- mlx_closure_new_func_ptr = dlsym(handle, "mlx_closure_new_func");
+ mlx_closure_new_func_ptr = GET_SYM(handle, "mlx_closure_new_func");
if (mlx_closure_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func\n");
return -1;
}
- mlx_closure_new_func_payload_ptr = dlsym(handle, "mlx_closure_new_func_payload");
+ mlx_closure_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_new_func_payload");
if (mlx_closure_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func_payload\n");
return -1;
}
- mlx_closure_set_ptr = dlsym(handle, "mlx_closure_set");
+ mlx_closure_set_ptr = GET_SYM(handle, "mlx_closure_set");
if (mlx_closure_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_set\n");
return -1;
}
- mlx_closure_apply_ptr = dlsym(handle, "mlx_closure_apply");
+ mlx_closure_apply_ptr = GET_SYM(handle, "mlx_closure_apply");
if (mlx_closure_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_apply\n");
return -1;
}
- mlx_closure_new_unary_ptr = dlsym(handle, "mlx_closure_new_unary");
+ mlx_closure_new_unary_ptr = GET_SYM(handle, "mlx_closure_new_unary");
if (mlx_closure_new_unary_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_unary\n");
return -1;
}
- mlx_closure_kwargs_new_ptr = dlsym(handle, "mlx_closure_kwargs_new");
+ mlx_closure_kwargs_new_ptr = GET_SYM(handle, "mlx_closure_kwargs_new");
if (mlx_closure_kwargs_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new\n");
return -1;
}
- mlx_closure_kwargs_free_ptr = dlsym(handle, "mlx_closure_kwargs_free");
+ mlx_closure_kwargs_free_ptr = GET_SYM(handle, "mlx_closure_kwargs_free");
if (mlx_closure_kwargs_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_free\n");
return -1;
}
- mlx_closure_kwargs_new_func_ptr = dlsym(handle, "mlx_closure_kwargs_new_func");
+ mlx_closure_kwargs_new_func_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func");
if (mlx_closure_kwargs_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func\n");
return -1;
}
- mlx_closure_kwargs_new_func_payload_ptr = dlsym(handle, "mlx_closure_kwargs_new_func_payload");
+ mlx_closure_kwargs_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func_payload");
if (mlx_closure_kwargs_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func_payload\n");
return -1;
}
- mlx_closure_kwargs_set_ptr = dlsym(handle, "mlx_closure_kwargs_set");
+ mlx_closure_kwargs_set_ptr = GET_SYM(handle, "mlx_closure_kwargs_set");
if (mlx_closure_kwargs_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_set\n");
return -1;
}
- mlx_closure_kwargs_apply_ptr = dlsym(handle, "mlx_closure_kwargs_apply");
+ mlx_closure_kwargs_apply_ptr = GET_SYM(handle, "mlx_closure_kwargs_apply");
if (mlx_closure_kwargs_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_apply\n");
return -1;
}
- mlx_closure_value_and_grad_new_ptr = dlsym(handle, "mlx_closure_value_and_grad_new");
+ mlx_closure_value_and_grad_new_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new");
if (mlx_closure_value_and_grad_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new\n");
return -1;
}
- mlx_closure_value_and_grad_free_ptr = dlsym(handle, "mlx_closure_value_and_grad_free");
+ mlx_closure_value_and_grad_free_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_free");
if (mlx_closure_value_and_grad_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_free\n");
return -1;
}
- mlx_closure_value_and_grad_new_func_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func");
+ mlx_closure_value_and_grad_new_func_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func");
if (mlx_closure_value_and_grad_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func\n");
return -1;
}
- mlx_closure_value_and_grad_new_func_payload_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func_payload");
+ mlx_closure_value_and_grad_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func_payload");
if (mlx_closure_value_and_grad_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func_payload\n");
return -1;
}
- mlx_closure_value_and_grad_set_ptr = dlsym(handle, "mlx_closure_value_and_grad_set");
+ mlx_closure_value_and_grad_set_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_set");
if (mlx_closure_value_and_grad_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_set\n");
return -1;
}
- mlx_closure_value_and_grad_apply_ptr = dlsym(handle, "mlx_closure_value_and_grad_apply");
+ mlx_closure_value_and_grad_apply_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_apply");
if (mlx_closure_value_and_grad_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_apply\n");
return -1;
}
- mlx_closure_custom_new_ptr = dlsym(handle, "mlx_closure_custom_new");
+ mlx_closure_custom_new_ptr = GET_SYM(handle, "mlx_closure_custom_new");
if (mlx_closure_custom_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new\n");
return -1;
}
- mlx_closure_custom_free_ptr = dlsym(handle, "mlx_closure_custom_free");
+ mlx_closure_custom_free_ptr = GET_SYM(handle, "mlx_closure_custom_free");
if (mlx_closure_custom_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_free\n");
return -1;
}
- mlx_closure_custom_new_func_ptr = dlsym(handle, "mlx_closure_custom_new_func");
+ mlx_closure_custom_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_new_func");
if (mlx_closure_custom_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func\n");
return -1;
}
- mlx_closure_custom_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_new_func_payload");
+ mlx_closure_custom_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_new_func_payload");
if (mlx_closure_custom_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func_payload\n");
return -1;
}
- mlx_closure_custom_set_ptr = dlsym(handle, "mlx_closure_custom_set");
+ mlx_closure_custom_set_ptr = GET_SYM(handle, "mlx_closure_custom_set");
if (mlx_closure_custom_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_set\n");
return -1;
}
- mlx_closure_custom_apply_ptr = dlsym(handle, "mlx_closure_custom_apply");
+ mlx_closure_custom_apply_ptr = GET_SYM(handle, "mlx_closure_custom_apply");
if (mlx_closure_custom_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_apply\n");
return -1;
}
- mlx_closure_custom_jvp_new_ptr = dlsym(handle, "mlx_closure_custom_jvp_new");
+ mlx_closure_custom_jvp_new_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new");
if (mlx_closure_custom_jvp_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new\n");
return -1;
}
- mlx_closure_custom_jvp_free_ptr = dlsym(handle, "mlx_closure_custom_jvp_free");
+ mlx_closure_custom_jvp_free_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_free");
if (mlx_closure_custom_jvp_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_free\n");
return -1;
}
- mlx_closure_custom_jvp_new_func_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func");
+ mlx_closure_custom_jvp_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func");
if (mlx_closure_custom_jvp_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func\n");
return -1;
}
- mlx_closure_custom_jvp_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func_payload");
+ mlx_closure_custom_jvp_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func_payload");
if (mlx_closure_custom_jvp_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func_payload\n");
return -1;
}
- mlx_closure_custom_jvp_set_ptr = dlsym(handle, "mlx_closure_custom_jvp_set");
+ mlx_closure_custom_jvp_set_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_set");
if (mlx_closure_custom_jvp_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_set\n");
return -1;
}
- mlx_closure_custom_jvp_apply_ptr = dlsym(handle, "mlx_closure_custom_jvp_apply");
+ mlx_closure_custom_jvp_apply_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_apply");
if (mlx_closure_custom_jvp_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_apply\n");
return -1;
}
- mlx_closure_custom_vmap_new_ptr = dlsym(handle, "mlx_closure_custom_vmap_new");
+ mlx_closure_custom_vmap_new_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new");
if (mlx_closure_custom_vmap_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new\n");
return -1;
}
- mlx_closure_custom_vmap_free_ptr = dlsym(handle, "mlx_closure_custom_vmap_free");
+ mlx_closure_custom_vmap_free_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_free");
if (mlx_closure_custom_vmap_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_free\n");
return -1;
}
- mlx_closure_custom_vmap_new_func_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func");
+ mlx_closure_custom_vmap_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func");
if (mlx_closure_custom_vmap_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func\n");
return -1;
}
- mlx_closure_custom_vmap_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func_payload");
+ mlx_closure_custom_vmap_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func_payload");
if (mlx_closure_custom_vmap_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func_payload\n");
return -1;
}
- mlx_closure_custom_vmap_set_ptr = dlsym(handle, "mlx_closure_custom_vmap_set");
+ mlx_closure_custom_vmap_set_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_set");
if (mlx_closure_custom_vmap_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_set\n");
return -1;
}
- mlx_closure_custom_vmap_apply_ptr = dlsym(handle, "mlx_closure_custom_vmap_apply");
+ mlx_closure_custom_vmap_apply_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_apply");
if (mlx_closure_custom_vmap_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_apply\n");
return -1;
}
- mlx_compile_ptr = dlsym(handle, "mlx_compile");
+ mlx_compile_ptr = GET_SYM(handle, "mlx_compile");
if (mlx_compile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_compile\n");
return -1;
}
- mlx_detail_compile_ptr = dlsym(handle, "mlx_detail_compile");
+ mlx_detail_compile_ptr = GET_SYM(handle, "mlx_detail_compile");
if (mlx_detail_compile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile\n");
return -1;
}
- mlx_detail_compile_clear_cache_ptr = dlsym(handle, "mlx_detail_compile_clear_cache");
+ mlx_detail_compile_clear_cache_ptr = GET_SYM(handle, "mlx_detail_compile_clear_cache");
if (mlx_detail_compile_clear_cache_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_clear_cache\n");
return -1;
}
- mlx_detail_compile_erase_ptr = dlsym(handle, "mlx_detail_compile_erase");
+ mlx_detail_compile_erase_ptr = GET_SYM(handle, "mlx_detail_compile_erase");
if (mlx_detail_compile_erase_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_erase\n");
return -1;
}
- mlx_disable_compile_ptr = dlsym(handle, "mlx_disable_compile");
+ mlx_disable_compile_ptr = GET_SYM(handle, "mlx_disable_compile");
if (mlx_disable_compile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_disable_compile\n");
return -1;
}
- mlx_enable_compile_ptr = dlsym(handle, "mlx_enable_compile");
+ mlx_enable_compile_ptr = GET_SYM(handle, "mlx_enable_compile");
if (mlx_enable_compile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_enable_compile\n");
return -1;
}
- mlx_set_compile_mode_ptr = dlsym(handle, "mlx_set_compile_mode");
+ mlx_set_compile_mode_ptr = GET_SYM(handle, "mlx_set_compile_mode");
if (mlx_set_compile_mode_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
return -1;
}
- mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
+ mlx_cuda_is_available_ptr = GET_SYM(handle, "mlx_cuda_is_available");
if (mlx_cuda_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
return -1;
}
- mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
+ mlx_device_new_ptr = GET_SYM(handle, "mlx_device_new");
if (mlx_device_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
return -1;
}
- mlx_device_new_type_ptr = dlsym(handle, "mlx_device_new_type");
+ mlx_device_new_type_ptr = GET_SYM(handle, "mlx_device_new_type");
if (mlx_device_new_type_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new_type\n");
return -1;
}
- mlx_device_free_ptr = dlsym(handle, "mlx_device_free");
+ mlx_device_free_ptr = GET_SYM(handle, "mlx_device_free");
if (mlx_device_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_free\n");
return -1;
}
- mlx_device_set_ptr = dlsym(handle, "mlx_device_set");
+ mlx_device_set_ptr = GET_SYM(handle, "mlx_device_set");
if (mlx_device_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_set\n");
return -1;
}
- mlx_device_tostring_ptr = dlsym(handle, "mlx_device_tostring");
+ mlx_device_tostring_ptr = GET_SYM(handle, "mlx_device_tostring");
if (mlx_device_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_tostring\n");
return -1;
}
- mlx_device_equal_ptr = dlsym(handle, "mlx_device_equal");
+ mlx_device_equal_ptr = GET_SYM(handle, "mlx_device_equal");
if (mlx_device_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_equal\n");
return -1;
}
- mlx_device_get_index_ptr = dlsym(handle, "mlx_device_get_index");
+ mlx_device_get_index_ptr = GET_SYM(handle, "mlx_device_get_index");
if (mlx_device_get_index_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_index\n");
return -1;
}
- mlx_device_get_type_ptr = dlsym(handle, "mlx_device_get_type");
+ mlx_device_get_type_ptr = GET_SYM(handle, "mlx_device_get_type");
if (mlx_device_get_type_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_type\n");
return -1;
}
- mlx_get_default_device_ptr = dlsym(handle, "mlx_get_default_device");
+ mlx_get_default_device_ptr = GET_SYM(handle, "mlx_get_default_device");
if (mlx_get_default_device_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_device\n");
return -1;
}
- mlx_set_default_device_ptr = dlsym(handle, "mlx_set_default_device");
+ mlx_set_default_device_ptr = GET_SYM(handle, "mlx_set_default_device");
if (mlx_set_default_device_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
return -1;
}
- mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
+ mlx_device_is_available_ptr = GET_SYM(handle, "mlx_device_is_available");
if (mlx_device_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
return -1;
}
- mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
+ mlx_device_count_ptr = GET_SYM(handle, "mlx_device_count");
if (mlx_device_count_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
return -1;
}
- mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
+ mlx_device_info_new_ptr = GET_SYM(handle, "mlx_device_info_new");
if (mlx_device_info_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
return -1;
}
- mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
+ mlx_device_info_get_ptr = GET_SYM(handle, "mlx_device_info_get");
if (mlx_device_info_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
return -1;
}
- mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
+ mlx_device_info_free_ptr = GET_SYM(handle, "mlx_device_info_free");
if (mlx_device_info_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
return -1;
}
- mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
+ mlx_device_info_has_key_ptr = GET_SYM(handle, "mlx_device_info_has_key");
if (mlx_device_info_has_key_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
return -1;
}
- mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
+ mlx_device_info_is_string_ptr = GET_SYM(handle, "mlx_device_info_is_string");
if (mlx_device_info_is_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
return -1;
}
- mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
+ mlx_device_info_get_string_ptr = GET_SYM(handle, "mlx_device_info_get_string");
if (mlx_device_info_get_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
return -1;
}
- mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
+ mlx_device_info_get_size_ptr = GET_SYM(handle, "mlx_device_info_get_size");
if (mlx_device_info_get_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
return -1;
}
- mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
+ mlx_device_info_get_keys_ptr = GET_SYM(handle, "mlx_device_info_get_keys");
if (mlx_device_info_get_keys_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
return -1;
}
- mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
+ mlx_distributed_all_gather_ptr = GET_SYM(handle, "mlx_distributed_all_gather");
if (mlx_distributed_all_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
return -1;
}
- mlx_distributed_all_max_ptr = dlsym(handle, "mlx_distributed_all_max");
+ mlx_distributed_all_max_ptr = GET_SYM(handle, "mlx_distributed_all_max");
if (mlx_distributed_all_max_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_max\n");
return -1;
}
- mlx_distributed_all_min_ptr = dlsym(handle, "mlx_distributed_all_min");
+ mlx_distributed_all_min_ptr = GET_SYM(handle, "mlx_distributed_all_min");
if (mlx_distributed_all_min_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_min\n");
return -1;
}
- mlx_distributed_all_sum_ptr = dlsym(handle, "mlx_distributed_all_sum");
+ mlx_distributed_all_sum_ptr = GET_SYM(handle, "mlx_distributed_all_sum");
if (mlx_distributed_all_sum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_sum\n");
return -1;
}
- mlx_distributed_recv_ptr = dlsym(handle, "mlx_distributed_recv");
+ mlx_distributed_recv_ptr = GET_SYM(handle, "mlx_distributed_recv");
if (mlx_distributed_recv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv\n");
return -1;
}
- mlx_distributed_recv_like_ptr = dlsym(handle, "mlx_distributed_recv_like");
+ mlx_distributed_recv_like_ptr = GET_SYM(handle, "mlx_distributed_recv_like");
if (mlx_distributed_recv_like_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv_like\n");
return -1;
}
- mlx_distributed_send_ptr = dlsym(handle, "mlx_distributed_send");
+ mlx_distributed_send_ptr = GET_SYM(handle, "mlx_distributed_send");
if (mlx_distributed_send_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_send\n");
return -1;
}
- mlx_distributed_sum_scatter_ptr = dlsym(handle, "mlx_distributed_sum_scatter");
+ mlx_distributed_sum_scatter_ptr = GET_SYM(handle, "mlx_distributed_sum_scatter");
if (mlx_distributed_sum_scatter_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_sum_scatter\n");
return -1;
}
- mlx_distributed_group_rank_ptr = dlsym(handle, "mlx_distributed_group_rank");
+ mlx_distributed_group_rank_ptr = GET_SYM(handle, "mlx_distributed_group_rank");
if (mlx_distributed_group_rank_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_rank\n");
return -1;
}
- mlx_distributed_group_size_ptr = dlsym(handle, "mlx_distributed_group_size");
+ mlx_distributed_group_size_ptr = GET_SYM(handle, "mlx_distributed_group_size");
if (mlx_distributed_group_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_size\n");
return -1;
}
- mlx_distributed_group_split_ptr = dlsym(handle, "mlx_distributed_group_split");
+ mlx_distributed_group_split_ptr = GET_SYM(handle, "mlx_distributed_group_split");
if (mlx_distributed_group_split_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_split\n");
return -1;
}
- mlx_distributed_is_available_ptr = dlsym(handle, "mlx_distributed_is_available");
+ mlx_distributed_is_available_ptr = GET_SYM(handle, "mlx_distributed_is_available");
if (mlx_distributed_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_is_available\n");
return -1;
}
- mlx_distributed_init_ptr = dlsym(handle, "mlx_distributed_init");
+ mlx_distributed_init_ptr = GET_SYM(handle, "mlx_distributed_init");
if (mlx_distributed_init_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n");
return -1;
}
- mlx_set_error_handler_ptr = dlsym(handle, "mlx_set_error_handler");
+ mlx_set_error_handler_ptr = GET_SYM(handle, "mlx_set_error_handler");
if (mlx_set_error_handler_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_error_handler\n");
return -1;
}
- _mlx_error_ptr = dlsym(handle, "_mlx_error");
+ _mlx_error_ptr = GET_SYM(handle, "_mlx_error");
if (_mlx_error_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_error\n");
return -1;
}
- mlx_export_function_ptr = dlsym(handle, "mlx_export_function");
+ mlx_export_function_ptr = GET_SYM(handle, "mlx_export_function");
if (mlx_export_function_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function\n");
return -1;
}
- mlx_export_function_kwargs_ptr = dlsym(handle, "mlx_export_function_kwargs");
+ mlx_export_function_kwargs_ptr = GET_SYM(handle, "mlx_export_function_kwargs");
if (mlx_export_function_kwargs_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function_kwargs\n");
return -1;
}
- mlx_function_exporter_new_ptr = dlsym(handle, "mlx_function_exporter_new");
+ mlx_function_exporter_new_ptr = GET_SYM(handle, "mlx_function_exporter_new");
if (mlx_function_exporter_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_new\n");
return -1;
}
- mlx_function_exporter_free_ptr = dlsym(handle, "mlx_function_exporter_free");
+ mlx_function_exporter_free_ptr = GET_SYM(handle, "mlx_function_exporter_free");
if (mlx_function_exporter_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_free\n");
return -1;
}
- mlx_function_exporter_apply_ptr = dlsym(handle, "mlx_function_exporter_apply");
+ mlx_function_exporter_apply_ptr = GET_SYM(handle, "mlx_function_exporter_apply");
if (mlx_function_exporter_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply\n");
return -1;
}
- mlx_function_exporter_apply_kwargs_ptr = dlsym(handle, "mlx_function_exporter_apply_kwargs");
+ mlx_function_exporter_apply_kwargs_ptr = GET_SYM(handle, "mlx_function_exporter_apply_kwargs");
if (mlx_function_exporter_apply_kwargs_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply_kwargs\n");
return -1;
}
- mlx_imported_function_new_ptr = dlsym(handle, "mlx_imported_function_new");
+ mlx_imported_function_new_ptr = GET_SYM(handle, "mlx_imported_function_new");
if (mlx_imported_function_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_new\n");
return -1;
}
- mlx_imported_function_free_ptr = dlsym(handle, "mlx_imported_function_free");
+ mlx_imported_function_free_ptr = GET_SYM(handle, "mlx_imported_function_free");
if (mlx_imported_function_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_free\n");
return -1;
}
- mlx_imported_function_apply_ptr = dlsym(handle, "mlx_imported_function_apply");
+ mlx_imported_function_apply_ptr = GET_SYM(handle, "mlx_imported_function_apply");
if (mlx_imported_function_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply\n");
return -1;
}
- mlx_imported_function_apply_kwargs_ptr = dlsym(handle, "mlx_imported_function_apply_kwargs");
+ mlx_imported_function_apply_kwargs_ptr = GET_SYM(handle, "mlx_imported_function_apply_kwargs");
if (mlx_imported_function_apply_kwargs_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply_kwargs\n");
return -1;
}
- mlx_fast_cuda_kernel_config_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_new");
+ mlx_fast_cuda_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_new");
if (mlx_fast_cuda_kernel_config_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_new\n");
return -1;
}
- mlx_fast_cuda_kernel_config_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_free");
+ mlx_fast_cuda_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_free");
if (mlx_fast_cuda_kernel_config_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_free\n");
return -1;
}
- mlx_fast_cuda_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_output_arg");
+ mlx_fast_cuda_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_output_arg");
if (mlx_fast_cuda_kernel_config_add_output_arg_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_output_arg\n");
return -1;
}
- mlx_fast_cuda_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_grid");
+ mlx_fast_cuda_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_grid");
if (mlx_fast_cuda_kernel_config_set_grid_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_grid\n");
return -1;
}
- mlx_fast_cuda_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_thread_group");
+ mlx_fast_cuda_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_thread_group");
if (mlx_fast_cuda_kernel_config_set_thread_group_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_thread_group\n");
return -1;
}
- mlx_fast_cuda_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_init_value");
+ mlx_fast_cuda_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_init_value");
if (mlx_fast_cuda_kernel_config_set_init_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_init_value\n");
return -1;
}
- mlx_fast_cuda_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_verbose");
+ mlx_fast_cuda_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_verbose");
if (mlx_fast_cuda_kernel_config_set_verbose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_verbose\n");
return -1;
}
- mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype");
+ mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype");
if (mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_dtype\n");
return -1;
}
- mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int");
+ mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int");
if (mlx_fast_cuda_kernel_config_add_template_arg_int_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_int\n");
return -1;
}
- mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool");
+ mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool");
if (mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_bool\n");
return -1;
}
- mlx_fast_cuda_kernel_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_new");
+ mlx_fast_cuda_kernel_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_new");
if (mlx_fast_cuda_kernel_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_new\n");
return -1;
}
- mlx_fast_cuda_kernel_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_free");
+ mlx_fast_cuda_kernel_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_free");
if (mlx_fast_cuda_kernel_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_free\n");
return -1;
}
- mlx_fast_cuda_kernel_apply_ptr = dlsym(handle, "mlx_fast_cuda_kernel_apply");
+ mlx_fast_cuda_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_apply");
if (mlx_fast_cuda_kernel_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_apply\n");
return -1;
}
- mlx_fast_layer_norm_ptr = dlsym(handle, "mlx_fast_layer_norm");
+ mlx_fast_layer_norm_ptr = GET_SYM(handle, "mlx_fast_layer_norm");
if (mlx_fast_layer_norm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_layer_norm\n");
return -1;
}
- mlx_fast_metal_kernel_config_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_new");
+ mlx_fast_metal_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_new");
if (mlx_fast_metal_kernel_config_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_new\n");
return -1;
}
- mlx_fast_metal_kernel_config_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_free");
+ mlx_fast_metal_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_free");
if (mlx_fast_metal_kernel_config_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_free\n");
return -1;
}
- mlx_fast_metal_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_output_arg");
+ mlx_fast_metal_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_output_arg");
if (mlx_fast_metal_kernel_config_add_output_arg_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_output_arg\n");
return -1;
}
- mlx_fast_metal_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_grid");
+ mlx_fast_metal_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_grid");
if (mlx_fast_metal_kernel_config_set_grid_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_grid\n");
return -1;
}
- mlx_fast_metal_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_thread_group");
+ mlx_fast_metal_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_thread_group");
if (mlx_fast_metal_kernel_config_set_thread_group_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_thread_group\n");
return -1;
}
- mlx_fast_metal_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_init_value");
+ mlx_fast_metal_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_init_value");
if (mlx_fast_metal_kernel_config_set_init_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_init_value\n");
return -1;
}
- mlx_fast_metal_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_verbose");
+ mlx_fast_metal_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_verbose");
if (mlx_fast_metal_kernel_config_set_verbose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_verbose\n");
return -1;
}
- mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype");
+ mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype");
if (mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_dtype\n");
return -1;
}
- mlx_fast_metal_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_int");
+ mlx_fast_metal_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_int");
if (mlx_fast_metal_kernel_config_add_template_arg_int_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_int\n");
return -1;
}
- mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool");
+ mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool");
if (mlx_fast_metal_kernel_config_add_template_arg_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_bool\n");
return -1;
}
- mlx_fast_metal_kernel_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_new");
+ mlx_fast_metal_kernel_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_new");
if (mlx_fast_metal_kernel_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_new\n");
return -1;
}
- mlx_fast_metal_kernel_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_free");
+ mlx_fast_metal_kernel_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_free");
if (mlx_fast_metal_kernel_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_free\n");
return -1;
}
- mlx_fast_metal_kernel_apply_ptr = dlsym(handle, "mlx_fast_metal_kernel_apply");
+ mlx_fast_metal_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_apply");
if (mlx_fast_metal_kernel_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_apply\n");
return -1;
}
- mlx_fast_rms_norm_ptr = dlsym(handle, "mlx_fast_rms_norm");
+ mlx_fast_rms_norm_ptr = GET_SYM(handle, "mlx_fast_rms_norm");
if (mlx_fast_rms_norm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rms_norm\n");
return -1;
}
- mlx_fast_rope_ptr = dlsym(handle, "mlx_fast_rope");
+ mlx_fast_rope_ptr = GET_SYM(handle, "mlx_fast_rope");
if (mlx_fast_rope_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope\n");
return -1;
}
- mlx_fast_rope_dynamic_ptr = dlsym(handle, "mlx_fast_rope_dynamic");
+ mlx_fast_rope_dynamic_ptr = GET_SYM(handle, "mlx_fast_rope_dynamic");
if (mlx_fast_rope_dynamic_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope_dynamic\n");
return -1;
}
- mlx_fast_scaled_dot_product_attention_ptr = dlsym(handle, "mlx_fast_scaled_dot_product_attention");
+ mlx_fast_scaled_dot_product_attention_ptr = GET_SYM(handle, "mlx_fast_scaled_dot_product_attention");
if (mlx_fast_scaled_dot_product_attention_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_scaled_dot_product_attention\n");
return -1;
}
- mlx_fft_fft_ptr = dlsym(handle, "mlx_fft_fft");
+ mlx_fft_fft_ptr = GET_SYM(handle, "mlx_fft_fft");
if (mlx_fft_fft_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft\n");
return -1;
}
- mlx_fft_fft2_ptr = dlsym(handle, "mlx_fft_fft2");
+ mlx_fft_fft2_ptr = GET_SYM(handle, "mlx_fft_fft2");
if (mlx_fft_fft2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft2\n");
return -1;
}
- mlx_fft_fftn_ptr = dlsym(handle, "mlx_fft_fftn");
+ mlx_fft_fftn_ptr = GET_SYM(handle, "mlx_fft_fftn");
if (mlx_fft_fftn_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftn\n");
return -1;
}
- mlx_fft_fftshift_ptr = dlsym(handle, "mlx_fft_fftshift");
+ mlx_fft_fftshift_ptr = GET_SYM(handle, "mlx_fft_fftshift");
if (mlx_fft_fftshift_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftshift\n");
return -1;
}
- mlx_fft_ifft_ptr = dlsym(handle, "mlx_fft_ifft");
+ mlx_fft_ifft_ptr = GET_SYM(handle, "mlx_fft_ifft");
if (mlx_fft_ifft_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft\n");
return -1;
}
- mlx_fft_ifft2_ptr = dlsym(handle, "mlx_fft_ifft2");
+ mlx_fft_ifft2_ptr = GET_SYM(handle, "mlx_fft_ifft2");
if (mlx_fft_ifft2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft2\n");
return -1;
}
- mlx_fft_ifftn_ptr = dlsym(handle, "mlx_fft_ifftn");
+ mlx_fft_ifftn_ptr = GET_SYM(handle, "mlx_fft_ifftn");
if (mlx_fft_ifftn_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftn\n");
return -1;
}
- mlx_fft_ifftshift_ptr = dlsym(handle, "mlx_fft_ifftshift");
+ mlx_fft_ifftshift_ptr = GET_SYM(handle, "mlx_fft_ifftshift");
if (mlx_fft_ifftshift_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftshift\n");
return -1;
}
- mlx_fft_irfft_ptr = dlsym(handle, "mlx_fft_irfft");
+ mlx_fft_irfft_ptr = GET_SYM(handle, "mlx_fft_irfft");
if (mlx_fft_irfft_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft\n");
return -1;
}
- mlx_fft_irfft2_ptr = dlsym(handle, "mlx_fft_irfft2");
+ mlx_fft_irfft2_ptr = GET_SYM(handle, "mlx_fft_irfft2");
if (mlx_fft_irfft2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft2\n");
return -1;
}
- mlx_fft_irfftn_ptr = dlsym(handle, "mlx_fft_irfftn");
+ mlx_fft_irfftn_ptr = GET_SYM(handle, "mlx_fft_irfftn");
if (mlx_fft_irfftn_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfftn\n");
return -1;
}
- mlx_fft_rfft_ptr = dlsym(handle, "mlx_fft_rfft");
+ mlx_fft_rfft_ptr = GET_SYM(handle, "mlx_fft_rfft");
if (mlx_fft_rfft_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft\n");
return -1;
}
- mlx_fft_rfft2_ptr = dlsym(handle, "mlx_fft_rfft2");
+ mlx_fft_rfft2_ptr = GET_SYM(handle, "mlx_fft_rfft2");
if (mlx_fft_rfft2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft2\n");
return -1;
}
- mlx_fft_rfftn_ptr = dlsym(handle, "mlx_fft_rfftn");
+ mlx_fft_rfftn_ptr = GET_SYM(handle, "mlx_fft_rfftn");
if (mlx_fft_rfftn_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftn\n");
return -1;
}
- mlx_load_reader_ptr = dlsym(handle, "mlx_load_reader");
+ mlx_load_reader_ptr = GET_SYM(handle, "mlx_load_reader");
if (mlx_load_reader_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_reader\n");
return -1;
}
- mlx_load_ptr = dlsym(handle, "mlx_load");
+ mlx_load_ptr = GET_SYM(handle, "mlx_load");
if (mlx_load_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_load\n");
return -1;
}
- mlx_load_safetensors_reader_ptr = dlsym(handle, "mlx_load_safetensors_reader");
+ mlx_load_safetensors_reader_ptr = GET_SYM(handle, "mlx_load_safetensors_reader");
if (mlx_load_safetensors_reader_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors_reader\n");
return -1;
}
- mlx_load_safetensors_ptr = dlsym(handle, "mlx_load_safetensors");
+ mlx_load_safetensors_ptr = GET_SYM(handle, "mlx_load_safetensors");
if (mlx_load_safetensors_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors\n");
return -1;
}
- mlx_save_writer_ptr = dlsym(handle, "mlx_save_writer");
+ mlx_save_writer_ptr = GET_SYM(handle, "mlx_save_writer");
if (mlx_save_writer_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_writer\n");
return -1;
}
- mlx_save_ptr = dlsym(handle, "mlx_save");
+ mlx_save_ptr = GET_SYM(handle, "mlx_save");
if (mlx_save_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_save\n");
return -1;
}
- mlx_save_safetensors_writer_ptr = dlsym(handle, "mlx_save_safetensors_writer");
+ mlx_save_safetensors_writer_ptr = GET_SYM(handle, "mlx_save_safetensors_writer");
if (mlx_save_safetensors_writer_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors_writer\n");
return -1;
}
- mlx_save_safetensors_ptr = dlsym(handle, "mlx_save_safetensors");
+ mlx_save_safetensors_ptr = GET_SYM(handle, "mlx_save_safetensors");
if (mlx_save_safetensors_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors\n");
return -1;
}
- mlx_io_reader_new_ptr = dlsym(handle, "mlx_io_reader_new");
+ mlx_io_reader_new_ptr = GET_SYM(handle, "mlx_io_reader_new");
if (mlx_io_reader_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_new\n");
return -1;
}
- mlx_io_reader_descriptor_ptr = dlsym(handle, "mlx_io_reader_descriptor");
+ mlx_io_reader_descriptor_ptr = GET_SYM(handle, "mlx_io_reader_descriptor");
if (mlx_io_reader_descriptor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_descriptor\n");
return -1;
}
- mlx_io_reader_tostring_ptr = dlsym(handle, "mlx_io_reader_tostring");
+ mlx_io_reader_tostring_ptr = GET_SYM(handle, "mlx_io_reader_tostring");
if (mlx_io_reader_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_tostring\n");
return -1;
}
- mlx_io_reader_free_ptr = dlsym(handle, "mlx_io_reader_free");
+ mlx_io_reader_free_ptr = GET_SYM(handle, "mlx_io_reader_free");
if (mlx_io_reader_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_free\n");
return -1;
}
- mlx_io_writer_new_ptr = dlsym(handle, "mlx_io_writer_new");
+ mlx_io_writer_new_ptr = GET_SYM(handle, "mlx_io_writer_new");
if (mlx_io_writer_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_new\n");
return -1;
}
- mlx_io_writer_descriptor_ptr = dlsym(handle, "mlx_io_writer_descriptor");
+ mlx_io_writer_descriptor_ptr = GET_SYM(handle, "mlx_io_writer_descriptor");
if (mlx_io_writer_descriptor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_descriptor\n");
return -1;
}
- mlx_io_writer_tostring_ptr = dlsym(handle, "mlx_io_writer_tostring");
+ mlx_io_writer_tostring_ptr = GET_SYM(handle, "mlx_io_writer_tostring");
if (mlx_io_writer_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_tostring\n");
return -1;
}
- mlx_io_writer_free_ptr = dlsym(handle, "mlx_io_writer_free");
+ mlx_io_writer_free_ptr = GET_SYM(handle, "mlx_io_writer_free");
if (mlx_io_writer_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_free\n");
return -1;
}
- mlx_linalg_cholesky_ptr = dlsym(handle, "mlx_linalg_cholesky");
+ mlx_linalg_cholesky_ptr = GET_SYM(handle, "mlx_linalg_cholesky");
if (mlx_linalg_cholesky_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky\n");
return -1;
}
- mlx_linalg_cholesky_inv_ptr = dlsym(handle, "mlx_linalg_cholesky_inv");
+ mlx_linalg_cholesky_inv_ptr = GET_SYM(handle, "mlx_linalg_cholesky_inv");
if (mlx_linalg_cholesky_inv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky_inv\n");
return -1;
}
- mlx_linalg_cross_ptr = dlsym(handle, "mlx_linalg_cross");
+ mlx_linalg_cross_ptr = GET_SYM(handle, "mlx_linalg_cross");
if (mlx_linalg_cross_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cross\n");
return -1;
}
- mlx_linalg_eig_ptr = dlsym(handle, "mlx_linalg_eig");
+ mlx_linalg_eig_ptr = GET_SYM(handle, "mlx_linalg_eig");
if (mlx_linalg_eig_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eig\n");
return -1;
}
- mlx_linalg_eigh_ptr = dlsym(handle, "mlx_linalg_eigh");
+ mlx_linalg_eigh_ptr = GET_SYM(handle, "mlx_linalg_eigh");
if (mlx_linalg_eigh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigh\n");
return -1;
}
- mlx_linalg_eigvals_ptr = dlsym(handle, "mlx_linalg_eigvals");
+ mlx_linalg_eigvals_ptr = GET_SYM(handle, "mlx_linalg_eigvals");
if (mlx_linalg_eigvals_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvals\n");
return -1;
}
- mlx_linalg_eigvalsh_ptr = dlsym(handle, "mlx_linalg_eigvalsh");
+ mlx_linalg_eigvalsh_ptr = GET_SYM(handle, "mlx_linalg_eigvalsh");
if (mlx_linalg_eigvalsh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvalsh\n");
return -1;
}
- mlx_linalg_inv_ptr = dlsym(handle, "mlx_linalg_inv");
+ mlx_linalg_inv_ptr = GET_SYM(handle, "mlx_linalg_inv");
if (mlx_linalg_inv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_inv\n");
return -1;
}
- mlx_linalg_lu_ptr = dlsym(handle, "mlx_linalg_lu");
+ mlx_linalg_lu_ptr = GET_SYM(handle, "mlx_linalg_lu");
if (mlx_linalg_lu_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu\n");
return -1;
}
- mlx_linalg_lu_factor_ptr = dlsym(handle, "mlx_linalg_lu_factor");
+ mlx_linalg_lu_factor_ptr = GET_SYM(handle, "mlx_linalg_lu_factor");
if (mlx_linalg_lu_factor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu_factor\n");
return -1;
}
- mlx_linalg_norm_ptr = dlsym(handle, "mlx_linalg_norm");
+ mlx_linalg_norm_ptr = GET_SYM(handle, "mlx_linalg_norm");
if (mlx_linalg_norm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm\n");
return -1;
}
- mlx_linalg_norm_matrix_ptr = dlsym(handle, "mlx_linalg_norm_matrix");
+ mlx_linalg_norm_matrix_ptr = GET_SYM(handle, "mlx_linalg_norm_matrix");
if (mlx_linalg_norm_matrix_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_matrix\n");
return -1;
}
- mlx_linalg_norm_l2_ptr = dlsym(handle, "mlx_linalg_norm_l2");
+ mlx_linalg_norm_l2_ptr = GET_SYM(handle, "mlx_linalg_norm_l2");
if (mlx_linalg_norm_l2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_l2\n");
return -1;
}
- mlx_linalg_pinv_ptr = dlsym(handle, "mlx_linalg_pinv");
+ mlx_linalg_pinv_ptr = GET_SYM(handle, "mlx_linalg_pinv");
if (mlx_linalg_pinv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_pinv\n");
return -1;
}
- mlx_linalg_qr_ptr = dlsym(handle, "mlx_linalg_qr");
+ mlx_linalg_qr_ptr = GET_SYM(handle, "mlx_linalg_qr");
if (mlx_linalg_qr_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_qr\n");
return -1;
}
- mlx_linalg_solve_ptr = dlsym(handle, "mlx_linalg_solve");
+ mlx_linalg_solve_ptr = GET_SYM(handle, "mlx_linalg_solve");
if (mlx_linalg_solve_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve\n");
return -1;
}
- mlx_linalg_solve_triangular_ptr = dlsym(handle, "mlx_linalg_solve_triangular");
+ mlx_linalg_solve_triangular_ptr = GET_SYM(handle, "mlx_linalg_solve_triangular");
if (mlx_linalg_solve_triangular_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve_triangular\n");
return -1;
}
- mlx_linalg_svd_ptr = dlsym(handle, "mlx_linalg_svd");
+ mlx_linalg_svd_ptr = GET_SYM(handle, "mlx_linalg_svd");
if (mlx_linalg_svd_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_svd\n");
return -1;
}
- mlx_linalg_tri_inv_ptr = dlsym(handle, "mlx_linalg_tri_inv");
+ mlx_linalg_tri_inv_ptr = GET_SYM(handle, "mlx_linalg_tri_inv");
if (mlx_linalg_tri_inv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_tri_inv\n");
return -1;
}
- mlx_map_string_to_array_new_ptr = dlsym(handle, "mlx_map_string_to_array_new");
+ mlx_map_string_to_array_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_new");
if (mlx_map_string_to_array_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_new\n");
return -1;
}
- mlx_map_string_to_array_set_ptr = dlsym(handle, "mlx_map_string_to_array_set");
+ mlx_map_string_to_array_set_ptr = GET_SYM(handle, "mlx_map_string_to_array_set");
if (mlx_map_string_to_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_set\n");
return -1;
}
- mlx_map_string_to_array_free_ptr = dlsym(handle, "mlx_map_string_to_array_free");
+ mlx_map_string_to_array_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_free");
if (mlx_map_string_to_array_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_free\n");
return -1;
}
- mlx_map_string_to_array_insert_ptr = dlsym(handle, "mlx_map_string_to_array_insert");
+ mlx_map_string_to_array_insert_ptr = GET_SYM(handle, "mlx_map_string_to_array_insert");
if (mlx_map_string_to_array_insert_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_insert\n");
return -1;
}
- mlx_map_string_to_array_get_ptr = dlsym(handle, "mlx_map_string_to_array_get");
+ mlx_map_string_to_array_get_ptr = GET_SYM(handle, "mlx_map_string_to_array_get");
if (mlx_map_string_to_array_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_get\n");
return -1;
}
- mlx_map_string_to_array_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_new");
+ mlx_map_string_to_array_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_new");
if (mlx_map_string_to_array_iterator_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_new\n");
return -1;
}
- mlx_map_string_to_array_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_free");
+ mlx_map_string_to_array_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_free");
if (mlx_map_string_to_array_iterator_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_free\n");
return -1;
}
- mlx_map_string_to_array_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_next");
+ mlx_map_string_to_array_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_next");
if (mlx_map_string_to_array_iterator_next_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_next\n");
return -1;
}
- mlx_map_string_to_string_new_ptr = dlsym(handle, "mlx_map_string_to_string_new");
+ mlx_map_string_to_string_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_new");
if (mlx_map_string_to_string_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_new\n");
return -1;
}
- mlx_map_string_to_string_set_ptr = dlsym(handle, "mlx_map_string_to_string_set");
+ mlx_map_string_to_string_set_ptr = GET_SYM(handle, "mlx_map_string_to_string_set");
if (mlx_map_string_to_string_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_set\n");
return -1;
}
- mlx_map_string_to_string_free_ptr = dlsym(handle, "mlx_map_string_to_string_free");
+ mlx_map_string_to_string_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_free");
if (mlx_map_string_to_string_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_free\n");
return -1;
}
- mlx_map_string_to_string_insert_ptr = dlsym(handle, "mlx_map_string_to_string_insert");
+ mlx_map_string_to_string_insert_ptr = GET_SYM(handle, "mlx_map_string_to_string_insert");
if (mlx_map_string_to_string_insert_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_insert\n");
return -1;
}
- mlx_map_string_to_string_get_ptr = dlsym(handle, "mlx_map_string_to_string_get");
+ mlx_map_string_to_string_get_ptr = GET_SYM(handle, "mlx_map_string_to_string_get");
if (mlx_map_string_to_string_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_get\n");
return -1;
}
- mlx_map_string_to_string_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_new");
+ mlx_map_string_to_string_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_new");
if (mlx_map_string_to_string_iterator_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_new\n");
return -1;
}
- mlx_map_string_to_string_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_free");
+ mlx_map_string_to_string_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_free");
if (mlx_map_string_to_string_iterator_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_free\n");
return -1;
}
- mlx_map_string_to_string_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_next");
+ mlx_map_string_to_string_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_next");
if (mlx_map_string_to_string_iterator_next_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_next\n");
return -1;
}
- mlx_clear_cache_ptr = dlsym(handle, "mlx_clear_cache");
+ mlx_clear_cache_ptr = GET_SYM(handle, "mlx_clear_cache");
if (mlx_clear_cache_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_clear_cache\n");
return -1;
}
- mlx_get_active_memory_ptr = dlsym(handle, "mlx_get_active_memory");
+ mlx_get_active_memory_ptr = GET_SYM(handle, "mlx_get_active_memory");
if (mlx_get_active_memory_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_active_memory\n");
return -1;
}
- mlx_get_cache_memory_ptr = dlsym(handle, "mlx_get_cache_memory");
+ mlx_get_cache_memory_ptr = GET_SYM(handle, "mlx_get_cache_memory");
if (mlx_get_cache_memory_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_cache_memory\n");
return -1;
}
- mlx_get_memory_limit_ptr = dlsym(handle, "mlx_get_memory_limit");
+ mlx_get_memory_limit_ptr = GET_SYM(handle, "mlx_get_memory_limit");
if (mlx_get_memory_limit_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_memory_limit\n");
return -1;
}
- mlx_get_peak_memory_ptr = dlsym(handle, "mlx_get_peak_memory");
+ mlx_get_peak_memory_ptr = GET_SYM(handle, "mlx_get_peak_memory");
if (mlx_get_peak_memory_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_peak_memory\n");
return -1;
}
- mlx_reset_peak_memory_ptr = dlsym(handle, "mlx_reset_peak_memory");
+ mlx_reset_peak_memory_ptr = GET_SYM(handle, "mlx_reset_peak_memory");
if (mlx_reset_peak_memory_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_reset_peak_memory\n");
return -1;
}
- mlx_set_cache_limit_ptr = dlsym(handle, "mlx_set_cache_limit");
+ mlx_set_cache_limit_ptr = GET_SYM(handle, "mlx_set_cache_limit");
if (mlx_set_cache_limit_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_cache_limit\n");
return -1;
}
- mlx_set_memory_limit_ptr = dlsym(handle, "mlx_set_memory_limit");
+ mlx_set_memory_limit_ptr = GET_SYM(handle, "mlx_set_memory_limit");
if (mlx_set_memory_limit_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_memory_limit\n");
return -1;
}
- mlx_set_wired_limit_ptr = dlsym(handle, "mlx_set_wired_limit");
+ mlx_set_wired_limit_ptr = GET_SYM(handle, "mlx_set_wired_limit");
if (mlx_set_wired_limit_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
return -1;
}
- mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
+ mlx_metal_is_available_ptr = GET_SYM(handle, "mlx_metal_is_available");
if (mlx_metal_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
return -1;
}
- mlx_metal_start_capture_ptr = dlsym(handle, "mlx_metal_start_capture");
+ mlx_metal_start_capture_ptr = GET_SYM(handle, "mlx_metal_start_capture");
if (mlx_metal_start_capture_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_start_capture\n");
return -1;
}
- mlx_metal_stop_capture_ptr = dlsym(handle, "mlx_metal_stop_capture");
+ mlx_metal_stop_capture_ptr = GET_SYM(handle, "mlx_metal_stop_capture");
if (mlx_metal_stop_capture_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_stop_capture\n");
return -1;
}
- mlx_abs_ptr = dlsym(handle, "mlx_abs");
+ mlx_abs_ptr = GET_SYM(handle, "mlx_abs");
if (mlx_abs_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_abs\n");
return -1;
}
- mlx_add_ptr = dlsym(handle, "mlx_add");
+ mlx_add_ptr = GET_SYM(handle, "mlx_add");
if (mlx_add_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_add\n");
return -1;
}
- mlx_addmm_ptr = dlsym(handle, "mlx_addmm");
+ mlx_addmm_ptr = GET_SYM(handle, "mlx_addmm");
if (mlx_addmm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_addmm\n");
return -1;
}
- mlx_all_axes_ptr = dlsym(handle, "mlx_all_axes");
+ mlx_all_axes_ptr = GET_SYM(handle, "mlx_all_axes");
if (mlx_all_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axes\n");
return -1;
}
- mlx_all_axis_ptr = dlsym(handle, "mlx_all_axis");
+ mlx_all_axis_ptr = GET_SYM(handle, "mlx_all_axis");
if (mlx_all_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axis\n");
return -1;
}
- mlx_all_ptr = dlsym(handle, "mlx_all");
+ mlx_all_ptr = GET_SYM(handle, "mlx_all");
if (mlx_all_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_all\n");
return -1;
}
- mlx_allclose_ptr = dlsym(handle, "mlx_allclose");
+ mlx_allclose_ptr = GET_SYM(handle, "mlx_allclose");
if (mlx_allclose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_allclose\n");
return -1;
}
- mlx_any_axes_ptr = dlsym(handle, "mlx_any_axes");
+ mlx_any_axes_ptr = GET_SYM(handle, "mlx_any_axes");
if (mlx_any_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axes\n");
return -1;
}
- mlx_any_axis_ptr = dlsym(handle, "mlx_any_axis");
+ mlx_any_axis_ptr = GET_SYM(handle, "mlx_any_axis");
if (mlx_any_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axis\n");
return -1;
}
- mlx_any_ptr = dlsym(handle, "mlx_any");
+ mlx_any_ptr = GET_SYM(handle, "mlx_any");
if (mlx_any_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_any\n");
return -1;
}
- mlx_arange_ptr = dlsym(handle, "mlx_arange");
+ mlx_arange_ptr = GET_SYM(handle, "mlx_arange");
if (mlx_arange_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arange\n");
return -1;
}
- mlx_arccos_ptr = dlsym(handle, "mlx_arccos");
+ mlx_arccos_ptr = GET_SYM(handle, "mlx_arccos");
if (mlx_arccos_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arccos\n");
return -1;
}
- mlx_arccosh_ptr = dlsym(handle, "mlx_arccosh");
+ mlx_arccosh_ptr = GET_SYM(handle, "mlx_arccosh");
if (mlx_arccosh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arccosh\n");
return -1;
}
- mlx_arcsin_ptr = dlsym(handle, "mlx_arcsin");
+ mlx_arcsin_ptr = GET_SYM(handle, "mlx_arcsin");
if (mlx_arcsin_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsin\n");
return -1;
}
- mlx_arcsinh_ptr = dlsym(handle, "mlx_arcsinh");
+ mlx_arcsinh_ptr = GET_SYM(handle, "mlx_arcsinh");
if (mlx_arcsinh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsinh\n");
return -1;
}
- mlx_arctan_ptr = dlsym(handle, "mlx_arctan");
+ mlx_arctan_ptr = GET_SYM(handle, "mlx_arctan");
if (mlx_arctan_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan\n");
return -1;
}
- mlx_arctan2_ptr = dlsym(handle, "mlx_arctan2");
+ mlx_arctan2_ptr = GET_SYM(handle, "mlx_arctan2");
if (mlx_arctan2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan2\n");
return -1;
}
- mlx_arctanh_ptr = dlsym(handle, "mlx_arctanh");
+ mlx_arctanh_ptr = GET_SYM(handle, "mlx_arctanh");
if (mlx_arctanh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctanh\n");
return -1;
}
- mlx_argmax_axis_ptr = dlsym(handle, "mlx_argmax_axis");
+ mlx_argmax_axis_ptr = GET_SYM(handle, "mlx_argmax_axis");
if (mlx_argmax_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax_axis\n");
return -1;
}
- mlx_argmax_ptr = dlsym(handle, "mlx_argmax");
+ mlx_argmax_ptr = GET_SYM(handle, "mlx_argmax");
if (mlx_argmax_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax\n");
return -1;
}
- mlx_argmin_axis_ptr = dlsym(handle, "mlx_argmin_axis");
+ mlx_argmin_axis_ptr = GET_SYM(handle, "mlx_argmin_axis");
if (mlx_argmin_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin_axis\n");
return -1;
}
- mlx_argmin_ptr = dlsym(handle, "mlx_argmin");
+ mlx_argmin_ptr = GET_SYM(handle, "mlx_argmin");
if (mlx_argmin_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin\n");
return -1;
}
- mlx_argpartition_axis_ptr = dlsym(handle, "mlx_argpartition_axis");
+ mlx_argpartition_axis_ptr = GET_SYM(handle, "mlx_argpartition_axis");
if (mlx_argpartition_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition_axis\n");
return -1;
}
- mlx_argpartition_ptr = dlsym(handle, "mlx_argpartition");
+ mlx_argpartition_ptr = GET_SYM(handle, "mlx_argpartition");
if (mlx_argpartition_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition\n");
return -1;
}
- mlx_argsort_axis_ptr = dlsym(handle, "mlx_argsort_axis");
+ mlx_argsort_axis_ptr = GET_SYM(handle, "mlx_argsort_axis");
if (mlx_argsort_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort_axis\n");
return -1;
}
- mlx_argsort_ptr = dlsym(handle, "mlx_argsort");
+ mlx_argsort_ptr = GET_SYM(handle, "mlx_argsort");
if (mlx_argsort_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort\n");
return -1;
}
- mlx_array_equal_ptr = dlsym(handle, "mlx_array_equal");
+ mlx_array_equal_ptr = GET_SYM(handle, "mlx_array_equal");
if (mlx_array_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_equal\n");
return -1;
}
- mlx_as_strided_ptr = dlsym(handle, "mlx_as_strided");
+ mlx_as_strided_ptr = GET_SYM(handle, "mlx_as_strided");
if (mlx_as_strided_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_as_strided\n");
return -1;
}
- mlx_astype_ptr = dlsym(handle, "mlx_astype");
+ mlx_astype_ptr = GET_SYM(handle, "mlx_astype");
if (mlx_astype_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_astype\n");
return -1;
}
- mlx_atleast_1d_ptr = dlsym(handle, "mlx_atleast_1d");
+ mlx_atleast_1d_ptr = GET_SYM(handle, "mlx_atleast_1d");
if (mlx_atleast_1d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_1d\n");
return -1;
}
- mlx_atleast_2d_ptr = dlsym(handle, "mlx_atleast_2d");
+ mlx_atleast_2d_ptr = GET_SYM(handle, "mlx_atleast_2d");
if (mlx_atleast_2d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_2d\n");
return -1;
}
- mlx_atleast_3d_ptr = dlsym(handle, "mlx_atleast_3d");
+ mlx_atleast_3d_ptr = GET_SYM(handle, "mlx_atleast_3d");
if (mlx_atleast_3d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
return -1;
}
- mlx_bitwise_and_ptr = dlsym(handle, "mlx_bitwise_and");
+ mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
if (mlx_bitwise_and_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
return -1;
}
- mlx_bitwise_invert_ptr = dlsym(handle, "mlx_bitwise_invert");
+ mlx_bitwise_invert_ptr = GET_SYM(handle, "mlx_bitwise_invert");
if (mlx_bitwise_invert_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_invert\n");
return -1;
}
- mlx_bitwise_or_ptr = dlsym(handle, "mlx_bitwise_or");
+ mlx_bitwise_or_ptr = GET_SYM(handle, "mlx_bitwise_or");
if (mlx_bitwise_or_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_or\n");
return -1;
}
- mlx_bitwise_xor_ptr = dlsym(handle, "mlx_bitwise_xor");
+ mlx_bitwise_xor_ptr = GET_SYM(handle, "mlx_bitwise_xor");
if (mlx_bitwise_xor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
return -1;
}
- mlx_block_masked_mm_ptr = dlsym(handle, "mlx_block_masked_mm");
+ mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
if (mlx_block_masked_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
return -1;
}
- mlx_broadcast_arrays_ptr = dlsym(handle, "mlx_broadcast_arrays");
+ mlx_broadcast_arrays_ptr = GET_SYM(handle, "mlx_broadcast_arrays");
if (mlx_broadcast_arrays_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_arrays\n");
return -1;
}
- mlx_broadcast_to_ptr = dlsym(handle, "mlx_broadcast_to");
+ mlx_broadcast_to_ptr = GET_SYM(handle, "mlx_broadcast_to");
if (mlx_broadcast_to_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_to\n");
return -1;
}
- mlx_ceil_ptr = dlsym(handle, "mlx_ceil");
+ mlx_ceil_ptr = GET_SYM(handle, "mlx_ceil");
if (mlx_ceil_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_ceil\n");
return -1;
}
- mlx_clip_ptr = dlsym(handle, "mlx_clip");
+ mlx_clip_ptr = GET_SYM(handle, "mlx_clip");
if (mlx_clip_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_clip\n");
return -1;
}
- mlx_concatenate_axis_ptr = dlsym(handle, "mlx_concatenate_axis");
+ mlx_concatenate_axis_ptr = GET_SYM(handle, "mlx_concatenate_axis");
if (mlx_concatenate_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate_axis\n");
return -1;
}
- mlx_concatenate_ptr = dlsym(handle, "mlx_concatenate");
+ mlx_concatenate_ptr = GET_SYM(handle, "mlx_concatenate");
if (mlx_concatenate_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate\n");
return -1;
}
- mlx_conjugate_ptr = dlsym(handle, "mlx_conjugate");
+ mlx_conjugate_ptr = GET_SYM(handle, "mlx_conjugate");
if (mlx_conjugate_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conjugate\n");
return -1;
}
- mlx_contiguous_ptr = dlsym(handle, "mlx_contiguous");
+ mlx_contiguous_ptr = GET_SYM(handle, "mlx_contiguous");
if (mlx_contiguous_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_contiguous\n");
return -1;
}
- mlx_conv1d_ptr = dlsym(handle, "mlx_conv1d");
+ mlx_conv1d_ptr = GET_SYM(handle, "mlx_conv1d");
if (mlx_conv1d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv1d\n");
return -1;
}
- mlx_conv2d_ptr = dlsym(handle, "mlx_conv2d");
+ mlx_conv2d_ptr = GET_SYM(handle, "mlx_conv2d");
if (mlx_conv2d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv2d\n");
return -1;
}
- mlx_conv3d_ptr = dlsym(handle, "mlx_conv3d");
+ mlx_conv3d_ptr = GET_SYM(handle, "mlx_conv3d");
if (mlx_conv3d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv3d\n");
return -1;
}
- mlx_conv_general_ptr = dlsym(handle, "mlx_conv_general");
+ mlx_conv_general_ptr = GET_SYM(handle, "mlx_conv_general");
if (mlx_conv_general_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_general\n");
return -1;
}
- mlx_conv_transpose1d_ptr = dlsym(handle, "mlx_conv_transpose1d");
+ mlx_conv_transpose1d_ptr = GET_SYM(handle, "mlx_conv_transpose1d");
if (mlx_conv_transpose1d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose1d\n");
return -1;
}
- mlx_conv_transpose2d_ptr = dlsym(handle, "mlx_conv_transpose2d");
+ mlx_conv_transpose2d_ptr = GET_SYM(handle, "mlx_conv_transpose2d");
if (mlx_conv_transpose2d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose2d\n");
return -1;
}
- mlx_conv_transpose3d_ptr = dlsym(handle, "mlx_conv_transpose3d");
+ mlx_conv_transpose3d_ptr = GET_SYM(handle, "mlx_conv_transpose3d");
if (mlx_conv_transpose3d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose3d\n");
return -1;
}
- mlx_copy_ptr = dlsym(handle, "mlx_copy");
+ mlx_copy_ptr = GET_SYM(handle, "mlx_copy");
if (mlx_copy_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_copy\n");
return -1;
}
- mlx_cos_ptr = dlsym(handle, "mlx_cos");
+ mlx_cos_ptr = GET_SYM(handle, "mlx_cos");
if (mlx_cos_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cos\n");
return -1;
}
- mlx_cosh_ptr = dlsym(handle, "mlx_cosh");
+ mlx_cosh_ptr = GET_SYM(handle, "mlx_cosh");
if (mlx_cosh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cosh\n");
return -1;
}
- mlx_cummax_ptr = dlsym(handle, "mlx_cummax");
+ mlx_cummax_ptr = GET_SYM(handle, "mlx_cummax");
if (mlx_cummax_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cummax\n");
return -1;
}
- mlx_cummin_ptr = dlsym(handle, "mlx_cummin");
+ mlx_cummin_ptr = GET_SYM(handle, "mlx_cummin");
if (mlx_cummin_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cummin\n");
return -1;
}
- mlx_cumprod_ptr = dlsym(handle, "mlx_cumprod");
+ mlx_cumprod_ptr = GET_SYM(handle, "mlx_cumprod");
if (mlx_cumprod_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cumprod\n");
return -1;
}
- mlx_cumsum_ptr = dlsym(handle, "mlx_cumsum");
+ mlx_cumsum_ptr = GET_SYM(handle, "mlx_cumsum");
if (mlx_cumsum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cumsum\n");
return -1;
}
- mlx_degrees_ptr = dlsym(handle, "mlx_degrees");
+ mlx_degrees_ptr = GET_SYM(handle, "mlx_degrees");
if (mlx_degrees_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_degrees\n");
return -1;
}
- mlx_depends_ptr = dlsym(handle, "mlx_depends");
+ mlx_depends_ptr = GET_SYM(handle, "mlx_depends");
if (mlx_depends_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_depends\n");
return -1;
}
- mlx_dequantize_ptr = dlsym(handle, "mlx_dequantize");
+ mlx_dequantize_ptr = GET_SYM(handle, "mlx_dequantize");
if (mlx_dequantize_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_dequantize\n");
return -1;
}
- mlx_diag_ptr = dlsym(handle, "mlx_diag");
+ mlx_diag_ptr = GET_SYM(handle, "mlx_diag");
if (mlx_diag_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_diag\n");
return -1;
}
- mlx_diagonal_ptr = dlsym(handle, "mlx_diagonal");
+ mlx_diagonal_ptr = GET_SYM(handle, "mlx_diagonal");
if (mlx_diagonal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_diagonal\n");
return -1;
}
- mlx_divide_ptr = dlsym(handle, "mlx_divide");
+ mlx_divide_ptr = GET_SYM(handle, "mlx_divide");
if (mlx_divide_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_divide\n");
return -1;
}
- mlx_divmod_ptr = dlsym(handle, "mlx_divmod");
+ mlx_divmod_ptr = GET_SYM(handle, "mlx_divmod");
if (mlx_divmod_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_divmod\n");
return -1;
}
- mlx_einsum_ptr = dlsym(handle, "mlx_einsum");
+ mlx_einsum_ptr = GET_SYM(handle, "mlx_einsum");
if (mlx_einsum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_einsum\n");
return -1;
}
- mlx_equal_ptr = dlsym(handle, "mlx_equal");
+ mlx_equal_ptr = GET_SYM(handle, "mlx_equal");
if (mlx_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_equal\n");
return -1;
}
- mlx_erf_ptr = dlsym(handle, "mlx_erf");
+ mlx_erf_ptr = GET_SYM(handle, "mlx_erf");
if (mlx_erf_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_erf\n");
return -1;
}
- mlx_erfinv_ptr = dlsym(handle, "mlx_erfinv");
+ mlx_erfinv_ptr = GET_SYM(handle, "mlx_erfinv");
if (mlx_erfinv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_erfinv\n");
return -1;
}
- mlx_exp_ptr = dlsym(handle, "mlx_exp");
+ mlx_exp_ptr = GET_SYM(handle, "mlx_exp");
if (mlx_exp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_exp\n");
return -1;
}
- mlx_expand_dims_axes_ptr = dlsym(handle, "mlx_expand_dims_axes");
+ mlx_expand_dims_axes_ptr = GET_SYM(handle, "mlx_expand_dims_axes");
if (mlx_expand_dims_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims_axes\n");
return -1;
}
- mlx_expand_dims_ptr = dlsym(handle, "mlx_expand_dims");
+ mlx_expand_dims_ptr = GET_SYM(handle, "mlx_expand_dims");
if (mlx_expand_dims_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims\n");
return -1;
}
- mlx_expm1_ptr = dlsym(handle, "mlx_expm1");
+ mlx_expm1_ptr = GET_SYM(handle, "mlx_expm1");
if (mlx_expm1_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_expm1\n");
return -1;
}
- mlx_eye_ptr = dlsym(handle, "mlx_eye");
+ mlx_eye_ptr = GET_SYM(handle, "mlx_eye");
if (mlx_eye_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_eye\n");
return -1;
}
- mlx_flatten_ptr = dlsym(handle, "mlx_flatten");
+ mlx_flatten_ptr = GET_SYM(handle, "mlx_flatten");
if (mlx_flatten_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_flatten\n");
return -1;
}
- mlx_floor_ptr = dlsym(handle, "mlx_floor");
+ mlx_floor_ptr = GET_SYM(handle, "mlx_floor");
if (mlx_floor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_floor\n");
return -1;
}
- mlx_floor_divide_ptr = dlsym(handle, "mlx_floor_divide");
+ mlx_floor_divide_ptr = GET_SYM(handle, "mlx_floor_divide");
if (mlx_floor_divide_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_floor_divide\n");
return -1;
}
- mlx_from_fp8_ptr = dlsym(handle, "mlx_from_fp8");
+ mlx_from_fp8_ptr = GET_SYM(handle, "mlx_from_fp8");
if (mlx_from_fp8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_from_fp8\n");
return -1;
}
- mlx_full_ptr = dlsym(handle, "mlx_full");
+ mlx_full_ptr = GET_SYM(handle, "mlx_full");
if (mlx_full_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_full\n");
return -1;
}
- mlx_full_like_ptr = dlsym(handle, "mlx_full_like");
+ mlx_full_like_ptr = GET_SYM(handle, "mlx_full_like");
if (mlx_full_like_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_full_like\n");
return -1;
}
- mlx_gather_ptr = dlsym(handle, "mlx_gather");
+ mlx_gather_ptr = GET_SYM(handle, "mlx_gather");
if (mlx_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather\n");
return -1;
}
- mlx_gather_single_ptr = dlsym(handle, "mlx_gather_single");
+ mlx_gather_single_ptr = GET_SYM(handle, "mlx_gather_single");
if (mlx_gather_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_single\n");
return -1;
}
- mlx_gather_mm_ptr = dlsym(handle, "mlx_gather_mm");
+ mlx_gather_mm_ptr = GET_SYM(handle, "mlx_gather_mm");
if (mlx_gather_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_mm\n");
return -1;
}
- mlx_gather_qmm_ptr = dlsym(handle, "mlx_gather_qmm");
+ mlx_gather_qmm_ptr = GET_SYM(handle, "mlx_gather_qmm");
if (mlx_gather_qmm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_qmm\n");
return -1;
}
- mlx_greater_ptr = dlsym(handle, "mlx_greater");
+ mlx_greater_ptr = GET_SYM(handle, "mlx_greater");
if (mlx_greater_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_greater\n");
return -1;
}
- mlx_greater_equal_ptr = dlsym(handle, "mlx_greater_equal");
+ mlx_greater_equal_ptr = GET_SYM(handle, "mlx_greater_equal");
if (mlx_greater_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_greater_equal\n");
return -1;
}
- mlx_hadamard_transform_ptr = dlsym(handle, "mlx_hadamard_transform");
+ mlx_hadamard_transform_ptr = GET_SYM(handle, "mlx_hadamard_transform");
if (mlx_hadamard_transform_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
return -1;
}
- mlx_identity_ptr = dlsym(handle, "mlx_identity");
+ mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
if (mlx_identity_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
return -1;
}
- mlx_imag_ptr = dlsym(handle, "mlx_imag");
+ mlx_imag_ptr = GET_SYM(handle, "mlx_imag");
if (mlx_imag_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imag\n");
return -1;
}
- mlx_inner_ptr = dlsym(handle, "mlx_inner");
+ mlx_inner_ptr = GET_SYM(handle, "mlx_inner");
if (mlx_inner_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_inner\n");
return -1;
}
- mlx_isclose_ptr = dlsym(handle, "mlx_isclose");
+ mlx_isclose_ptr = GET_SYM(handle, "mlx_isclose");
if (mlx_isclose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isclose\n");
return -1;
}
- mlx_isfinite_ptr = dlsym(handle, "mlx_isfinite");
+ mlx_isfinite_ptr = GET_SYM(handle, "mlx_isfinite");
if (mlx_isfinite_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isfinite\n");
return -1;
}
- mlx_isinf_ptr = dlsym(handle, "mlx_isinf");
+ mlx_isinf_ptr = GET_SYM(handle, "mlx_isinf");
if (mlx_isinf_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isinf\n");
return -1;
}
- mlx_isnan_ptr = dlsym(handle, "mlx_isnan");
+ mlx_isnan_ptr = GET_SYM(handle, "mlx_isnan");
if (mlx_isnan_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isnan\n");
return -1;
}
- mlx_isneginf_ptr = dlsym(handle, "mlx_isneginf");
+ mlx_isneginf_ptr = GET_SYM(handle, "mlx_isneginf");
if (mlx_isneginf_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isneginf\n");
return -1;
}
- mlx_isposinf_ptr = dlsym(handle, "mlx_isposinf");
+ mlx_isposinf_ptr = GET_SYM(handle, "mlx_isposinf");
if (mlx_isposinf_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isposinf\n");
return -1;
}
- mlx_kron_ptr = dlsym(handle, "mlx_kron");
+ mlx_kron_ptr = GET_SYM(handle, "mlx_kron");
if (mlx_kron_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_kron\n");
return -1;
}
- mlx_left_shift_ptr = dlsym(handle, "mlx_left_shift");
+ mlx_left_shift_ptr = GET_SYM(handle, "mlx_left_shift");
if (mlx_left_shift_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_left_shift\n");
return -1;
}
- mlx_less_ptr = dlsym(handle, "mlx_less");
+ mlx_less_ptr = GET_SYM(handle, "mlx_less");
if (mlx_less_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_less\n");
return -1;
}
- mlx_less_equal_ptr = dlsym(handle, "mlx_less_equal");
+ mlx_less_equal_ptr = GET_SYM(handle, "mlx_less_equal");
if (mlx_less_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_less_equal\n");
return -1;
}
- mlx_linspace_ptr = dlsym(handle, "mlx_linspace");
+ mlx_linspace_ptr = GET_SYM(handle, "mlx_linspace");
if (mlx_linspace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linspace\n");
return -1;
}
- mlx_log_ptr = dlsym(handle, "mlx_log");
+ mlx_log_ptr = GET_SYM(handle, "mlx_log");
if (mlx_log_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_log\n");
return -1;
}
- mlx_log10_ptr = dlsym(handle, "mlx_log10");
+ mlx_log10_ptr = GET_SYM(handle, "mlx_log10");
if (mlx_log10_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_log10\n");
return -1;
}
- mlx_log1p_ptr = dlsym(handle, "mlx_log1p");
+ mlx_log1p_ptr = GET_SYM(handle, "mlx_log1p");
if (mlx_log1p_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_log1p\n");
return -1;
}
- mlx_log2_ptr = dlsym(handle, "mlx_log2");
+ mlx_log2_ptr = GET_SYM(handle, "mlx_log2");
if (mlx_log2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_log2\n");
return -1;
}
- mlx_logaddexp_ptr = dlsym(handle, "mlx_logaddexp");
+ mlx_logaddexp_ptr = GET_SYM(handle, "mlx_logaddexp");
if (mlx_logaddexp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logaddexp\n");
return -1;
}
- mlx_logcumsumexp_ptr = dlsym(handle, "mlx_logcumsumexp");
+ mlx_logcumsumexp_ptr = GET_SYM(handle, "mlx_logcumsumexp");
if (mlx_logcumsumexp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logcumsumexp\n");
return -1;
}
- mlx_logical_and_ptr = dlsym(handle, "mlx_logical_and");
+ mlx_logical_and_ptr = GET_SYM(handle, "mlx_logical_and");
if (mlx_logical_and_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_and\n");
return -1;
}
- mlx_logical_not_ptr = dlsym(handle, "mlx_logical_not");
+ mlx_logical_not_ptr = GET_SYM(handle, "mlx_logical_not");
if (mlx_logical_not_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_not\n");
return -1;
}
- mlx_logical_or_ptr = dlsym(handle, "mlx_logical_or");
+ mlx_logical_or_ptr = GET_SYM(handle, "mlx_logical_or");
if (mlx_logical_or_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_or\n");
return -1;
}
- mlx_logsumexp_axes_ptr = dlsym(handle, "mlx_logsumexp_axes");
+ mlx_logsumexp_axes_ptr = GET_SYM(handle, "mlx_logsumexp_axes");
if (mlx_logsumexp_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axes\n");
return -1;
}
- mlx_logsumexp_axis_ptr = dlsym(handle, "mlx_logsumexp_axis");
+ mlx_logsumexp_axis_ptr = GET_SYM(handle, "mlx_logsumexp_axis");
if (mlx_logsumexp_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axis\n");
return -1;
}
- mlx_logsumexp_ptr = dlsym(handle, "mlx_logsumexp");
+ mlx_logsumexp_ptr = GET_SYM(handle, "mlx_logsumexp");
if (mlx_logsumexp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp\n");
return -1;
}
- mlx_masked_scatter_ptr = dlsym(handle, "mlx_masked_scatter");
+ mlx_masked_scatter_ptr = GET_SYM(handle, "mlx_masked_scatter");
if (mlx_masked_scatter_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_masked_scatter\n");
return -1;
}
- mlx_matmul_ptr = dlsym(handle, "mlx_matmul");
+ mlx_matmul_ptr = GET_SYM(handle, "mlx_matmul");
if (mlx_matmul_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_matmul\n");
return -1;
}
- mlx_max_axes_ptr = dlsym(handle, "mlx_max_axes");
+ mlx_max_axes_ptr = GET_SYM(handle, "mlx_max_axes");
if (mlx_max_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axes\n");
return -1;
}
- mlx_max_axis_ptr = dlsym(handle, "mlx_max_axis");
+ mlx_max_axis_ptr = GET_SYM(handle, "mlx_max_axis");
if (mlx_max_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axis\n");
return -1;
}
- mlx_max_ptr = dlsym(handle, "mlx_max");
+ mlx_max_ptr = GET_SYM(handle, "mlx_max");
if (mlx_max_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_max\n");
return -1;
}
- mlx_maximum_ptr = dlsym(handle, "mlx_maximum");
+ mlx_maximum_ptr = GET_SYM(handle, "mlx_maximum");
if (mlx_maximum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_maximum\n");
return -1;
}
- mlx_mean_axes_ptr = dlsym(handle, "mlx_mean_axes");
+ mlx_mean_axes_ptr = GET_SYM(handle, "mlx_mean_axes");
if (mlx_mean_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axes\n");
return -1;
}
- mlx_mean_axis_ptr = dlsym(handle, "mlx_mean_axis");
+ mlx_mean_axis_ptr = GET_SYM(handle, "mlx_mean_axis");
if (mlx_mean_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axis\n");
return -1;
}
- mlx_mean_ptr = dlsym(handle, "mlx_mean");
+ mlx_mean_ptr = GET_SYM(handle, "mlx_mean");
if (mlx_mean_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean\n");
return -1;
}
- mlx_median_ptr = dlsym(handle, "mlx_median");
+ mlx_median_ptr = GET_SYM(handle, "mlx_median");
if (mlx_median_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_median\n");
return -1;
}
- mlx_meshgrid_ptr = dlsym(handle, "mlx_meshgrid");
+ mlx_meshgrid_ptr = GET_SYM(handle, "mlx_meshgrid");
if (mlx_meshgrid_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_meshgrid\n");
return -1;
}
- mlx_min_axes_ptr = dlsym(handle, "mlx_min_axes");
+ mlx_min_axes_ptr = GET_SYM(handle, "mlx_min_axes");
if (mlx_min_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axes\n");
return -1;
}
- mlx_min_axis_ptr = dlsym(handle, "mlx_min_axis");
+ mlx_min_axis_ptr = GET_SYM(handle, "mlx_min_axis");
if (mlx_min_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axis\n");
return -1;
}
- mlx_min_ptr = dlsym(handle, "mlx_min");
+ mlx_min_ptr = GET_SYM(handle, "mlx_min");
if (mlx_min_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_min\n");
return -1;
}
- mlx_minimum_ptr = dlsym(handle, "mlx_minimum");
+ mlx_minimum_ptr = GET_SYM(handle, "mlx_minimum");
if (mlx_minimum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_minimum\n");
return -1;
}
- mlx_moveaxis_ptr = dlsym(handle, "mlx_moveaxis");
+ mlx_moveaxis_ptr = GET_SYM(handle, "mlx_moveaxis");
if (mlx_moveaxis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_moveaxis\n");
return -1;
}
- mlx_multiply_ptr = dlsym(handle, "mlx_multiply");
+ mlx_multiply_ptr = GET_SYM(handle, "mlx_multiply");
if (mlx_multiply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_multiply\n");
return -1;
}
- mlx_nan_to_num_ptr = dlsym(handle, "mlx_nan_to_num");
+ mlx_nan_to_num_ptr = GET_SYM(handle, "mlx_nan_to_num");
if (mlx_nan_to_num_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_nan_to_num\n");
return -1;
}
- mlx_negative_ptr = dlsym(handle, "mlx_negative");
+ mlx_negative_ptr = GET_SYM(handle, "mlx_negative");
if (mlx_negative_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_negative\n");
return -1;
}
- mlx_not_equal_ptr = dlsym(handle, "mlx_not_equal");
+ mlx_not_equal_ptr = GET_SYM(handle, "mlx_not_equal");
if (mlx_not_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_not_equal\n");
return -1;
}
- mlx_number_of_elements_ptr = dlsym(handle, "mlx_number_of_elements");
+ mlx_number_of_elements_ptr = GET_SYM(handle, "mlx_number_of_elements");
if (mlx_number_of_elements_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_number_of_elements\n");
return -1;
}
- mlx_ones_ptr = dlsym(handle, "mlx_ones");
+ mlx_ones_ptr = GET_SYM(handle, "mlx_ones");
if (mlx_ones_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_ones\n");
return -1;
}
- mlx_ones_like_ptr = dlsym(handle, "mlx_ones_like");
+ mlx_ones_like_ptr = GET_SYM(handle, "mlx_ones_like");
if (mlx_ones_like_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_ones_like\n");
return -1;
}
- mlx_outer_ptr = dlsym(handle, "mlx_outer");
+ mlx_outer_ptr = GET_SYM(handle, "mlx_outer");
if (mlx_outer_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_outer\n");
return -1;
}
- mlx_pad_ptr = dlsym(handle, "mlx_pad");
+ mlx_pad_ptr = GET_SYM(handle, "mlx_pad");
if (mlx_pad_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_pad\n");
return -1;
}
- mlx_pad_symmetric_ptr = dlsym(handle, "mlx_pad_symmetric");
+ mlx_pad_symmetric_ptr = GET_SYM(handle, "mlx_pad_symmetric");
if (mlx_pad_symmetric_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_pad_symmetric\n");
return -1;
}
- mlx_partition_axis_ptr = dlsym(handle, "mlx_partition_axis");
+ mlx_partition_axis_ptr = GET_SYM(handle, "mlx_partition_axis");
if (mlx_partition_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_partition_axis\n");
return -1;
}
- mlx_partition_ptr = dlsym(handle, "mlx_partition");
+ mlx_partition_ptr = GET_SYM(handle, "mlx_partition");
if (mlx_partition_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_partition\n");
return -1;
}
- mlx_power_ptr = dlsym(handle, "mlx_power");
+ mlx_power_ptr = GET_SYM(handle, "mlx_power");
if (mlx_power_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_power\n");
return -1;
}
- mlx_prod_axes_ptr = dlsym(handle, "mlx_prod_axes");
+ mlx_prod_axes_ptr = GET_SYM(handle, "mlx_prod_axes");
if (mlx_prod_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axes\n");
return -1;
}
- mlx_prod_axis_ptr = dlsym(handle, "mlx_prod_axis");
+ mlx_prod_axis_ptr = GET_SYM(handle, "mlx_prod_axis");
if (mlx_prod_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axis\n");
return -1;
}
- mlx_prod_ptr = dlsym(handle, "mlx_prod");
+ mlx_prod_ptr = GET_SYM(handle, "mlx_prod");
if (mlx_prod_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod\n");
return -1;
}
- mlx_put_along_axis_ptr = dlsym(handle, "mlx_put_along_axis");
+ mlx_put_along_axis_ptr = GET_SYM(handle, "mlx_put_along_axis");
if (mlx_put_along_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_put_along_axis\n");
return -1;
}
- mlx_qqmm_ptr = dlsym(handle, "mlx_qqmm");
+ mlx_qqmm_ptr = GET_SYM(handle, "mlx_qqmm");
if (mlx_qqmm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_qqmm\n");
return -1;
}
- mlx_quantize_ptr = dlsym(handle, "mlx_quantize");
+ mlx_quantize_ptr = GET_SYM(handle, "mlx_quantize");
if (mlx_quantize_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_quantize\n");
return -1;
}
- mlx_quantized_matmul_ptr = dlsym(handle, "mlx_quantized_matmul");
+ mlx_quantized_matmul_ptr = GET_SYM(handle, "mlx_quantized_matmul");
if (mlx_quantized_matmul_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_quantized_matmul\n");
return -1;
}
- mlx_radians_ptr = dlsym(handle, "mlx_radians");
+ mlx_radians_ptr = GET_SYM(handle, "mlx_radians");
if (mlx_radians_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_radians\n");
return -1;
}
- mlx_real_ptr = dlsym(handle, "mlx_real");
+ mlx_real_ptr = GET_SYM(handle, "mlx_real");
if (mlx_real_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_real\n");
return -1;
}
- mlx_reciprocal_ptr = dlsym(handle, "mlx_reciprocal");
+ mlx_reciprocal_ptr = GET_SYM(handle, "mlx_reciprocal");
if (mlx_reciprocal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_reciprocal\n");
return -1;
}
- mlx_remainder_ptr = dlsym(handle, "mlx_remainder");
+ mlx_remainder_ptr = GET_SYM(handle, "mlx_remainder");
if (mlx_remainder_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_remainder\n");
return -1;
}
- mlx_repeat_axis_ptr = dlsym(handle, "mlx_repeat_axis");
+ mlx_repeat_axis_ptr = GET_SYM(handle, "mlx_repeat_axis");
if (mlx_repeat_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat_axis\n");
return -1;
}
- mlx_repeat_ptr = dlsym(handle, "mlx_repeat");
+ mlx_repeat_ptr = GET_SYM(handle, "mlx_repeat");
if (mlx_repeat_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat\n");
return -1;
}
- mlx_reshape_ptr = dlsym(handle, "mlx_reshape");
+ mlx_reshape_ptr = GET_SYM(handle, "mlx_reshape");
if (mlx_reshape_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_reshape\n");
return -1;
}
- mlx_right_shift_ptr = dlsym(handle, "mlx_right_shift");
+ mlx_right_shift_ptr = GET_SYM(handle, "mlx_right_shift");
if (mlx_right_shift_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_right_shift\n");
return -1;
}
- mlx_roll_axis_ptr = dlsym(handle, "mlx_roll_axis");
+ mlx_roll_axis_ptr = GET_SYM(handle, "mlx_roll_axis");
if (mlx_roll_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axis\n");
return -1;
}
- mlx_roll_axes_ptr = dlsym(handle, "mlx_roll_axes");
+ mlx_roll_axes_ptr = GET_SYM(handle, "mlx_roll_axes");
if (mlx_roll_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axes\n");
return -1;
}
- mlx_roll_ptr = dlsym(handle, "mlx_roll");
+ mlx_roll_ptr = GET_SYM(handle, "mlx_roll");
if (mlx_roll_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll\n");
return -1;
}
- mlx_round_ptr = dlsym(handle, "mlx_round");
+ mlx_round_ptr = GET_SYM(handle, "mlx_round");
if (mlx_round_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_round\n");
return -1;
}
- mlx_rsqrt_ptr = dlsym(handle, "mlx_rsqrt");
+ mlx_rsqrt_ptr = GET_SYM(handle, "mlx_rsqrt");
if (mlx_rsqrt_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_rsqrt\n");
return -1;
}
- mlx_scatter_ptr = dlsym(handle, "mlx_scatter");
+ mlx_scatter_ptr = GET_SYM(handle, "mlx_scatter");
if (mlx_scatter_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter\n");
return -1;
}
- mlx_scatter_single_ptr = dlsym(handle, "mlx_scatter_single");
+ mlx_scatter_single_ptr = GET_SYM(handle, "mlx_scatter_single");
if (mlx_scatter_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_single\n");
return -1;
}
- mlx_scatter_add_ptr = dlsym(handle, "mlx_scatter_add");
+ mlx_scatter_add_ptr = GET_SYM(handle, "mlx_scatter_add");
if (mlx_scatter_add_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add\n");
return -1;
}
- mlx_scatter_add_single_ptr = dlsym(handle, "mlx_scatter_add_single");
+ mlx_scatter_add_single_ptr = GET_SYM(handle, "mlx_scatter_add_single");
if (mlx_scatter_add_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_single\n");
return -1;
}
- mlx_scatter_add_axis_ptr = dlsym(handle, "mlx_scatter_add_axis");
+ mlx_scatter_add_axis_ptr = GET_SYM(handle, "mlx_scatter_add_axis");
if (mlx_scatter_add_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_axis\n");
return -1;
}
- mlx_scatter_max_ptr = dlsym(handle, "mlx_scatter_max");
+ mlx_scatter_max_ptr = GET_SYM(handle, "mlx_scatter_max");
if (mlx_scatter_max_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max\n");
return -1;
}
- mlx_scatter_max_single_ptr = dlsym(handle, "mlx_scatter_max_single");
+ mlx_scatter_max_single_ptr = GET_SYM(handle, "mlx_scatter_max_single");
if (mlx_scatter_max_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max_single\n");
return -1;
}
- mlx_scatter_min_ptr = dlsym(handle, "mlx_scatter_min");
+ mlx_scatter_min_ptr = GET_SYM(handle, "mlx_scatter_min");
if (mlx_scatter_min_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min\n");
return -1;
}
- mlx_scatter_min_single_ptr = dlsym(handle, "mlx_scatter_min_single");
+ mlx_scatter_min_single_ptr = GET_SYM(handle, "mlx_scatter_min_single");
if (mlx_scatter_min_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min_single\n");
return -1;
}
- mlx_scatter_prod_ptr = dlsym(handle, "mlx_scatter_prod");
+ mlx_scatter_prod_ptr = GET_SYM(handle, "mlx_scatter_prod");
if (mlx_scatter_prod_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod\n");
return -1;
}
- mlx_scatter_prod_single_ptr = dlsym(handle, "mlx_scatter_prod_single");
+ mlx_scatter_prod_single_ptr = GET_SYM(handle, "mlx_scatter_prod_single");
if (mlx_scatter_prod_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod_single\n");
return -1;
}
- mlx_segmented_mm_ptr = dlsym(handle, "mlx_segmented_mm");
+ mlx_segmented_mm_ptr = GET_SYM(handle, "mlx_segmented_mm");
if (mlx_segmented_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_segmented_mm\n");
return -1;
}
- mlx_sigmoid_ptr = dlsym(handle, "mlx_sigmoid");
+ mlx_sigmoid_ptr = GET_SYM(handle, "mlx_sigmoid");
if (mlx_sigmoid_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sigmoid\n");
return -1;
}
- mlx_sign_ptr = dlsym(handle, "mlx_sign");
+ mlx_sign_ptr = GET_SYM(handle, "mlx_sign");
if (mlx_sign_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sign\n");
return -1;
}
- mlx_sin_ptr = dlsym(handle, "mlx_sin");
+ mlx_sin_ptr = GET_SYM(handle, "mlx_sin");
if (mlx_sin_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sin\n");
return -1;
}
- mlx_sinh_ptr = dlsym(handle, "mlx_sinh");
+ mlx_sinh_ptr = GET_SYM(handle, "mlx_sinh");
if (mlx_sinh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sinh\n");
return -1;
}
- mlx_slice_ptr = dlsym(handle, "mlx_slice");
+ mlx_slice_ptr = GET_SYM(handle, "mlx_slice");
if (mlx_slice_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice\n");
return -1;
}
- mlx_slice_dynamic_ptr = dlsym(handle, "mlx_slice_dynamic");
+ mlx_slice_dynamic_ptr = GET_SYM(handle, "mlx_slice_dynamic");
if (mlx_slice_dynamic_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_dynamic\n");
return -1;
}
- mlx_slice_update_ptr = dlsym(handle, "mlx_slice_update");
+ mlx_slice_update_ptr = GET_SYM(handle, "mlx_slice_update");
if (mlx_slice_update_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update\n");
return -1;
}
- mlx_slice_update_dynamic_ptr = dlsym(handle, "mlx_slice_update_dynamic");
+ mlx_slice_update_dynamic_ptr = GET_SYM(handle, "mlx_slice_update_dynamic");
if (mlx_slice_update_dynamic_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_dynamic\n");
return -1;
}
- mlx_softmax_axes_ptr = dlsym(handle, "mlx_softmax_axes");
+ mlx_softmax_axes_ptr = GET_SYM(handle, "mlx_softmax_axes");
if (mlx_softmax_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axes\n");
return -1;
}
- mlx_softmax_axis_ptr = dlsym(handle, "mlx_softmax_axis");
+ mlx_softmax_axis_ptr = GET_SYM(handle, "mlx_softmax_axis");
if (mlx_softmax_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axis\n");
return -1;
}
- mlx_softmax_ptr = dlsym(handle, "mlx_softmax");
+ mlx_softmax_ptr = GET_SYM(handle, "mlx_softmax");
if (mlx_softmax_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax\n");
return -1;
}
- mlx_sort_axis_ptr = dlsym(handle, "mlx_sort_axis");
+ mlx_sort_axis_ptr = GET_SYM(handle, "mlx_sort_axis");
if (mlx_sort_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sort_axis\n");
return -1;
}
- mlx_sort_ptr = dlsym(handle, "mlx_sort");
+ mlx_sort_ptr = GET_SYM(handle, "mlx_sort");
if (mlx_sort_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sort\n");
return -1;
}
- mlx_split_ptr = dlsym(handle, "mlx_split");
+ mlx_split_ptr = GET_SYM(handle, "mlx_split");
if (mlx_split_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_split\n");
return -1;
}
- mlx_split_sections_ptr = dlsym(handle, "mlx_split_sections");
+ mlx_split_sections_ptr = GET_SYM(handle, "mlx_split_sections");
if (mlx_split_sections_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_split_sections\n");
return -1;
}
- mlx_sqrt_ptr = dlsym(handle, "mlx_sqrt");
+ mlx_sqrt_ptr = GET_SYM(handle, "mlx_sqrt");
if (mlx_sqrt_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sqrt\n");
return -1;
}
- mlx_square_ptr = dlsym(handle, "mlx_square");
+ mlx_square_ptr = GET_SYM(handle, "mlx_square");
if (mlx_square_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_square\n");
return -1;
}
- mlx_squeeze_axes_ptr = dlsym(handle, "mlx_squeeze_axes");
+ mlx_squeeze_axes_ptr = GET_SYM(handle, "mlx_squeeze_axes");
if (mlx_squeeze_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axes\n");
return -1;
}
- mlx_squeeze_axis_ptr = dlsym(handle, "mlx_squeeze_axis");
+ mlx_squeeze_axis_ptr = GET_SYM(handle, "mlx_squeeze_axis");
if (mlx_squeeze_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axis\n");
return -1;
}
- mlx_squeeze_ptr = dlsym(handle, "mlx_squeeze");
+ mlx_squeeze_ptr = GET_SYM(handle, "mlx_squeeze");
if (mlx_squeeze_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze\n");
return -1;
}
- mlx_stack_axis_ptr = dlsym(handle, "mlx_stack_axis");
+ mlx_stack_axis_ptr = GET_SYM(handle, "mlx_stack_axis");
if (mlx_stack_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stack_axis\n");
return -1;
}
- mlx_stack_ptr = dlsym(handle, "mlx_stack");
+ mlx_stack_ptr = GET_SYM(handle, "mlx_stack");
if (mlx_stack_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stack\n");
return -1;
}
- mlx_std_axes_ptr = dlsym(handle, "mlx_std_axes");
+ mlx_std_axes_ptr = GET_SYM(handle, "mlx_std_axes");
if (mlx_std_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axes\n");
return -1;
}
- mlx_std_axis_ptr = dlsym(handle, "mlx_std_axis");
+ mlx_std_axis_ptr = GET_SYM(handle, "mlx_std_axis");
if (mlx_std_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axis\n");
return -1;
}
- mlx_std_ptr = dlsym(handle, "mlx_std");
+ mlx_std_ptr = GET_SYM(handle, "mlx_std");
if (mlx_std_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_std\n");
return -1;
}
- mlx_stop_gradient_ptr = dlsym(handle, "mlx_stop_gradient");
+ mlx_stop_gradient_ptr = GET_SYM(handle, "mlx_stop_gradient");
if (mlx_stop_gradient_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stop_gradient\n");
return -1;
}
- mlx_subtract_ptr = dlsym(handle, "mlx_subtract");
+ mlx_subtract_ptr = GET_SYM(handle, "mlx_subtract");
if (mlx_subtract_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_subtract\n");
return -1;
}
- mlx_sum_axes_ptr = dlsym(handle, "mlx_sum_axes");
+ mlx_sum_axes_ptr = GET_SYM(handle, "mlx_sum_axes");
if (mlx_sum_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axes\n");
return -1;
}
- mlx_sum_axis_ptr = dlsym(handle, "mlx_sum_axis");
+ mlx_sum_axis_ptr = GET_SYM(handle, "mlx_sum_axis");
if (mlx_sum_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axis\n");
return -1;
}
- mlx_sum_ptr = dlsym(handle, "mlx_sum");
+ mlx_sum_ptr = GET_SYM(handle, "mlx_sum");
if (mlx_sum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum\n");
return -1;
}
- mlx_swapaxes_ptr = dlsym(handle, "mlx_swapaxes");
+ mlx_swapaxes_ptr = GET_SYM(handle, "mlx_swapaxes");
if (mlx_swapaxes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_swapaxes\n");
return -1;
}
- mlx_take_axis_ptr = dlsym(handle, "mlx_take_axis");
+ mlx_take_axis_ptr = GET_SYM(handle, "mlx_take_axis");
if (mlx_take_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_take_axis\n");
return -1;
}
- mlx_take_ptr = dlsym(handle, "mlx_take");
+ mlx_take_ptr = GET_SYM(handle, "mlx_take");
if (mlx_take_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_take\n");
return -1;
}
- mlx_take_along_axis_ptr = dlsym(handle, "mlx_take_along_axis");
+ mlx_take_along_axis_ptr = GET_SYM(handle, "mlx_take_along_axis");
if (mlx_take_along_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_take_along_axis\n");
return -1;
}
- mlx_tan_ptr = dlsym(handle, "mlx_tan");
+ mlx_tan_ptr = GET_SYM(handle, "mlx_tan");
if (mlx_tan_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tan\n");
return -1;
}
- mlx_tanh_ptr = dlsym(handle, "mlx_tanh");
+ mlx_tanh_ptr = GET_SYM(handle, "mlx_tanh");
if (mlx_tanh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tanh\n");
return -1;
}
- mlx_tensordot_ptr = dlsym(handle, "mlx_tensordot");
+ mlx_tensordot_ptr = GET_SYM(handle, "mlx_tensordot");
if (mlx_tensordot_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot\n");
return -1;
}
- mlx_tensordot_axis_ptr = dlsym(handle, "mlx_tensordot_axis");
+ mlx_tensordot_axis_ptr = GET_SYM(handle, "mlx_tensordot_axis");
if (mlx_tensordot_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot_axis\n");
return -1;
}
- mlx_tile_ptr = dlsym(handle, "mlx_tile");
+ mlx_tile_ptr = GET_SYM(handle, "mlx_tile");
if (mlx_tile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tile\n");
return -1;
}
- mlx_to_fp8_ptr = dlsym(handle, "mlx_to_fp8");
+ mlx_to_fp8_ptr = GET_SYM(handle, "mlx_to_fp8");
if (mlx_to_fp8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_to_fp8\n");
return -1;
}
- mlx_topk_axis_ptr = dlsym(handle, "mlx_topk_axis");
+ mlx_topk_axis_ptr = GET_SYM(handle, "mlx_topk_axis");
if (mlx_topk_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_topk_axis\n");
return -1;
}
- mlx_topk_ptr = dlsym(handle, "mlx_topk");
+ mlx_topk_ptr = GET_SYM(handle, "mlx_topk");
if (mlx_topk_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_topk\n");
return -1;
}
- mlx_trace_ptr = dlsym(handle, "mlx_trace");
+ mlx_trace_ptr = GET_SYM(handle, "mlx_trace");
if (mlx_trace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_trace\n");
return -1;
}
- mlx_transpose_axes_ptr = dlsym(handle, "mlx_transpose_axes");
+ mlx_transpose_axes_ptr = GET_SYM(handle, "mlx_transpose_axes");
if (mlx_transpose_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose_axes\n");
return -1;
}
- mlx_transpose_ptr = dlsym(handle, "mlx_transpose");
+ mlx_transpose_ptr = GET_SYM(handle, "mlx_transpose");
if (mlx_transpose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose\n");
return -1;
}
- mlx_tri_ptr = dlsym(handle, "mlx_tri");
+ mlx_tri_ptr = GET_SYM(handle, "mlx_tri");
if (mlx_tri_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tri\n");
return -1;
}
- mlx_tril_ptr = dlsym(handle, "mlx_tril");
+ mlx_tril_ptr = GET_SYM(handle, "mlx_tril");
if (mlx_tril_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tril\n");
return -1;
}
- mlx_triu_ptr = dlsym(handle, "mlx_triu");
+ mlx_triu_ptr = GET_SYM(handle, "mlx_triu");
if (mlx_triu_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_triu\n");
return -1;
}
- mlx_unflatten_ptr = dlsym(handle, "mlx_unflatten");
+ mlx_unflatten_ptr = GET_SYM(handle, "mlx_unflatten");
if (mlx_unflatten_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_unflatten\n");
return -1;
}
- mlx_var_axes_ptr = dlsym(handle, "mlx_var_axes");
+ mlx_var_axes_ptr = GET_SYM(handle, "mlx_var_axes");
if (mlx_var_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axes\n");
return -1;
}
- mlx_var_axis_ptr = dlsym(handle, "mlx_var_axis");
+ mlx_var_axis_ptr = GET_SYM(handle, "mlx_var_axis");
if (mlx_var_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axis\n");
return -1;
}
- mlx_var_ptr = dlsym(handle, "mlx_var");
+ mlx_var_ptr = GET_SYM(handle, "mlx_var");
if (mlx_var_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_var\n");
return -1;
}
- mlx_view_ptr = dlsym(handle, "mlx_view");
+ mlx_view_ptr = GET_SYM(handle, "mlx_view");
if (mlx_view_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_view\n");
return -1;
}
- mlx_where_ptr = dlsym(handle, "mlx_where");
+ mlx_where_ptr = GET_SYM(handle, "mlx_where");
if (mlx_where_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_where\n");
return -1;
}
- mlx_zeros_ptr = dlsym(handle, "mlx_zeros");
+ mlx_zeros_ptr = GET_SYM(handle, "mlx_zeros");
if (mlx_zeros_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros\n");
return -1;
}
- mlx_zeros_like_ptr = dlsym(handle, "mlx_zeros_like");
+ mlx_zeros_like_ptr = GET_SYM(handle, "mlx_zeros_like");
if (mlx_zeros_like_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros_like\n");
return -1;
}
- mlx_random_bernoulli_ptr = dlsym(handle, "mlx_random_bernoulli");
+ mlx_random_bernoulli_ptr = GET_SYM(handle, "mlx_random_bernoulli");
if (mlx_random_bernoulli_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bernoulli\n");
return -1;
}
- mlx_random_bits_ptr = dlsym(handle, "mlx_random_bits");
+ mlx_random_bits_ptr = GET_SYM(handle, "mlx_random_bits");
if (mlx_random_bits_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bits\n");
return -1;
}
- mlx_random_categorical_shape_ptr = dlsym(handle, "mlx_random_categorical_shape");
+ mlx_random_categorical_shape_ptr = GET_SYM(handle, "mlx_random_categorical_shape");
if (mlx_random_categorical_shape_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_shape\n");
return -1;
}
- mlx_random_categorical_num_samples_ptr = dlsym(handle, "mlx_random_categorical_num_samples");
+ mlx_random_categorical_num_samples_ptr = GET_SYM(handle, "mlx_random_categorical_num_samples");
if (mlx_random_categorical_num_samples_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_num_samples\n");
return -1;
}
- mlx_random_categorical_ptr = dlsym(handle, "mlx_random_categorical");
+ mlx_random_categorical_ptr = GET_SYM(handle, "mlx_random_categorical");
if (mlx_random_categorical_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical\n");
return -1;
}
- mlx_random_gumbel_ptr = dlsym(handle, "mlx_random_gumbel");
+ mlx_random_gumbel_ptr = GET_SYM(handle, "mlx_random_gumbel");
if (mlx_random_gumbel_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_gumbel\n");
return -1;
}
- mlx_random_key_ptr = dlsym(handle, "mlx_random_key");
+ mlx_random_key_ptr = GET_SYM(handle, "mlx_random_key");
if (mlx_random_key_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_key\n");
return -1;
}
- mlx_random_laplace_ptr = dlsym(handle, "mlx_random_laplace");
+ mlx_random_laplace_ptr = GET_SYM(handle, "mlx_random_laplace");
if (mlx_random_laplace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_laplace\n");
return -1;
}
- mlx_random_multivariate_normal_ptr = dlsym(handle, "mlx_random_multivariate_normal");
+ mlx_random_multivariate_normal_ptr = GET_SYM(handle, "mlx_random_multivariate_normal");
if (mlx_random_multivariate_normal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_multivariate_normal\n");
return -1;
}
- mlx_random_normal_broadcast_ptr = dlsym(handle, "mlx_random_normal_broadcast");
+ mlx_random_normal_broadcast_ptr = GET_SYM(handle, "mlx_random_normal_broadcast");
if (mlx_random_normal_broadcast_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal_broadcast\n");
return -1;
}
- mlx_random_normal_ptr = dlsym(handle, "mlx_random_normal");
+ mlx_random_normal_ptr = GET_SYM(handle, "mlx_random_normal");
if (mlx_random_normal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal\n");
return -1;
}
- mlx_random_permutation_ptr = dlsym(handle, "mlx_random_permutation");
+ mlx_random_permutation_ptr = GET_SYM(handle, "mlx_random_permutation");
if (mlx_random_permutation_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation\n");
return -1;
}
- mlx_random_permutation_arange_ptr = dlsym(handle, "mlx_random_permutation_arange");
+ mlx_random_permutation_arange_ptr = GET_SYM(handle, "mlx_random_permutation_arange");
if (mlx_random_permutation_arange_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation_arange\n");
return -1;
}
- mlx_random_randint_ptr = dlsym(handle, "mlx_random_randint");
+ mlx_random_randint_ptr = GET_SYM(handle, "mlx_random_randint");
if (mlx_random_randint_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_randint\n");
return -1;
}
- mlx_random_seed_ptr = dlsym(handle, "mlx_random_seed");
+ mlx_random_seed_ptr = GET_SYM(handle, "mlx_random_seed");
if (mlx_random_seed_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_seed\n");
return -1;
}
- mlx_random_split_num_ptr = dlsym(handle, "mlx_random_split_num");
+ mlx_random_split_num_ptr = GET_SYM(handle, "mlx_random_split_num");
if (mlx_random_split_num_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split_num\n");
return -1;
}
- mlx_random_split_ptr = dlsym(handle, "mlx_random_split");
+ mlx_random_split_ptr = GET_SYM(handle, "mlx_random_split");
if (mlx_random_split_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split\n");
return -1;
}
- mlx_random_truncated_normal_ptr = dlsym(handle, "mlx_random_truncated_normal");
+ mlx_random_truncated_normal_ptr = GET_SYM(handle, "mlx_random_truncated_normal");
if (mlx_random_truncated_normal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_truncated_normal\n");
return -1;
}
- mlx_random_uniform_ptr = dlsym(handle, "mlx_random_uniform");
+ mlx_random_uniform_ptr = GET_SYM(handle, "mlx_random_uniform");
if (mlx_random_uniform_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_uniform\n");
return -1;
}
- mlx_stream_new_ptr = dlsym(handle, "mlx_stream_new");
+ mlx_stream_new_ptr = GET_SYM(handle, "mlx_stream_new");
if (mlx_stream_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new\n");
return -1;
}
- mlx_stream_new_device_ptr = dlsym(handle, "mlx_stream_new_device");
+ mlx_stream_new_device_ptr = GET_SYM(handle, "mlx_stream_new_device");
if (mlx_stream_new_device_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new_device\n");
return -1;
}
- mlx_stream_set_ptr = dlsym(handle, "mlx_stream_set");
+ mlx_stream_set_ptr = GET_SYM(handle, "mlx_stream_set");
if (mlx_stream_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_set\n");
return -1;
}
- mlx_stream_free_ptr = dlsym(handle, "mlx_stream_free");
+ mlx_stream_free_ptr = GET_SYM(handle, "mlx_stream_free");
if (mlx_stream_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_free\n");
return -1;
}
- mlx_stream_tostring_ptr = dlsym(handle, "mlx_stream_tostring");
+ mlx_stream_tostring_ptr = GET_SYM(handle, "mlx_stream_tostring");
if (mlx_stream_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_tostring\n");
return -1;
}
- mlx_stream_equal_ptr = dlsym(handle, "mlx_stream_equal");
+ mlx_stream_equal_ptr = GET_SYM(handle, "mlx_stream_equal");
if (mlx_stream_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_equal\n");
return -1;
}
- mlx_stream_get_device_ptr = dlsym(handle, "mlx_stream_get_device");
+ mlx_stream_get_device_ptr = GET_SYM(handle, "mlx_stream_get_device");
if (mlx_stream_get_device_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_device\n");
return -1;
}
- mlx_stream_get_index_ptr = dlsym(handle, "mlx_stream_get_index");
+ mlx_stream_get_index_ptr = GET_SYM(handle, "mlx_stream_get_index");
if (mlx_stream_get_index_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_index\n");
return -1;
}
- mlx_synchronize_ptr = dlsym(handle, "mlx_synchronize");
+ mlx_synchronize_ptr = GET_SYM(handle, "mlx_synchronize");
if (mlx_synchronize_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_synchronize\n");
return -1;
}
- mlx_get_default_stream_ptr = dlsym(handle, "mlx_get_default_stream");
+ mlx_get_default_stream_ptr = GET_SYM(handle, "mlx_get_default_stream");
if (mlx_get_default_stream_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_stream\n");
return -1;
}
- mlx_set_default_stream_ptr = dlsym(handle, "mlx_set_default_stream");
+ mlx_set_default_stream_ptr = GET_SYM(handle, "mlx_set_default_stream");
if (mlx_set_default_stream_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_stream\n");
return -1;
}
- mlx_default_cpu_stream_new_ptr = dlsym(handle, "mlx_default_cpu_stream_new");
+ mlx_default_cpu_stream_new_ptr = GET_SYM(handle, "mlx_default_cpu_stream_new");
if (mlx_default_cpu_stream_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_default_cpu_stream_new\n");
return -1;
}
- mlx_default_gpu_stream_new_ptr = dlsym(handle, "mlx_default_gpu_stream_new");
+ mlx_default_gpu_stream_new_ptr = GET_SYM(handle, "mlx_default_gpu_stream_new");
if (mlx_default_gpu_stream_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_default_gpu_stream_new\n");
return -1;
}
- mlx_string_new_ptr = dlsym(handle, "mlx_string_new");
+ mlx_string_new_ptr = GET_SYM(handle, "mlx_string_new");
if (mlx_string_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new\n");
return -1;
}
- mlx_string_new_data_ptr = dlsym(handle, "mlx_string_new_data");
+ mlx_string_new_data_ptr = GET_SYM(handle, "mlx_string_new_data");
if (mlx_string_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new_data\n");
return -1;
}
- mlx_string_set_ptr = dlsym(handle, "mlx_string_set");
+ mlx_string_set_ptr = GET_SYM(handle, "mlx_string_set");
if (mlx_string_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_set\n");
return -1;
}
- mlx_string_data_ptr = dlsym(handle, "mlx_string_data");
+ mlx_string_data_ptr = GET_SYM(handle, "mlx_string_data");
if (mlx_string_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_data\n");
return -1;
}
- mlx_string_free_ptr = dlsym(handle, "mlx_string_free");
+ mlx_string_free_ptr = GET_SYM(handle, "mlx_string_free");
if (mlx_string_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_free\n");
return -1;
}
- mlx_async_eval_ptr = dlsym(handle, "mlx_async_eval");
+ mlx_async_eval_ptr = GET_SYM(handle, "mlx_async_eval");
if (mlx_async_eval_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_async_eval\n");
return -1;
}
- mlx_checkpoint_ptr = dlsym(handle, "mlx_checkpoint");
+ mlx_checkpoint_ptr = GET_SYM(handle, "mlx_checkpoint");
if (mlx_checkpoint_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_checkpoint\n");
return -1;
}
- mlx_custom_function_ptr = dlsym(handle, "mlx_custom_function");
+ mlx_custom_function_ptr = GET_SYM(handle, "mlx_custom_function");
if (mlx_custom_function_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_function\n");
return -1;
}
- mlx_custom_vjp_ptr = dlsym(handle, "mlx_custom_vjp");
+ mlx_custom_vjp_ptr = GET_SYM(handle, "mlx_custom_vjp");
if (mlx_custom_vjp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_vjp\n");
return -1;
}
- mlx_eval_ptr = dlsym(handle, "mlx_eval");
+ mlx_eval_ptr = GET_SYM(handle, "mlx_eval");
if (mlx_eval_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_eval\n");
return -1;
}
- mlx_jvp_ptr = dlsym(handle, "mlx_jvp");
+ mlx_jvp_ptr = GET_SYM(handle, "mlx_jvp");
if (mlx_jvp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_jvp\n");
return -1;
}
- mlx_value_and_grad_ptr = dlsym(handle, "mlx_value_and_grad");
+ mlx_value_and_grad_ptr = GET_SYM(handle, "mlx_value_and_grad");
if (mlx_value_and_grad_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_value_and_grad\n");
return -1;
}
- mlx_vjp_ptr = dlsym(handle, "mlx_vjp");
+ mlx_vjp_ptr = GET_SYM(handle, "mlx_vjp");
if (mlx_vjp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vjp\n");
return -1;
}
- mlx_detail_vmap_replace_ptr = dlsym(handle, "mlx_detail_vmap_replace");
+ mlx_detail_vmap_replace_ptr = GET_SYM(handle, "mlx_detail_vmap_replace");
if (mlx_detail_vmap_replace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_replace\n");
return -1;
}
- mlx_detail_vmap_trace_ptr = dlsym(handle, "mlx_detail_vmap_trace");
+ mlx_detail_vmap_trace_ptr = GET_SYM(handle, "mlx_detail_vmap_trace");
if (mlx_detail_vmap_trace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_trace\n");
return -1;
}
- mlx_vector_array_new_ptr = dlsym(handle, "mlx_vector_array_new");
+ mlx_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_array_new");
if (mlx_vector_array_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new\n");
return -1;
}
- mlx_vector_array_set_ptr = dlsym(handle, "mlx_vector_array_set");
+ mlx_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_array_set");
if (mlx_vector_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set\n");
return -1;
}
- mlx_vector_array_free_ptr = dlsym(handle, "mlx_vector_array_free");
+ mlx_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_array_free");
if (mlx_vector_array_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_free\n");
return -1;
}
- mlx_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_array_new_data");
+ mlx_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_array_new_data");
if (mlx_vector_array_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_data\n");
return -1;
}
- mlx_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_array_new_value");
+ mlx_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_array_new_value");
if (mlx_vector_array_new_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_value\n");
return -1;
}
- mlx_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_array_set_data");
+ mlx_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_array_set_data");
if (mlx_vector_array_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_data\n");
return -1;
}
- mlx_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_array_set_value");
+ mlx_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_array_set_value");
if (mlx_vector_array_set_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_value\n");
return -1;
}
- mlx_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_array_append_data");
+ mlx_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_array_append_data");
if (mlx_vector_array_append_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_data\n");
return -1;
}
- mlx_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_array_append_value");
+ mlx_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_array_append_value");
if (mlx_vector_array_append_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_value\n");
return -1;
}
- mlx_vector_array_size_ptr = dlsym(handle, "mlx_vector_array_size");
+ mlx_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_array_size");
if (mlx_vector_array_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_size\n");
return -1;
}
- mlx_vector_array_get_ptr = dlsym(handle, "mlx_vector_array_get");
+ mlx_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_array_get");
if (mlx_vector_array_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_get\n");
return -1;
}
- mlx_vector_vector_array_new_ptr = dlsym(handle, "mlx_vector_vector_array_new");
+ mlx_vector_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_vector_array_new");
if (mlx_vector_vector_array_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new\n");
return -1;
}
- mlx_vector_vector_array_set_ptr = dlsym(handle, "mlx_vector_vector_array_set");
+ mlx_vector_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_vector_array_set");
if (mlx_vector_vector_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set\n");
return -1;
}
- mlx_vector_vector_array_free_ptr = dlsym(handle, "mlx_vector_vector_array_free");
+ mlx_vector_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_vector_array_free");
if (mlx_vector_vector_array_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_free\n");
return -1;
}
- mlx_vector_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_vector_array_new_data");
+ mlx_vector_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_data");
if (mlx_vector_vector_array_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_data\n");
return -1;
}
- mlx_vector_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_vector_array_new_value");
+ mlx_vector_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_value");
if (mlx_vector_vector_array_new_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_value\n");
return -1;
}
- mlx_vector_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_vector_array_set_data");
+ mlx_vector_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_data");
if (mlx_vector_vector_array_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_data\n");
return -1;
}
- mlx_vector_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_vector_array_set_value");
+ mlx_vector_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_value");
if (mlx_vector_vector_array_set_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_value\n");
return -1;
}
- mlx_vector_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_vector_array_append_data");
+ mlx_vector_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_data");
if (mlx_vector_vector_array_append_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_data\n");
return -1;
}
- mlx_vector_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_vector_array_append_value");
+ mlx_vector_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_value");
if (mlx_vector_vector_array_append_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_value\n");
return -1;
}
- mlx_vector_vector_array_size_ptr = dlsym(handle, "mlx_vector_vector_array_size");
+ mlx_vector_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_vector_array_size");
if (mlx_vector_vector_array_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_size\n");
return -1;
}
- mlx_vector_vector_array_get_ptr = dlsym(handle, "mlx_vector_vector_array_get");
+ mlx_vector_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_vector_array_get");
if (mlx_vector_vector_array_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_get\n");
return -1;
}
- mlx_vector_int_new_ptr = dlsym(handle, "mlx_vector_int_new");
+ mlx_vector_int_new_ptr = GET_SYM(handle, "mlx_vector_int_new");
if (mlx_vector_int_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new\n");
return -1;
}
- mlx_vector_int_set_ptr = dlsym(handle, "mlx_vector_int_set");
+ mlx_vector_int_set_ptr = GET_SYM(handle, "mlx_vector_int_set");
if (mlx_vector_int_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set\n");
return -1;
}
- mlx_vector_int_free_ptr = dlsym(handle, "mlx_vector_int_free");
+ mlx_vector_int_free_ptr = GET_SYM(handle, "mlx_vector_int_free");
if (mlx_vector_int_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_free\n");
return -1;
}
- mlx_vector_int_new_data_ptr = dlsym(handle, "mlx_vector_int_new_data");
+ mlx_vector_int_new_data_ptr = GET_SYM(handle, "mlx_vector_int_new_data");
if (mlx_vector_int_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_data\n");
return -1;
}
- mlx_vector_int_new_value_ptr = dlsym(handle, "mlx_vector_int_new_value");
+ mlx_vector_int_new_value_ptr = GET_SYM(handle, "mlx_vector_int_new_value");
if (mlx_vector_int_new_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_value\n");
return -1;
}
- mlx_vector_int_set_data_ptr = dlsym(handle, "mlx_vector_int_set_data");
+ mlx_vector_int_set_data_ptr = GET_SYM(handle, "mlx_vector_int_set_data");
if (mlx_vector_int_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_data\n");
return -1;
}
- mlx_vector_int_set_value_ptr = dlsym(handle, "mlx_vector_int_set_value");
+ mlx_vector_int_set_value_ptr = GET_SYM(handle, "mlx_vector_int_set_value");
if (mlx_vector_int_set_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_value\n");
return -1;
}
- mlx_vector_int_append_data_ptr = dlsym(handle, "mlx_vector_int_append_data");
+ mlx_vector_int_append_data_ptr = GET_SYM(handle, "mlx_vector_int_append_data");
if (mlx_vector_int_append_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_data\n");
return -1;
}
- mlx_vector_int_append_value_ptr = dlsym(handle, "mlx_vector_int_append_value");
+ mlx_vector_int_append_value_ptr = GET_SYM(handle, "mlx_vector_int_append_value");
if (mlx_vector_int_append_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_value\n");
return -1;
}
- mlx_vector_int_size_ptr = dlsym(handle, "mlx_vector_int_size");
+ mlx_vector_int_size_ptr = GET_SYM(handle, "mlx_vector_int_size");
if (mlx_vector_int_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_size\n");
return -1;
}
- mlx_vector_int_get_ptr = dlsym(handle, "mlx_vector_int_get");
+ mlx_vector_int_get_ptr = GET_SYM(handle, "mlx_vector_int_get");
if (mlx_vector_int_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_get\n");
return -1;
}
- mlx_vector_string_new_ptr = dlsym(handle, "mlx_vector_string_new");
+ mlx_vector_string_new_ptr = GET_SYM(handle, "mlx_vector_string_new");
if (mlx_vector_string_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new\n");
return -1;
}
- mlx_vector_string_set_ptr = dlsym(handle, "mlx_vector_string_set");
+ mlx_vector_string_set_ptr = GET_SYM(handle, "mlx_vector_string_set");
if (mlx_vector_string_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set\n");
return -1;
}
- mlx_vector_string_free_ptr = dlsym(handle, "mlx_vector_string_free");
+ mlx_vector_string_free_ptr = GET_SYM(handle, "mlx_vector_string_free");
if (mlx_vector_string_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_free\n");
return -1;
}
- mlx_vector_string_new_data_ptr = dlsym(handle, "mlx_vector_string_new_data");
+ mlx_vector_string_new_data_ptr = GET_SYM(handle, "mlx_vector_string_new_data");
if (mlx_vector_string_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_data\n");
return -1;
}
- mlx_vector_string_new_value_ptr = dlsym(handle, "mlx_vector_string_new_value");
+ mlx_vector_string_new_value_ptr = GET_SYM(handle, "mlx_vector_string_new_value");
if (mlx_vector_string_new_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_value\n");
return -1;
}
- mlx_vector_string_set_data_ptr = dlsym(handle, "mlx_vector_string_set_data");
+ mlx_vector_string_set_data_ptr = GET_SYM(handle, "mlx_vector_string_set_data");
if (mlx_vector_string_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_data\n");
return -1;
}
- mlx_vector_string_set_value_ptr = dlsym(handle, "mlx_vector_string_set_value");
+ mlx_vector_string_set_value_ptr = GET_SYM(handle, "mlx_vector_string_set_value");
if (mlx_vector_string_set_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_value\n");
return -1;
}
- mlx_vector_string_append_data_ptr = dlsym(handle, "mlx_vector_string_append_data");
+ mlx_vector_string_append_data_ptr = GET_SYM(handle, "mlx_vector_string_append_data");
if (mlx_vector_string_append_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_data\n");
return -1;
}
- mlx_vector_string_append_value_ptr = dlsym(handle, "mlx_vector_string_append_value");
+ mlx_vector_string_append_value_ptr = GET_SYM(handle, "mlx_vector_string_append_value");
if (mlx_vector_string_append_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_value\n");
return -1;
}
- mlx_vector_string_size_ptr = dlsym(handle, "mlx_vector_string_size");
+ mlx_vector_string_size_ptr = GET_SYM(handle, "mlx_vector_string_size");
if (mlx_vector_string_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_size\n");
return -1;
}
- mlx_vector_string_get_ptr = dlsym(handle, "mlx_vector_string_get");
+ mlx_vector_string_get_ptr = GET_SYM(handle, "mlx_vector_string_get");
if (mlx_vector_string_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_get\n");
return -1;
}
- mlx_version_ptr = dlsym(handle, "mlx_version");
+ mlx_version_ptr = GET_SYM(handle, "mlx_version");
if (mlx_version_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_version\n");
return -1;
diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go
index cf3e51572..b529b9088 100644
--- a/x/imagegen/mlx/mlx.go
+++ b/x/imagegen/mlx/mlx.go
@@ -1,9 +1,7 @@
-//go:build mlx
-
package mlx
/*
-#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR}
+#cgo CFLAGS: -O3 -I${SRCDIR}/../../mlxrunner/mlx/include -I${SRCDIR}
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
#cgo linux LDFLAGS: -lstdc++ -ldl
#cgo windows LDFLAGS: -lstdc++
@@ -32,7 +30,7 @@ static inline void set_default_stream(mlx_stream s) {
_default_stream = s;
}
-// CPU stream for file loading (Load primitive only runs on CPU)
+// CPU stream for operations that only support CPU evaluation
static inline mlx_stream cpu_stream() {
if (_cpu_stream.ctx == NULL) {
_cpu_stream = mlx_default_cpu_stream_new();
@@ -45,8 +43,11 @@ static inline mlx_stream cpu_stream() {
// nocallback: function won't call back into Go
*/
import "C"
+
import (
"fmt"
+ "os"
+ "path/filepath"
"reflect"
"runtime"
"sync"
@@ -1502,15 +1503,21 @@ type SafetensorsFile struct {
metadata C.mlx_map_string_to_string
}
-// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader
-// Note: Uses CPU stream because Load primitive only runs on CPU
+// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader.
+// On CUDA, Load::eval_gpu is implemented so we use the default (GPU) stream.
+// On Metal, Load::eval_gpu is not implemented so we must use the CPU stream.
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
+ stream := C.default_stream()
+ if runtime.GOOS == "darwin" {
+ stream = C.cpu_stream()
+ }
+
var arrays C.mlx_map_string_to_array
var metadata C.mlx_map_string_to_string
- if C.mlx_load_safetensors(&arrays, &metadata, cPath, C.cpu_stream()) != 0 {
+ if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
return nil, fmt.Errorf("failed to load safetensors: %s", path)
}
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
@@ -1689,11 +1696,91 @@ func ArgmaxKeepArray(logits *Array) *Array {
//
// Thread safety: Protected by randomStateMu, mimicking Python's GIL behavior.
// All random functions that use global state acquire this lock.
-var RandomState = []*Array{nil}
-var randomStateMu sync.Mutex
+var (
+ RandomState = []*Array{nil}
+ randomStateMu sync.Mutex
+)
-var mlxInitialized bool
-var mlxInitError error
+var (
+ mlxInitialized bool
+ mlxInitError error
+)
+
+// mlxLibName returns the platform-specific shared library filename.
+func mlxLibName() string {
+ switch runtime.GOOS {
+ case "windows":
+ return "mlxc.dll"
+ case "darwin":
+ return "libmlxc.dylib"
+ default:
+ return "libmlxc.so"
+ }
+}
+
+// findMLXLibrary searches for the MLX shared library in standard locations.
+// Returns the path to the library, or empty string if not found.
+func findMLXLibrary() string {
+ libName := mlxLibName()
+
+ // 1. OLLAMA_LIBRARY_PATH — check each dir and mlx_* subdirs
+ if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
+ for _, dir := range filepath.SplitList(paths) {
+ candidate := filepath.Join(dir, libName)
+ if _, err := os.Stat(candidate); err == nil {
+ return candidate
+ }
+ if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx*")); err == nil {
+ for _, mlxDir := range mlxDirs {
+ candidate = filepath.Join(mlxDir, libName)
+ if _, err := os.Stat(candidate); err == nil {
+ return candidate
+ }
+ }
+ }
+ }
+ }
+
+ // 2. Executable directory and lib/ollama/mlx* subdirs
+ if exe, err := os.Executable(); err == nil {
+ if eval, err := filepath.EvalSymlinks(exe); err == nil {
+ exe = eval
+ }
+ exeDir := filepath.Dir(exe)
+
+ // Check exe dir directly (macOS copies dylib here)
+ candidate := filepath.Join(exeDir, libName)
+ if _, err := os.Stat(candidate); err == nil {
+ return candidate
+ }
+
+ // Check exe_dir/lib/ollama/mlx* subdirectories
+ // and exe_dir/../lib/ollama/mlx* (standard bin/lib sibling layout)
+ for _, libOllamaDir := range []string{
+ filepath.Join(exeDir, "lib", "ollama"),
+ filepath.Join(exeDir, "..", "lib", "ollama"),
+ } {
+ if mlxDirs, err := filepath.Glob(filepath.Join(libOllamaDir, "mlx*")); err == nil {
+ for _, mlxDir := range mlxDirs {
+ candidate = filepath.Join(mlxDir, libName)
+ if _, err := os.Stat(candidate); err == nil {
+ return candidate
+ }
+ }
+ }
+ }
+ }
+
+ // 3. Build directory (for tests run from repo root)
+ if cwd, err := os.Getwd(); err == nil {
+ candidate := filepath.Join(cwd, "build", "lib", "ollama", libName)
+ if _, err := os.Stat(candidate); err == nil {
+ return candidate
+ }
+ }
+
+ return ""
+}
// InitMLX initializes the MLX library by dynamically loading libmlxc.
// This must be called before using any MLX functions.
@@ -1703,9 +1790,16 @@ func InitMLX() error {
return mlxInitError
}
- // Try to load the MLX dynamic library
- ret := C.mlx_dynamic_init()
- if ret != 0 {
+ // Search for the library using Go path discovery
+ libPath := findMLXLibrary()
+ if libPath == "" {
+ mlxInitError = fmt.Errorf("failed to initialize MLX: %s not found", mlxLibName())
+ return mlxInitError
+ }
+
+ cPath := C.CString(libPath)
+ defer C.free(unsafe.Pointer(cPath))
+ if C.mlx_dynamic_init_path(cPath) != 0 {
errMsg := C.GoString(C.mlx_dynamic_error())
mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg)
return mlxInitError
@@ -1713,8 +1807,7 @@ func InitMLX() error {
// Initialize all function pointers via dlsym
handle := C.mlx_get_handle()
- ret = C.mlx_load_functions(handle)
- if ret != 0 {
+ if C.mlx_load_functions(handle) != 0 {
mlxInitError = fmt.Errorf("failed to load MLX function symbols")
return mlxInitError
}
diff --git a/x/imagegen/mlx/mlx_dynamic.c b/x/imagegen/mlx/mlx_dynamic.c
index aedef7a01..1281ec0e3 100644
--- a/x/imagegen/mlx/mlx_dynamic.c
+++ b/x/imagegen/mlx/mlx_dynamic.c
@@ -9,114 +9,76 @@
#ifdef _WIN32
#include
typedef HMODULE lib_handle_t;
-#define LOAD_LIB(path) LoadLibraryA(path)
-#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
-#define CLOSE_LIB(handle) FreeLibrary(handle)
-#define LIB_ERROR() "LoadLibrary failed"
+static char win_error_buffer[256] = {0};
+static const char* get_win_error(void) {
+ DWORD err = GetLastError();
+ snprintf(win_error_buffer, sizeof(win_error_buffer), "error code %lu", err);
+ return win_error_buffer;
+}
+#define LIB_ERROR() get_win_error()
#else
#include
typedef void* lib_handle_t;
-#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
-#define GET_SYMBOL(handle, name) dlsym(handle, name)
-#define CLOSE_LIB(handle) dlclose(handle)
#define LIB_ERROR() dlerror()
-#ifdef __APPLE__
-#include
-#include
-#endif
#endif
static lib_handle_t mlx_handle = NULL;
static int mlx_initialized = 0;
static char mlx_error_buffer[512] = {0};
-#ifdef __APPLE__
-// Get path to library in same directory as executable
-static char* get_exe_relative_path(const char* libname) {
- static char path[1024];
- uint32_t size = sizeof(path);
- if (_NSGetExecutablePath(path, &size) != 0) {
- return NULL;
+#ifdef _WIN32
+// Windows: Load library from a path with dependency resolution.
+// Temporarily adds the library's directory to the DLL search path
+// so that dependencies (like mlx.dll) in the same directory are found.
+static int try_load_win(const char* path) {
+ if (!path) return 0;
+
+ // Extract directory and add to DLL search path for dependency resolution
+ char dir_path[MAX_PATH];
+ strncpy(dir_path, path, MAX_PATH - 1);
+ dir_path[MAX_PATH - 1] = '\0';
+ char* last_slash = strrchr(dir_path, '\\');
+ if (!last_slash) last_slash = strrchr(dir_path, '/');
+ if (last_slash) {
+ *last_slash = '\0';
+ SetDllDirectoryA(dir_path);
}
- // Get directory of executable
- char* dir = dirname(path);
- static char fullpath[1024];
- snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname);
- return fullpath;
+
+ mlx_handle = LoadLibraryA(path);
+ SetDllDirectoryA(NULL);
+ return mlx_handle != NULL;
}
#endif
// Try to load library from a specific path
static int try_load_lib(const char* path) {
if (!path) return 0;
- mlx_handle = LOAD_LIB(path);
+#ifdef _WIN32
+ return try_load_win(path);
+#else
+ mlx_handle = dlopen(path, RTLD_LAZY | RTLD_GLOBAL);
return mlx_handle != NULL;
+#endif
}
-// Initialize MLX dynamic library
-// Returns 0 on success, -1 on failure
-// On failure, call mlx_dynamic_error() to get error message
-int mlx_dynamic_init(void) {
+// Initialize the MLX dynamic library from a specific path.
+// Returns 0 on success, -1 on failure.
+int mlx_dynamic_init_path(const char* path) {
if (mlx_initialized) {
- return 0; // Already initialized
+ return 0;
}
- const char* lib_path = NULL;
- const char* tried_paths[8] = {0};
- int num_tried = 0;
-
-#ifdef _WIN32
- // Windows: try same directory as executable
- lib_path = "libmlxc.dll";
- tried_paths[num_tried++] = lib_path;
- if (try_load_lib(lib_path)) goto success;
-#elif defined(__APPLE__)
- // macOS: try executable directory first
- lib_path = get_exe_relative_path("libmlxc.dylib");
- if (lib_path) {
- tried_paths[num_tried++] = lib_path;
- if (try_load_lib(lib_path)) goto success;
+ if (try_load_lib(path)) {
+ mlx_initialized = 1;
+ snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
+ "MLX: Successfully loaded %s", path ? path : "library");
+ return 0;
}
- // Try build directory (for tests run from repo root)
- lib_path = "./build/lib/ollama/libmlxc.dylib";
- tried_paths[num_tried++] = lib_path;
- if (try_load_lib(lib_path)) goto success;
- // Fallback to system paths
- lib_path = "libmlxc.dylib";
- tried_paths[num_tried++] = lib_path;
- if (try_load_lib(lib_path)) goto success;
-#else
- // Linux: try build directory first (for tests)
- lib_path = "./build/lib/ollama/libmlxc.so";
- tried_paths[num_tried++] = lib_path;
- if (try_load_lib(lib_path)) goto success;
- // Fallback to system paths
- lib_path = "libmlxc.so";
- tried_paths[num_tried++] = lib_path;
- if (try_load_lib(lib_path)) goto success;
-#endif
- // Failed to load library - build error message with all tried paths
- {
- const char* err = LIB_ERROR();
- int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
- "MLX: Failed to load libmlxc library. Tried: ");
- for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) {
- offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
- "%s%s", i > 0 ? ", " : "", tried_paths[i]);
- }
- if (err) {
- snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
- ". Last error: %s", err);
- }
- }
- return -1;
-
-success:
- mlx_initialized = 1;
+ const char* err = LIB_ERROR();
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
- "MLX: Successfully loaded %s", lib_path ? lib_path : "library");
- return 0;
+ "MLX: Failed to load %s: %s", path ? path : "(null)", err ? err : "unknown error");
+ return -1;
}
// Get the last error message
@@ -124,21 +86,8 @@ const char* mlx_dynamic_error(void) {
return mlx_error_buffer;
}
-// Check if MLX is initialized
-int mlx_dynamic_is_initialized(void) {
- return mlx_initialized;
-}
-
// Get the library handle (for use by generated wrappers)
void* mlx_get_handle(void) {
return mlx_handle;
}
-// Cleanup (optional, called at program exit)
-void mlx_dynamic_cleanup(void) {
- if (mlx_handle != NULL) {
- CLOSE_LIB(mlx_handle);
- mlx_handle = NULL;
- mlx_initialized = 0;
- }
-}
diff --git a/x/imagegen/mlx/mlx_dynamic.h b/x/imagegen/mlx/mlx_dynamic.h
index 9ca1473f9..3f4d0fd74 100644
--- a/x/imagegen/mlx/mlx_dynamic.h
+++ b/x/imagegen/mlx/mlx_dynamic.h
@@ -6,22 +6,16 @@
extern "C" {
#endif
-// Initialize the MLX dynamic library
+// Initialize the MLX dynamic library from a specific path
// Returns 0 on success, -1 on failure
-int mlx_dynamic_init(void);
+int mlx_dynamic_init_path(const char* path);
// Get the last error message from dynamic loading
const char* mlx_dynamic_error(void);
-// Check if MLX is initialized
-int mlx_dynamic_is_initialized(void);
-
// Get the library handle (for use by generated wrappers)
void* mlx_get_handle(void);
-// Cleanup resources (optional, for clean shutdown)
-void mlx_dynamic_cleanup(void);
-
#ifdef __cplusplus
}
#endif
diff --git a/x/imagegen/mlx/mlx_test.go b/x/imagegen/mlx/mlx_test.go
index 37b3ac63b..82221cab7 100644
--- a/x/imagegen/mlx/mlx_test.go
+++ b/x/imagegen/mlx/mlx_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
import (
diff --git a/x/imagegen/models/flux2/flux2.go b/x/imagegen/models/flux2/flux2.go
index 894af41f8..908de3c87 100644
--- a/x/imagegen/models/flux2/flux2.go
+++ b/x/imagegen/models/flux2/flux2.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
// Klein is a 4B parameter distilled model that supports sub-second inference.
package flux2
diff --git a/x/imagegen/models/flux2/rope.go b/x/imagegen/models/flux2/rope.go
index c349e7010..bc245b7d8 100644
--- a/x/imagegen/models/flux2/rope.go
+++ b/x/imagegen/models/flux2/rope.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package flux2
import (
diff --git a/x/imagegen/models/flux2/scheduler.go b/x/imagegen/models/flux2/scheduler.go
index aba3c871f..033ee8c3c 100644
--- a/x/imagegen/models/flux2/scheduler.go
+++ b/x/imagegen/models/flux2/scheduler.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package flux2
import (
diff --git a/x/imagegen/models/flux2/transformer.go b/x/imagegen/models/flux2/transformer.go
index 93771a661..3a48f27b7 100644
--- a/x/imagegen/models/flux2/transformer.go
+++ b/x/imagegen/models/flux2/transformer.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package flux2
import (
diff --git a/x/imagegen/models/flux2/vae.go b/x/imagegen/models/flux2/vae.go
index 4b09b1ba4..057523d17 100644
--- a/x/imagegen/models/flux2/vae.go
+++ b/x/imagegen/models/flux2/vae.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package flux2
import (
diff --git a/x/imagegen/models/qwen3/text_encoder.go b/x/imagegen/models/qwen3/text_encoder.go
index de32bd347..46b5a5e29 100644
--- a/x/imagegen/models/qwen3/text_encoder.go
+++ b/x/imagegen/models/qwen3/text_encoder.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
package qwen3
diff --git a/x/imagegen/models/zimage/scheduler.go b/x/imagegen/models/zimage/scheduler.go
index 6c474f5a4..f5e1ccc46 100644
--- a/x/imagegen/models/zimage/scheduler.go
+++ b/x/imagegen/models/zimage/scheduler.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package zimage
import (
diff --git a/x/imagegen/models/zimage/text_encoder.go b/x/imagegen/models/zimage/text_encoder.go
index 2c2688c31..65d9ab596 100644
--- a/x/imagegen/models/zimage/text_encoder.go
+++ b/x/imagegen/models/zimage/text_encoder.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package zimage
import (
diff --git a/x/imagegen/models/zimage/transformer.go b/x/imagegen/models/zimage/transformer.go
index 2c42d8c25..b27064507 100644
--- a/x/imagegen/models/zimage/transformer.go
+++ b/x/imagegen/models/zimage/transformer.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
diff --git a/x/imagegen/models/zimage/vae.go b/x/imagegen/models/zimage/vae.go
index aca2b1bfc..a31ec210b 100644
--- a/x/imagegen/models/zimage/vae.go
+++ b/x/imagegen/models/zimage/vae.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package zimage
import (
diff --git a/x/imagegen/models/zimage/zimage.go b/x/imagegen/models/zimage/zimage.go
index e7ce8436d..4058819c7 100644
--- a/x/imagegen/models/zimage/zimage.go
+++ b/x/imagegen/models/zimage/zimage.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
diff --git a/x/imagegen/nn/nn.go b/x/imagegen/nn/nn.go
index d72474358..0a08f05be 100644
--- a/x/imagegen/nn/nn.go
+++ b/x/imagegen/nn/nn.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package nn provides neural network layer types.
package nn
diff --git a/x/imagegen/nn/nn_test.go b/x/imagegen/nn/nn_test.go
index 00e69ccb0..dbc5e2b30 100644
--- a/x/imagegen/nn/nn_test.go
+++ b/x/imagegen/nn/nn_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package nn
import (
diff --git a/x/imagegen/runner.go b/x/imagegen/runner.go
index 0409c4bf7..d92b59059 100644
--- a/x/imagegen/runner.go
+++ b/x/imagegen/runner.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
package imagegen
diff --git a/x/imagegen/runner_stub.go b/x/imagegen/runner_stub.go
deleted file mode 100644
index 866a4408c..000000000
--- a/x/imagegen/runner_stub.go
+++ /dev/null
@@ -1,10 +0,0 @@
-//go:build !mlx
-
-package imagegen
-
-import "errors"
-
-// Execute returns an error when not built with MLX support.
-func Execute(args []string) error {
- return errors.New("MLX runner not available: build with mlx tag")
-}
diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go
index fbd443e05..e6e74929b 100644
--- a/x/imagegen/safetensors/loader.go
+++ b/x/imagegen/safetensors/loader.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package safetensors
import (
diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go
index 4dbcf59a3..df7b52465 100644
--- a/x/imagegen/safetensors/safetensors.go
+++ b/x/imagegen/safetensors/safetensors.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package safetensors
import (
diff --git a/x/imagegen/safetensors/safetensors_test.go b/x/imagegen/safetensors/safetensors_test.go
index f00268751..5f3e10d71 100644
--- a/x/imagegen/safetensors/safetensors_test.go
+++ b/x/imagegen/safetensors/safetensors_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package safetensors
import (
diff --git a/x/imagegen/tokenizer/tokenizer.go b/x/imagegen/tokenizer/tokenizer.go
index bf8ff63af..d2f1aac18 100644
--- a/x/imagegen/tokenizer/tokenizer.go
+++ b/x/imagegen/tokenizer/tokenizer.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
diff --git a/x/imagegen/tokenizer/tokenizer_test.go b/x/imagegen/tokenizer/tokenizer_test.go
index 2ac79ab1e..a72e447c6 100644
--- a/x/imagegen/tokenizer/tokenizer_test.go
+++ b/x/imagegen/tokenizer/tokenizer_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import (
diff --git a/x/imagegen/vae/tiling.go b/x/imagegen/vae/tiling.go
index 1babfef98..fcb5af701 100644
--- a/x/imagegen/vae/tiling.go
+++ b/x/imagegen/vae/tiling.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
package vae
diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go
index a9ff8904c..0216ffeaa 100644
--- a/x/mlxrunner/cache.go
+++ b/x/mlxrunner/cache.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlxrunner
import (
diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go
index 7d0d0b060..a452fbcb2 100644
--- a/x/mlxrunner/cache/cache.go
+++ b/x/mlxrunner/cache/cache.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package cache
import (
diff --git a/x/mlxrunner/cache/recurrent.go b/x/mlxrunner/cache/recurrent.go
index 0cbbc01e2..86c592be5 100644
--- a/x/mlxrunner/cache/recurrent.go
+++ b/x/mlxrunner/cache/recurrent.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package cache
import "github.com/ollama/ollama/x/mlxrunner/mlx"
diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go
index f1a0e4cca..2f18105af 100644
--- a/x/mlxrunner/client.go
+++ b/x/mlxrunner/client.go
@@ -72,14 +72,23 @@ func NewClient(modelName string) (*Client, error) {
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
- // On Linux, set LD_LIBRARY_PATH to include MLX library directories
- if runtime.GOOS == "linux" {
+ // Set library path environment variable for MLX libraries
+ // Linux: LD_LIBRARY_PATH, Windows: PATH
+ var libPathEnvVar string
+ switch runtime.GOOS {
+ case "linux":
+ libPathEnvVar = "LD_LIBRARY_PATH"
+ case "windows":
+ libPathEnvVar = "PATH"
+ }
+
+ if libPathEnvVar != "" {
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
- if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
+ if existingPath, ok := os.LookupEnv(libPathEnvVar); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
@@ -87,16 +96,20 @@ func NewClient(modelName string) (*Client, error) {
found := false
for i := range cmd.Env {
- if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
- cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
+ envName := cmd.Env[i]
+ if runtime.GOOS == "windows" {
+ envName = strings.ToUpper(envName)
+ }
+ if strings.HasPrefix(envName, libPathEnvVar+"=") {
+ cmd.Env[i] = libPathEnvVar + "=" + pathEnvVal
found = true
break
}
}
if !found {
- cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
+ cmd.Env = append(cmd.Env, libPathEnvVar+"="+pathEnvVal)
}
- slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
+ slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal)
}
c := &Client{
diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go
index 9202dac34..6b6394d60 100644
--- a/x/mlxrunner/imports.go
+++ b/x/mlxrunner/imports.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlxrunner
import (
diff --git a/x/mlxrunner/mlx/CMakeLists.txt b/x/mlxrunner/mlx/CMakeLists.txt
index 1ca13bdaf..9825c441b 100644
--- a/x/mlxrunner/mlx/CMakeLists.txt
+++ b/x/mlxrunner/mlx/CMakeLists.txt
@@ -24,3 +24,7 @@ FetchContent_Declare(
)
FetchContent_MakeAvailable(mlx-c)
+
+# Sync vendored headers with fetched version
+file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
+file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/include/mlx/c/")
diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go
index 3134a127a..ce0e48eda 100644
--- a/x/mlxrunner/mlx/act.go
+++ b/x/mlxrunner/mlx/act.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go
index 6047aacec..de91813fc 100644
--- a/x/mlxrunner/mlx/array.go
+++ b/x/mlxrunner/mlx/array.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/array_test.go b/x/mlxrunner/mlx/array_test.go
index aab5db7ba..bc6a4ca4a 100644
--- a/x/mlxrunner/mlx/array_test.go
+++ b/x/mlxrunner/mlx/array_test.go
@@ -1,10 +1,16 @@
-//go:build mlx
-
package mlx
import "testing"
+func skipIfNoMLX(t *testing.T) {
+ t.Helper()
+ if err := CheckInit(); err != nil {
+ t.Skipf("MLX not available: %v", err)
+ }
+}
+
func TestFromValue(t *testing.T) {
+ skipIfNoMLX(t)
for got, want := range map[*Array]DType{
FromValue(true): DTypeBool,
FromValue(false): DTypeBool,
@@ -22,6 +28,7 @@ func TestFromValue(t *testing.T) {
}
func TestFromValues(t *testing.T) {
+ skipIfNoMLX(t)
for got, want := range map[*Array]DType{
FromValues([]bool{true, false, true}, 3): DTypeBool,
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
diff --git a/x/mlxrunner/mlx/dtype.go b/x/mlxrunner/mlx/dtype.go
index 95237c792..b0a0ce6c1 100644
--- a/x/mlxrunner/mlx/dtype.go
+++ b/x/mlxrunner/mlx/dtype.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/dynamic.go b/x/mlxrunner/mlx/dynamic.go
index a1286da59..38f825d24 100644
--- a/x/mlxrunner/mlx/dynamic.go
+++ b/x/mlxrunner/mlx/dynamic.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "dynamic.h"
@@ -24,10 +22,16 @@ func CheckInit() error {
return initError
}
-// tryLoadFromDir searches a directory for libmlxc.* and tries to load it.
+// tryLoadFromDir searches a directory for the mlxc shared library and tries to load it.
// Returns true if the library was successfully loaded.
func tryLoadFromDir(dir string) bool {
- matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*")
+ // On Windows, MSVC produces mlxc.dll (no lib prefix)
+ // On Unix, it's libmlxc.so or libmlxc.dylib
+ pattern := "libmlxc.*"
+ if runtime.GOOS == "windows" {
+ pattern = "mlxc.*"
+ }
+ matches, err := fs.Glob(os.DirFS(dir), pattern)
if err != nil || len(matches) == 0 {
return false
}
@@ -60,7 +64,10 @@ func tryLoadFromDir(dir string) bool {
// Returns true if the library was successfully loaded.
func tryLoadByName() bool {
libraryName := "libmlxc.dylib"
- if runtime.GOOS == "linux" {
+ switch runtime.GOOS {
+ case "windows":
+ libraryName = "mlxc.dll"
+ case "linux":
libraryName = "libmlxc.so"
}
@@ -81,19 +88,25 @@ func tryLoadByName() bool {
func init() {
switch runtime.GOOS {
- case "darwin":
+ case "darwin", "linux", "windows":
- case "windows":
default:
return
}
- // Try OLLAMA_LIBRARY_PATH first
+ // Try OLLAMA_LIBRARY_PATH first, including mlx_* subdirectories
if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
for _, dir := range filepath.SplitList(paths) {
if tryLoadFromDir(dir) {
return
}
+ if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*")); err == nil {
+ for _, mlxDir := range mlxDirs {
+ if tryLoadFromDir(mlxDir) {
+ return
+ }
+ }
+ }
}
}
@@ -115,12 +128,21 @@ func init() {
searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
}
+ // Also scan mlx_* subdirectories within each search dir
+ var expanded []string
for _, dir := range searchDirs {
+ expanded = append(expanded, dir)
+ if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*")); err == nil {
+ expanded = append(expanded, mlxDirs...)
+ }
+ }
+
+ for _, dir := range expanded {
if tryLoadFromDir(dir) {
return
}
}
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs)
- slog.Warn("MLX dynamic library not available", "error", initError)
+ slog.Debug("MLX dynamic library not available", "error", initError)
}
diff --git a/x/mlxrunner/mlx/dynamic.h b/x/mlxrunner/mlx/dynamic.h
index f93d8fab7..f29825ce6 100644
--- a/x/mlxrunner/mlx/dynamic.h
+++ b/x/mlxrunner/mlx/dynamic.h
@@ -3,7 +3,7 @@
#ifdef _WIN32
#include
-#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
+#define DLSYM(handle, symbol) (void*)GetProcAddress((HMODULE)(handle.ctx), symbol)
#else
#include
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
@@ -23,9 +23,15 @@ typedef uint16_t float16_t;
typedef uint16_t bfloat16_t;
#endif
-#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
-#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
-#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
+// Undef ERROR to avoid conflict with wingdi.h on Windows
+#ifdef ERROR
+#undef ERROR
+#endif
+#define MLX_ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
+#define CHECK(x) if (!(x)) { MLX_ERROR("CHECK failed: " #x); }
+#define CHECK_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x); CHECK(x##_)
+// OPTIONAL_LOAD: load symbol if available, leave function pointer NULL otherwise
+#define OPTIONAL_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x)
typedef struct {
void* ctx;
diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go
index 0570840d6..7feca3b1e 100644
--- a/x/mlxrunner/mlx/fast.go
+++ b/x/mlxrunner/mlx/fast.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go
index 7ace1f6d3..31550cef1 100644
--- a/x/mlxrunner/mlx/gated_delta.go
+++ b/x/mlxrunner/mlx/gated_delta.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include
diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c
index 29d1330af..ecf9d30c8 100644
--- a/x/mlxrunner/mlx/generated.c
+++ b/x/mlxrunner/mlx/generated.c
@@ -2299,8 +2299,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_array_item_float32);
CHECK_LOAD(handle, mlx_array_item_float64);
CHECK_LOAD(handle, mlx_array_item_complex64);
- CHECK_LOAD(handle, mlx_array_item_float16);
- CHECK_LOAD(handle, mlx_array_item_bfloat16);
+ OPTIONAL_LOAD(handle, mlx_array_item_float16);
+ OPTIONAL_LOAD(handle, mlx_array_item_bfloat16);
CHECK_LOAD(handle, mlx_array_data_bool);
CHECK_LOAD(handle, mlx_array_data_uint8);
CHECK_LOAD(handle, mlx_array_data_uint16);
@@ -2313,8 +2313,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_array_data_float32);
CHECK_LOAD(handle, mlx_array_data_float64);
CHECK_LOAD(handle, mlx_array_data_complex64);
- CHECK_LOAD(handle, mlx_array_data_float16);
- CHECK_LOAD(handle, mlx_array_data_bfloat16);
+ OPTIONAL_LOAD(handle, mlx_array_data_float16);
+ OPTIONAL_LOAD(handle, mlx_array_data_bfloat16);
CHECK_LOAD(handle, _mlx_array_is_available);
CHECK_LOAD(handle, _mlx_array_wait);
CHECK_LOAD(handle, _mlx_array_is_contiguous);
diff --git a/x/mlxrunner/mlx/generator/generated.c.gotmpl b/x/mlxrunner/mlx/generator/generated.c.gotmpl
index c31b34a76..227589aa8 100644
--- a/x/mlxrunner/mlx/generator/generated.c.gotmpl
+++ b/x/mlxrunner/mlx/generator/generated.c.gotmpl
@@ -11,7 +11,7 @@
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
{{- range .Functions }}
- CHECK_LOAD(handle, {{ .Name }});
+ {{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }});
{{- end }}
return 0;
}
diff --git a/x/mlxrunner/mlx/generator/main.go b/x/mlxrunner/mlx/generator/main.go
index a98046a2f..d1203add4 100644
--- a/x/mlxrunner/mlx/generator/main.go
+++ b/x/mlxrunner/mlx/generator/main.go
@@ -17,11 +17,21 @@ import (
//go:embed *.gotmpl
var fsys embed.FS
+// optionalSymbols lists symbols that may not be present in all builds
+// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX).
+var optionalSymbols = map[string]bool{
+ "mlx_array_item_float16": true,
+ "mlx_array_item_bfloat16": true,
+ "mlx_array_data_float16": true,
+ "mlx_array_data_bfloat16": true,
+}
+
type Function struct {
Type,
Name,
Parameters,
Args string
+ Optional bool
}
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
@@ -104,7 +114,9 @@ func main() {
matches := qc.Matches(query, tree.RootNode(), bts)
for match := matches.Next(); match != nil; match = matches.Next() {
for _, capture := range match.Captures {
- funs = append(funs, ParseFunction(&capture.Node, tc, bts))
+ fn := ParseFunction(&capture.Node, tc, bts)
+ fn.Optional = optionalSymbols[fn.Name]
+ funs = append(funs, fn)
}
}
}
diff --git a/x/mlxrunner/mlx/include/mlx/c/README.md b/x/mlxrunner/mlx/include/mlx/c/README.md
new file mode 100644
index 000000000..905ca451c
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/README.md
@@ -0,0 +1,12 @@
+# Vendored MLX-C Headers
+
+These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
+The pinned version is in `MLX_VERSION` at the repo root.
+
+Headers are automatically refreshed when you run a CMake build:
+
+```shell
+cmake --preset 'MLX CUDA 13'
+```
+
+See the [MLX Engine](../../../../../../../docs/development.md#mlx-engine-optional) section of the development docs for full build instructions.
diff --git a/x/mlxrunner/mlx/include/mlx/c/array.h b/x/mlxrunner/mlx/include/mlx/c/array.h
new file mode 100644
index 000000000..a3b382bb2
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/array.h
@@ -0,0 +1,420 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_ARRAY_H
+#define MLX_ARRAY_H
+
+#include "mlx/c/string.h"
+
+#include
+#include
+#include
+#include
+
+// Complex number support
+#ifdef _MSC_VER
+#define _CRT_USE_C_COMPLEX_H
+#include
+typedef _Fcomplex mlx_complex64_t;
+#else
+#include
+typedef float _Complex mlx_complex64_t;
+#endif
+
+#include "half.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_array Array
+ * MLX N-dimensional array object.
+ */
+/**@{*/
+
+/**
+ * A N-dimensional array object.
+ */
+typedef struct mlx_array_ {
+ void* ctx;
+} mlx_array;
+
+static mlx_array mlx_array_empty;
+
+/**
+ * Array element type.
+ */
+typedef enum mlx_dtype_ {
+ MLX_BOOL,
+ MLX_UINT8,
+ MLX_UINT16,
+ MLX_UINT32,
+ MLX_UINT64,
+ MLX_INT8,
+ MLX_INT16,
+ MLX_INT32,
+ MLX_INT64,
+ MLX_FLOAT16,
+ MLX_FLOAT32,
+ MLX_FLOAT64,
+ MLX_BFLOAT16,
+ MLX_COMPLEX64,
+} mlx_dtype;
+
+/**
+ * Size of given mlx_dtype datatype in bytes.
+ */
+size_t mlx_dtype_size(mlx_dtype dtype);
+
+/**
+ * Get array description.
+ */
+int mlx_array_tostring(mlx_string* str, const mlx_array arr);
+
+/**
+ * New empty array.
+ */
+mlx_array mlx_array_new(void);
+
+/**
+ * Free an array.
+ */
+int mlx_array_free(mlx_array arr);
+
+/**
+ * New array from a bool scalar.
+ */
+mlx_array mlx_array_new_bool(bool val);
+/**
+ * New array from a int scalar.
+ */
+mlx_array mlx_array_new_int(int val);
+/**
+ * New array from a float32 scalar.
+ */
+mlx_array mlx_array_new_float32(float val);
+/**
+ * New array from a float scalar.
+ * Same as float32.
+ */
+mlx_array mlx_array_new_float(float val);
+/**
+ * New array from a float64 scalar.
+ */
+mlx_array mlx_array_new_float64(double val);
+/**
+ * New array from a double scalar.
+ * Same as float64.
+ */
+mlx_array mlx_array_new_double(double val);
+/**
+ * New array from a complex scalar.
+ */
+mlx_array mlx_array_new_complex(float real_val, float imag_val);
+/**
+ * New array from existing buffer.
+ * @param data A buffer which will be copied.
+ * @param shape Shape of the array.
+ * @param dim Number of dimensions (size of `shape`).
+ * @param dtype Type of array elements.
+ */
+mlx_array mlx_array_new_data(
+ const void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype);
+/**
+ * New array from existing buffer.
+ * @param data A buffer which will be copied.
+ * @param shape Shape of the array.
+ * @param dim Number of dimensions (size of `shape`).
+ * @param dtype Type of array elements.
+ * @param dtor Callback for when the buffer is no longer needed.
+ */
+mlx_array mlx_array_new_data_managed(
+ void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype,
+ void (*dtor)(void*));
+/**
+ * New array from existing buffer.
+ * @param data A buffer which will be copied.
+ * @param shape Shape of the array.
+ * @param dim Number of dimensions (size of `shape`).
+ * @param dtype Type of array elements.
+ * @param payload Payload pointer passed to the `dtor` callback instead of
+ * `data`.
+ * @param dtor Callback for when the buffer is no longer needed.
+ */
+mlx_array mlx_array_new_data_managed_payload(
+ void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype,
+ void* payload,
+ void (*dtor)(void*));
+/**
+ * Set array to provided src array.
+ */
+int mlx_array_set(mlx_array* arr, const mlx_array src);
+/**
+ * Set array to a bool scalar.
+ */
+int mlx_array_set_bool(mlx_array* arr, bool val);
+/**
+ * Set array to a int scalar.
+ */
+int mlx_array_set_int(mlx_array* arr, int val);
+/**
+ * Set array to a float32 scalar.
+ */
+int mlx_array_set_float32(mlx_array* arr, float val);
+/**
+ * Set array to a float scalar.
+ */
+int mlx_array_set_float(mlx_array* arr, float val);
+/**
+ * Set array to a float64 scalar.
+ */
+int mlx_array_set_float64(mlx_array* arr, double val);
+/**
+ * Set array to a double scalar.
+ */
+int mlx_array_set_double(mlx_array* arr, double val);
+/**
+ * Set array to a complex scalar.
+ */
+int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val);
+/**
+ * Set array to specified data and shape.
+ * @param arr Destination array.
+ * @param data A buffer which will be copied.
+ * @param shape Shape of the array.
+ * @param dim Number of dimensions (size of `shape`).
+ * @param dtype Type of array elements.
+ */
+int mlx_array_set_data(
+ mlx_array* arr,
+ const void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype);
+
+/**
+ * The size of the array's datatype in bytes.
+ */
+size_t mlx_array_itemsize(const mlx_array arr);
+/**
+ * Number of elements in the array.
+ */
+size_t mlx_array_size(const mlx_array arr);
+/**
+ * The number of bytes in the array.
+ */
+size_t mlx_array_nbytes(const mlx_array arr);
+/**
+ * The array's dimension.
+ */
+size_t mlx_array_ndim(const mlx_array arr);
+/**
+ * The shape of the array.
+ * Returns: a pointer to the sizes of each dimension.
+ */
+const int* mlx_array_shape(const mlx_array arr);
+/**
+ * The strides of the array.
+ * Returns: a pointer to the sizes of each dimension.
+ */
+const size_t* mlx_array_strides(const mlx_array arr);
+/**
+ * The shape of the array in a particular dimension.
+ */
+int mlx_array_dim(const mlx_array arr, int dim);
+/**
+ * The array element type.
+ */
+mlx_dtype mlx_array_dtype(const mlx_array arr);
+
+/**
+ * Evaluate the array.
+ */
+int mlx_array_eval(mlx_array arr);
+
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_bool(bool* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_uint8(uint8_t* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_uint16(uint16_t* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_uint32(uint32_t* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_uint64(uint64_t* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_int8(int8_t* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_int16(int16_t* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_int32(int32_t* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_int64(int64_t* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_float32(float* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_float64(double* res, const mlx_array arr);
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
+
+#ifdef HAS_FLOAT16
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_float16(float16_t* res, const mlx_array arr);
+#endif
+
+#ifdef HAS_BFLOAT16
+/**
+ * Access the value of a scalar array.
+ */
+int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr);
+#endif
+
+/**
+ * Returns a pointer to the array data, cast to `bool*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const bool* mlx_array_data_bool(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `uint8_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const uint8_t* mlx_array_data_uint8(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `uint16_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const uint16_t* mlx_array_data_uint16(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `uint32_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const uint32_t* mlx_array_data_uint32(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `uint64_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const uint64_t* mlx_array_data_uint64(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `int8_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const int8_t* mlx_array_data_int8(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `int16_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const int16_t* mlx_array_data_int16(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `int32_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const int32_t* mlx_array_data_int32(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `int64_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const int64_t* mlx_array_data_int64(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `float32*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const float* mlx_array_data_float32(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `float64*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const double* mlx_array_data_float64(const mlx_array arr);
+/**
+ * Returns a pointer to the array data, cast to `_Complex*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
+
+#ifdef HAS_FLOAT16
+/**
+ * Returns a pointer to the array data, cast to `float16_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const float16_t* mlx_array_data_float16(const mlx_array arr);
+#endif
+
+#ifdef HAS_BFLOAT16
+/**
+ * Returns a pointer to the array data, cast to `bfloat16_t*`.
+ * Array must be evaluated, otherwise returns NULL.
+ */
+const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr);
+#endif
+
+/**
+ * Check if the array is available.
+ * Internal function: use at your own risk.
+ */
+int _mlx_array_is_available(bool* res, const mlx_array arr);
+
+/**
+ * Wait on the array to be available. After this `_mlx_array_is_available`
+ * returns `true`. Internal function: use at your own risk.
+ */
+int _mlx_array_wait(const mlx_array arr);
+
+/**
+ * Whether the array is contiguous in memory.
+ * Internal function: use at your own risk.
+ */
+int _mlx_array_is_contiguous(bool* res, const mlx_array arr);
+
+/**
+ * Whether the array's rows are contiguous in memory.
+ * Internal function: use at your own risk.
+ */
+int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr);
+
+/**
+ * Whether the array's columns are contiguous in memory.
+ * Internal function: use at your own risk.
+ */
+int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/closure.h b/x/mlxrunner/mlx/include/mlx/c/closure.h
new file mode 100644
index 000000000..33f711572
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/closure.h
@@ -0,0 +1,197 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_CLOSURE_H
+#define MLX_CLOSURE_H
+
+#include "mlx/c/array.h"
+#include "mlx/c/map.h"
+#include "mlx/c/optional.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_closure Closures
+ * MLX closure objects.
+ */
+/**@{*/
+
+typedef struct mlx_closure_ {
+ void* ctx;
+} mlx_closure;
+mlx_closure mlx_closure_new(void);
+int mlx_closure_free(mlx_closure cls);
+mlx_closure mlx_closure_new_func(
+ int (*fun)(mlx_vector_array*, const mlx_vector_array));
+mlx_closure mlx_closure_new_func_payload(
+ int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
+ void* payload,
+ void (*dtor)(void*));
+int mlx_closure_set(mlx_closure* cls, const mlx_closure src);
+int mlx_closure_apply(
+ mlx_vector_array* res,
+ mlx_closure cls,
+ const mlx_vector_array input);
+
+mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
+
+typedef struct mlx_closure_kwargs_ {
+ void* ctx;
+} mlx_closure_kwargs;
+mlx_closure_kwargs mlx_closure_kwargs_new(void);
+int mlx_closure_kwargs_free(mlx_closure_kwargs cls);
+mlx_closure_kwargs mlx_closure_kwargs_new_func(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_map_string_to_array));
+mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_map_string_to_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+int mlx_closure_kwargs_set(
+ mlx_closure_kwargs* cls,
+ const mlx_closure_kwargs src);
+int mlx_closure_kwargs_apply(
+ mlx_vector_array* res,
+ mlx_closure_kwargs cls,
+ const mlx_vector_array input_0,
+ const mlx_map_string_to_array input_1);
+
+typedef struct mlx_closure_value_and_grad_ {
+ void* ctx;
+} mlx_closure_value_and_grad;
+mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void);
+int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls);
+mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
+ int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
+mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_array*,
+ const mlx_vector_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+int mlx_closure_value_and_grad_set(
+ mlx_closure_value_and_grad* cls,
+ const mlx_closure_value_and_grad src);
+int mlx_closure_value_and_grad_apply(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ mlx_closure_value_and_grad cls,
+ const mlx_vector_array input);
+
+typedef struct mlx_closure_custom_ {
+ void* ctx;
+} mlx_closure_custom;
+mlx_closure_custom mlx_closure_custom_new(void);
+int mlx_closure_custom_free(mlx_closure_custom cls);
+mlx_closure_custom mlx_closure_custom_new_func(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const mlx_vector_array));
+mlx_closure_custom mlx_closure_custom_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+int mlx_closure_custom_set(
+ mlx_closure_custom* cls,
+ const mlx_closure_custom src);
+int mlx_closure_custom_apply(
+ mlx_vector_array* res,
+ mlx_closure_custom cls,
+ const mlx_vector_array input_0,
+ const mlx_vector_array input_1,
+ const mlx_vector_array input_2);
+
+typedef struct mlx_closure_custom_jvp_ {
+ void* ctx;
+} mlx_closure_custom_jvp;
+mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void);
+int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls);
+mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const int*,
+ size_t _num));
+mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const int*,
+ size_t _num,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+int mlx_closure_custom_jvp_set(
+ mlx_closure_custom_jvp* cls,
+ const mlx_closure_custom_jvp src);
+int mlx_closure_custom_jvp_apply(
+ mlx_vector_array* res,
+ mlx_closure_custom_jvp cls,
+ const mlx_vector_array input_0,
+ const mlx_vector_array input_1,
+ const int* input_2,
+ size_t input_2_num);
+
+typedef struct mlx_closure_custom_vmap_ {
+ void* ctx;
+} mlx_closure_custom_vmap;
+mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void);
+int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls);
+mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_int*,
+ const mlx_vector_array,
+ const int*,
+ size_t _num));
+mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_int*,
+ const mlx_vector_array,
+ const int*,
+ size_t _num,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+int mlx_closure_custom_vmap_set(
+ mlx_closure_custom_vmap* cls,
+ const mlx_closure_custom_vmap src);
+int mlx_closure_custom_vmap_apply(
+ mlx_vector_array* res_0,
+ mlx_vector_int* res_1,
+ mlx_closure_custom_vmap cls,
+ const mlx_vector_array input_0,
+ const int* input_1,
+ size_t input_1_num);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/compile.h b/x/mlxrunner/mlx/include/mlx/c/compile.h
new file mode 100644
index 000000000..04567fb3a
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/compile.h
@@ -0,0 +1,57 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_COMPILE_H
+#define MLX_COMPILE_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup compile Compilation operations
+ */
+/**@{*/
+
+typedef enum mlx_compile_mode_ {
+ MLX_COMPILE_MODE_DISABLED,
+ MLX_COMPILE_MODE_NO_SIMPLIFY,
+ MLX_COMPILE_MODE_NO_FUSE,
+ MLX_COMPILE_MODE_ENABLED
+} mlx_compile_mode;
+int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless);
+int mlx_detail_compile(
+ mlx_closure* res,
+ const mlx_closure fun,
+ uintptr_t fun_id,
+ bool shapeless,
+ const uint64_t* constants,
+ size_t constants_num);
+int mlx_detail_compile_clear_cache(void);
+int mlx_detail_compile_erase(uintptr_t fun_id);
+int mlx_disable_compile(void);
+int mlx_enable_compile(void);
+int mlx_set_compile_mode(mlx_compile_mode mode);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/cuda.h b/x/mlxrunner/mlx/include/mlx/c/cuda.h
new file mode 100644
index 000000000..4734f8c51
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/cuda.h
@@ -0,0 +1,39 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_CUDA_H
+#define MLX_CUDA_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup cuda Cuda specific operations
+ */
+/**@{*/
+
+int mlx_cuda_is_available(bool* res);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/device.h b/x/mlxrunner/mlx/include/mlx/c/device.h
new file mode 100644
index 000000000..4b74e39d3
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/device.h
@@ -0,0 +1,154 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_DEVICE_H
+#define MLX_DEVICE_H
+
+#include
+#include
+
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_device Device
+ * MLX device object.
+ */
+/**@{*/
+
+/**
+ * A MLX device object.
+ */
+typedef struct mlx_device_ {
+ void* ctx;
+} mlx_device;
+
+/**
+ * Device type.
+ */
+typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type;
+
+/**
+ * Returns a new empty device.
+ */
+mlx_device mlx_device_new(void);
+
+/**
+ * Returns a new device of specified `type`, with specified `index`.
+ */
+mlx_device mlx_device_new_type(mlx_device_type type, int index);
+/**
+ * Free a device.
+ */
+int mlx_device_free(mlx_device dev);
+/**
+ * Set device to provided src device.
+ */
+int mlx_device_set(mlx_device* dev, const mlx_device src);
+/**
+ * Get device description.
+ */
+int mlx_device_tostring(mlx_string* str, mlx_device dev);
+/**
+ * Check if devices are the same.
+ */
+bool mlx_device_equal(mlx_device lhs, mlx_device rhs);
+/**
+ * Returns the index of the device.
+ */
+int mlx_device_get_index(int* index, mlx_device dev);
+/**
+ * Returns the type of the device.
+ */
+int mlx_device_get_type(mlx_device_type* type, mlx_device dev);
+/**
+ * Returns the default MLX device.
+ */
+int mlx_get_default_device(mlx_device* dev);
+/**
+ * Set the default MLX device.
+ */
+int mlx_set_default_device(mlx_device dev);
+/**
+ * Check if device is available.
+ */
+int mlx_device_is_available(bool* avail, mlx_device dev);
+/**
+ * Get the number of available devices for a device type.
+ */
+int mlx_device_count(int* count, mlx_device_type type);
+
+/**
+ * A MLX device info object.
+ * Contains key-value pairs with device properties.
+ * Keys vary by backend but common keys include:
+ * - device_name (string): Device name
+ * - architecture (string): Architecture identifier
+ * Additional keys may be present depending on the backend.
+ */
+typedef struct mlx_device_info_ {
+ void* ctx;
+} mlx_device_info;
+
+/**
+ * Returns a new empty device info object.
+ */
+mlx_device_info mlx_device_info_new(void);
+/**
+ * Get device information for a device.
+ */
+int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
+/**
+ * Free a device info object.
+ */
+int mlx_device_info_free(mlx_device_info info);
+/**
+ * Check if a key exists in the device info.
+ * Returns 0 on success, 1 on error.
+ * Sets *exists to true if the key exists, false otherwise.
+ */
+int mlx_device_info_has_key(
+ bool* exists,
+ mlx_device_info info,
+ const char* key);
+/**
+ * Check if a value is a string type.
+ * Returns 0 on success, 1 on error.
+ * Sets *is_string to true if the value is a string, false if it's a size_t.
+ */
+int mlx_device_info_is_string(
+ bool* is_string,
+ mlx_device_info info,
+ const char* key);
+/**
+ * Get a string value from device info.
+ * Returns 0 on success, 1 on error, 2 if key not found or wrong type.
+ */
+int mlx_device_info_get_string(
+ const char** value,
+ mlx_device_info info,
+ const char* key);
+/**
+ * Get a size_t value from device info.
+ * Returns 0 on success, 1 on error, 2 if key not found or wrong type.
+ */
+int mlx_device_info_get_size(
+ size_t* value,
+ mlx_device_info info,
+ const char* key);
+/**
+ * Get all keys from device info.
+ * Returns 0 on success, 1 on error.
+ */
+int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/distributed.h b/x/mlxrunner/mlx/include/mlx/c/distributed.h
new file mode 100644
index 000000000..c3b0baeee
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/distributed.h
@@ -0,0 +1,83 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_DISTRIBUTED_H
+#define MLX_DISTRIBUTED_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup distributed Distributed collectives
+ */
+/**@{*/
+
+int mlx_distributed_all_gather(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream S);
+int mlx_distributed_all_max(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+int mlx_distributed_all_min(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+int mlx_distributed_all_sum(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+int mlx_distributed_recv(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ int src,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+int mlx_distributed_recv_like(
+ mlx_array* res,
+ const mlx_array x,
+ int src,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+int mlx_distributed_send(
+ mlx_array* res,
+ const mlx_array x,
+ int dst,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+int mlx_distributed_sum_scatter(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/distributed_group.h b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h
new file mode 100644
index 000000000..3cfccc806
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h
@@ -0,0 +1,58 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_DISTRIBUTED_GROUP_H
+#define MLX_DISTRIBUTED_GROUP_H
+
+#include
+
+#include "mlx/c/stream.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_distributed_group MLX distributed
+ */
+/**@{*/
+
+/**
+ * A MLX distributed group object.
+ */
+typedef struct mlx_distributed_group_ {
+ void* ctx;
+} mlx_distributed_group;
+
+/**
+ * Get the rank.
+ */
+int mlx_distributed_group_rank(mlx_distributed_group group);
+
+/**
+ * Get the group size.
+ */
+int mlx_distributed_group_size(mlx_distributed_group group);
+
+/**
+ * Split the group.
+ */
+mlx_distributed_group
+mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
+
+/**
+ * Check if distributed is available.
+ */
+bool mlx_distributed_is_available(void);
+
+/**
+ * Initialize distributed.
+ */
+mlx_distributed_group mlx_distributed_init(bool strict);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/error.h b/x/mlxrunner/mlx/include/mlx/c/error.h
new file mode 100644
index 000000000..8c063a403
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/error.h
@@ -0,0 +1,41 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_ERROR_H
+#define MLX_ERROR_H
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_error Error management
+ */
+/**@{*/
+
+typedef void (*mlx_error_handler_func)(const char* msg, void* data);
+
+/**
+ * Set the error handler.
+ */
+void mlx_set_error_handler(
+ mlx_error_handler_func handler,
+ void* data,
+ void (*dtor)(void*));
+
+/**
+ * Throw an error.
+ */
+void _mlx_error(const char* file, const int line, const char* fmt, ...);
+
+/**
+ * Throw an error. Macro which passes file name and line number to _mlx_error().
+ */
+#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__)
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/export.h b/x/mlxrunner/mlx/include/mlx/c/export.h
new file mode 100644
index 000000000..52cb2835c
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/export.h
@@ -0,0 +1,75 @@
+/* Copyright © 2023-2025 Apple Inc. */
+
+#ifndef MLX_EXPORT_H
+#define MLX_EXPORT_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup export Function serialization
+ */
+/**@{*/
+int mlx_export_function(
+ const char* file,
+ const mlx_closure fun,
+ const mlx_vector_array args,
+ bool shapeless);
+int mlx_export_function_kwargs(
+ const char* file,
+ const mlx_closure_kwargs fun,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs,
+ bool shapeless);
+
+typedef struct mlx_function_exporter_ {
+ void* ctx;
+} mlx_function_exporter;
+mlx_function_exporter mlx_function_exporter_new(
+ const char* file,
+ const mlx_closure fun,
+ bool shapeless);
+int mlx_function_exporter_free(mlx_function_exporter xfunc);
+int mlx_function_exporter_apply(
+ const mlx_function_exporter xfunc,
+ const mlx_vector_array args);
+int mlx_function_exporter_apply_kwargs(
+ const mlx_function_exporter xfunc,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs);
+
+typedef struct mlx_imported_function_ {
+ void* ctx;
+} mlx_imported_function;
+mlx_imported_function mlx_imported_function_new(const char* file);
+int mlx_imported_function_free(mlx_imported_function xfunc);
+int mlx_imported_function_apply(
+ mlx_vector_array* res,
+ const mlx_imported_function xfunc,
+ const mlx_vector_array args);
+int mlx_imported_function_apply_kwargs(
+ mlx_vector_array* res,
+ const mlx_imported_function xfunc,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs);
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/fast.h b/x/mlxrunner/mlx/include/mlx/c/fast.h
new file mode 100644
index 000000000..c825d00e5
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/fast.h
@@ -0,0 +1,206 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_FAST_H
+#define MLX_FAST_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup fast Fast custom operations
+ */
+/**@{*/
+
+typedef struct mlx_fast_cuda_kernel_config_ {
+ void* ctx;
+} mlx_fast_cuda_kernel_config;
+mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void);
+void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls);
+
+int mlx_fast_cuda_kernel_config_add_output_arg(
+ mlx_fast_cuda_kernel_config cls,
+ const int* shape,
+ size_t size,
+ mlx_dtype dtype);
+int mlx_fast_cuda_kernel_config_set_grid(
+ mlx_fast_cuda_kernel_config cls,
+ int grid1,
+ int grid2,
+ int grid3);
+int mlx_fast_cuda_kernel_config_set_thread_group(
+ mlx_fast_cuda_kernel_config cls,
+ int thread1,
+ int thread2,
+ int thread3);
+int mlx_fast_cuda_kernel_config_set_init_value(
+ mlx_fast_cuda_kernel_config cls,
+ float value);
+int mlx_fast_cuda_kernel_config_set_verbose(
+ mlx_fast_cuda_kernel_config cls,
+ bool verbose);
+int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ mlx_dtype dtype);
+int mlx_fast_cuda_kernel_config_add_template_arg_int(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ int value);
+int mlx_fast_cuda_kernel_config_add_template_arg_bool(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ bool value);
+
+typedef struct mlx_fast_cuda_kernel_ {
+ void* ctx;
+} mlx_fast_cuda_kernel;
+
+mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
+ const char* name,
+ const mlx_vector_string input_names,
+ const mlx_vector_string output_names,
+ const char* source,
+ const char* header,
+ bool ensure_row_contiguous,
+ int shared_memory);
+
+void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls);
+
+int mlx_fast_cuda_kernel_apply(
+ mlx_vector_array* outputs,
+ mlx_fast_cuda_kernel cls,
+ const mlx_vector_array inputs,
+ const mlx_fast_cuda_kernel_config config,
+ const mlx_stream stream);
+
+int mlx_fast_layer_norm(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array weight /* may be null */,
+ const mlx_array bias /* may be null */,
+ float eps,
+ const mlx_stream s);
+
+typedef struct mlx_fast_metal_kernel_config_ {
+ void* ctx;
+} mlx_fast_metal_kernel_config;
+mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void);
+void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls);
+
+int mlx_fast_metal_kernel_config_add_output_arg(
+ mlx_fast_metal_kernel_config cls,
+ const int* shape,
+ size_t size,
+ mlx_dtype dtype);
+int mlx_fast_metal_kernel_config_set_grid(
+ mlx_fast_metal_kernel_config cls,
+ int grid1,
+ int grid2,
+ int grid3);
+int mlx_fast_metal_kernel_config_set_thread_group(
+ mlx_fast_metal_kernel_config cls,
+ int thread1,
+ int thread2,
+ int thread3);
+int mlx_fast_metal_kernel_config_set_init_value(
+ mlx_fast_metal_kernel_config cls,
+ float value);
+int mlx_fast_metal_kernel_config_set_verbose(
+ mlx_fast_metal_kernel_config cls,
+ bool verbose);
+int mlx_fast_metal_kernel_config_add_template_arg_dtype(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ mlx_dtype dtype);
+int mlx_fast_metal_kernel_config_add_template_arg_int(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ int value);
+int mlx_fast_metal_kernel_config_add_template_arg_bool(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ bool value);
+
+typedef struct mlx_fast_metal_kernel_ {
+ void* ctx;
+} mlx_fast_metal_kernel;
+
+mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
+ const char* name,
+ const mlx_vector_string input_names,
+ const mlx_vector_string output_names,
+ const char* source,
+ const char* header,
+ bool ensure_row_contiguous,
+ bool atomic_outputs);
+
+void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
+
+int mlx_fast_metal_kernel_apply(
+ mlx_vector_array* outputs,
+ mlx_fast_metal_kernel cls,
+ const mlx_vector_array inputs,
+ const mlx_fast_metal_kernel_config config,
+ const mlx_stream stream);
+
+int mlx_fast_rms_norm(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array weight /* may be null */,
+ float eps,
+ const mlx_stream s);
+int mlx_fast_rope(
+ mlx_array* res,
+ const mlx_array x,
+ int dims,
+ bool traditional,
+ mlx_optional_float base,
+ float scale,
+ int offset,
+ const mlx_array freqs /* may be null */,
+ const mlx_stream s);
+int mlx_fast_rope_dynamic(
+ mlx_array* res,
+ const mlx_array x,
+ int dims,
+ bool traditional,
+ mlx_optional_float base,
+ float scale,
+ const mlx_array offset,
+ const mlx_array freqs /* may be null */,
+ const mlx_stream s);
+int mlx_fast_scaled_dot_product_attention(
+ mlx_array* res,
+ const mlx_array queries,
+ const mlx_array keys,
+ const mlx_array values,
+ float scale,
+ const char* mask_mode,
+ const mlx_array mask_arr /* may be null */,
+ const mlx_array sinks /* may be null */,
+ const mlx_stream s);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/fft.h b/x/mlxrunner/mlx/include/mlx/c/fft.h
new file mode 100644
index 000000000..779803e9b
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/fft.h
@@ -0,0 +1,138 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_FFT_H
+#define MLX_FFT_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup fft FFT operations
+ */
+/**@{*/
+
+int mlx_fft_fft(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s);
+int mlx_fft_fft2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_fftn(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_fftshift(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_ifft(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s);
+int mlx_fft_ifft2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_ifftn(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_ifftshift(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_irfft(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s);
+int mlx_fft_irfft2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_irfftn(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_rfft(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s);
+int mlx_fft_rfft2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_fft_rfftn(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/half.h b/x/mlxrunner/mlx/include/mlx/c/half.h
new file mode 100644
index 000000000..958d555f5
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/half.h
@@ -0,0 +1,26 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_HALF_H
+#define MLX_HALF_H
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__)
+#define HAS_FLOAT16
+#include
+typedef __fp16 float16_t;
+#endif
+
+#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__)
+#define HAS_BFLOAT16
+#include
+typedef __bf16 bfloat16_t;
+#endif
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/io.h b/x/mlxrunner/mlx/include/mlx/c/io.h
new file mode 100644
index 000000000..6eb205c9a
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/io.h
@@ -0,0 +1,63 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_IO_H
+#define MLX_IO_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup io IO operations
+ */
+/**@{*/
+
+int mlx_load_reader(
+ mlx_array* res,
+ mlx_io_reader in_stream,
+ const mlx_stream s);
+int mlx_load(mlx_array* res, const char* file, const mlx_stream s);
+int mlx_load_safetensors_reader(
+ mlx_map_string_to_array* res_0,
+ mlx_map_string_to_string* res_1,
+ mlx_io_reader in_stream,
+ const mlx_stream s);
+int mlx_load_safetensors(
+ mlx_map_string_to_array* res_0,
+ mlx_map_string_to_string* res_1,
+ const char* file,
+ const mlx_stream s);
+int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a);
+int mlx_save(const char* file, const mlx_array a);
+int mlx_save_safetensors_writer(
+ mlx_io_writer in_stream,
+ const mlx_map_string_to_array param,
+ const mlx_map_string_to_string metadata);
+int mlx_save_safetensors(
+ const char* file,
+ const mlx_map_string_to_array param,
+ const mlx_map_string_to_string metadata);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/io_types.h b/x/mlxrunner/mlx/include/mlx/c/io_types.h
new file mode 100644
index 000000000..88349b57c
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/io_types.h
@@ -0,0 +1,104 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_IO_TYPES_H
+#define MLX_IO_TYPES_H
+
+#include
+
+#include "mlx/c/string.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_io_types IO Types
+ * MLX IO type objects.
+ */
+/**@{*/
+
+/**
+ * A MLX IO reader object.
+ */
+typedef struct mlx_io_reader_ {
+ void* ctx;
+} mlx_io_reader;
+/**
+ * A MLX IO writer object.
+ */
+typedef struct mlx_io_writer_ {
+ void* ctx;
+} mlx_io_writer;
+
+/**
+ * Virtual table for custom IO reader and writer objects.
+ */
+typedef struct mlx_io_vtable_ {
+ bool (*is_open)(void*);
+ bool (*good)(void*);
+ size_t (*tell)(void*);
+ void (*seek)(void*, int64_t off, int whence);
+ void (*read)(void*, char* data, size_t n);
+ void (*read_at_offset)(void*, char* data, size_t n, size_t off);
+ void (*write)(void*, const char* data, size_t n);
+ const char* (*label)(void*);
+ void (*free)(void*);
+} mlx_io_vtable;
+
+/**
+ * Returns a new custom IO reader.
+ * `vtable` operates on user descriptor `desc`.
+ */
+mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable);
+
+/**
+ * Get IO reader user descriptor.
+ */
+int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io);
+
+/**
+ * Get IO reader description.
+ */
+int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io);
+
+/**
+ * Free IO reader.
+ *
+ * Note that MLX arrays are lazily evaluated, so the underlying object may
+ * be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
+ * will be called when the underlying object is actually freed.
+ */
+int mlx_io_reader_free(mlx_io_reader io);
+
+/**
+ * Returns a new custom IO writer.
+ * `vtable` operates on user descriptor `desc`.
+ */
+mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable);
+
+/**
+ * Get IO writer user descriptor.
+ */
+int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io);
+
+/**
+ * Get IO writer description.
+ */
+int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io);
+
+/**
+ * Free IO writer.
+ *
+ * Note that MLX arrays are lazily evaluated, so the underlying object may
+ * be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
+ * will be called when the underlying object is actually freed.
+ */
+int mlx_io_writer_free(mlx_io_writer io);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/linalg.h b/x/mlxrunner/mlx/include/mlx/c/linalg.h
new file mode 100644
index 000000000..91d5d661e
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/linalg.h
@@ -0,0 +1,128 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_LINALG_H
+#define MLX_LINALG_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup linalg Linear algebra operations
+ */
+/**@{*/
+
+int mlx_linalg_cholesky(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s);
+int mlx_linalg_cholesky_inv(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s);
+int mlx_linalg_cross(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int axis,
+ const mlx_stream s);
+int mlx_linalg_eig(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s);
+int mlx_linalg_eigh(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const char* UPLO,
+ const mlx_stream s);
+int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_linalg_eigvalsh(
+ mlx_array* res,
+ const mlx_array a,
+ const char* UPLO,
+ const mlx_stream s);
+int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
+int mlx_linalg_lu_factor(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s);
+int mlx_linalg_norm(
+ mlx_array* res,
+ const mlx_array a,
+ double ord,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_linalg_norm_matrix(
+ mlx_array* res,
+ const mlx_array a,
+ const char* ord,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_linalg_norm_l2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_linalg_qr(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s);
+int mlx_linalg_solve(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_linalg_solve_triangular(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ bool upper,
+ const mlx_stream s);
+int mlx_linalg_svd(
+ mlx_vector_array* res,
+ const mlx_array a,
+ bool compute_uv,
+ const mlx_stream s);
+int mlx_linalg_tri_inv(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/map.h b/x/mlxrunner/mlx/include/mlx/c/map.h
new file mode 100644
index 000000000..56abe84f1
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/map.h
@@ -0,0 +1,149 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_MAP_H
+#define MLX_MAP_H
+
+#include "mlx/c/array.h"
+#include "mlx/c/string.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_map Maps
+ * MLX map objects.
+ */
+/**@{*/
+
+/**
+ * A string-to-array map
+ */
+typedef struct mlx_map_string_to_array_ {
+ void* ctx;
+} mlx_map_string_to_array;
+
+/**
+ * Returns a new empty string-to-array map.
+ */
+mlx_map_string_to_array mlx_map_string_to_array_new(void);
+/**
+ * Set map to provided src map.
+ */
+int mlx_map_string_to_array_set(
+ mlx_map_string_to_array* map,
+ const mlx_map_string_to_array src);
+/**
+ * Free a string-to-array map.
+ */
+int mlx_map_string_to_array_free(mlx_map_string_to_array map);
+/**
+ * Insert a new `value` at the specified `key` in the map.
+ */
+int mlx_map_string_to_array_insert(
+ mlx_map_string_to_array map,
+ const char* key,
+ const mlx_array value);
+/**
+ * Returns the value indexed at the specified `key` in the map.
+ */
+int mlx_map_string_to_array_get(
+ mlx_array* value,
+ const mlx_map_string_to_array map,
+ const char* key);
+
+/**
+ * An iterator over a string-to-array map.
+ */
+typedef struct mlx_map_string_to_array_iterator_ {
+ void* ctx;
+ void* map_ctx;
+} mlx_map_string_to_array_iterator;
+/**
+ * Returns a new iterator over the given map.
+ */
+mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(
+ mlx_map_string_to_array map);
+/**
+ * Free iterator.
+ */
+int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it);
+/**
+ * Increment iterator.
+ */
+int mlx_map_string_to_array_iterator_next(
+ const char** key,
+ mlx_array* value,
+ mlx_map_string_to_array_iterator it);
+
+/**
+ * A string-to-string map
+ */
+typedef struct mlx_map_string_to_string_ {
+ void* ctx;
+} mlx_map_string_to_string;
+
+/**
+ * Returns a new empty string-to-string map.
+ */
+mlx_map_string_to_string mlx_map_string_to_string_new(void);
+/**
+ * Set map to provided src map.
+ */
+int mlx_map_string_to_string_set(
+ mlx_map_string_to_string* map,
+ const mlx_map_string_to_string src);
+/**
+ * Free a string-to-string map.
+ */
+int mlx_map_string_to_string_free(mlx_map_string_to_string map);
+/**
+ * Insert a new `value` at the specified `key` in the map.
+ */
+int mlx_map_string_to_string_insert(
+ mlx_map_string_to_string map,
+ const char* key,
+ const char* value);
+/**
+ * Returns the value indexed at the specified `key` in the map.
+ */
+int mlx_map_string_to_string_get(
+ const char** value,
+ const mlx_map_string_to_string map,
+ const char* key);
+
+/**
+ * An iterator over a string-to-string map.
+ */
+typedef struct mlx_map_string_to_string_iterator_ {
+ void* ctx;
+ void* map_ctx;
+} mlx_map_string_to_string_iterator;
+/**
+ * Returns a new iterator over the given map.
+ */
+mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(
+ mlx_map_string_to_string map);
+/**
+ * Free iterator.
+ */
+int mlx_map_string_to_string_iterator_free(
+ mlx_map_string_to_string_iterator it);
+/**
+ * Increment iterator.
+ */
+int mlx_map_string_to_string_iterator_next(
+ const char** key,
+ const char** value,
+ mlx_map_string_to_string_iterator it);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/memory.h b/x/mlxrunner/mlx/include/mlx/c/memory.h
new file mode 100644
index 000000000..bae9e08ec
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/memory.h
@@ -0,0 +1,47 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_MEMORY_H
+#define MLX_MEMORY_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup memory Memory operations
+ */
+/**@{*/
+
+int mlx_clear_cache(void);
+int mlx_get_active_memory(size_t* res);
+int mlx_get_cache_memory(size_t* res);
+int mlx_get_memory_limit(size_t* res);
+int mlx_get_peak_memory(size_t* res);
+int mlx_reset_peak_memory(void);
+int mlx_set_cache_limit(size_t* res, size_t limit);
+int mlx_set_memory_limit(size_t* res, size_t limit);
+int mlx_set_wired_limit(size_t* res, size_t limit);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/metal.h b/x/mlxrunner/mlx/include/mlx/c/metal.h
new file mode 100644
index 000000000..5877b224b
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/metal.h
@@ -0,0 +1,41 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_METAL_H
+#define MLX_METAL_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup metal Metal specific operations
+ */
+/**@{*/
+
+int mlx_metal_is_available(bool* res);
+int mlx_metal_start_capture(const char* path);
+int mlx_metal_stop_capture(void);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/mlx.h b/x/mlxrunner/mlx/include/mlx/c/mlx.h
new file mode 100644
index 000000000..ffadac89a
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/mlx.h
@@ -0,0 +1,34 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_ALL_H
+#define MLX_ALL_H
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/compile.h"
+#include "mlx/c/cuda.h"
+#include "mlx/c/device.h"
+#include "mlx/c/distributed.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/error.h"
+#include "mlx/c/export.h"
+#include "mlx/c/fast.h"
+#include "mlx/c/fft.h"
+#include "mlx/c/half.h"
+#include "mlx/c/io.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/linalg.h"
+#include "mlx/c/map.h"
+#include "mlx/c/memory.h"
+#include "mlx/c/metal.h"
+#include "mlx/c/ops.h"
+#include "mlx/c/optional.h"
+#include "mlx/c/random.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/transforms.h"
+#include "mlx/c/transforms_impl.h"
+#include "mlx/c/vector.h"
+#include "mlx/c/version.h"
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/ops.h b/x/mlxrunner/mlx/include/mlx/c/ops.h
new file mode 100644
index 000000000..a1446fb9e
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/ops.h
@@ -0,0 +1,1235 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_OPS_H
+#define MLX_OPS_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup ops Core array operations
+ */
+/**@{*/
+
+int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_add(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_addmm(
+ mlx_array* res,
+ const mlx_array c,
+ const mlx_array a,
+ const mlx_array b,
+ float alpha,
+ float beta,
+ const mlx_stream s);
+int mlx_all_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_all_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_all(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_allclose(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ double rtol,
+ double atol,
+ bool equal_nan,
+ const mlx_stream s);
+int mlx_any_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_any_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_any(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_arange(
+ mlx_array* res,
+ double start,
+ double stop,
+ double step,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_arctan2(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_argmax_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_argmax(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_argmin_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_argmin(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_argpartition_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ int axis,
+ const mlx_stream s);
+int mlx_argpartition(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ const mlx_stream s);
+int mlx_argsort_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s);
+int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_array_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ bool equal_nan,
+ const mlx_stream s);
+int mlx_as_strided(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const int64_t* strides,
+ size_t strides_num,
+ size_t offset,
+ const mlx_stream s);
+int mlx_astype(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_bitwise_and(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_bitwise_or(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_bitwise_xor(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_block_masked_mm(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int block_size,
+ const mlx_array mask_out /* may be null */,
+ const mlx_array mask_lhs /* may be null */,
+ const mlx_array mask_rhs /* may be null */,
+ const mlx_stream s);
+int mlx_broadcast_arrays(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_stream s);
+int mlx_broadcast_to(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s);
+int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_clip(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array a_min /* may be null */,
+ const mlx_array a_max /* may be null */,
+ const mlx_stream s);
+int mlx_concatenate_axis(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ int axis,
+ const mlx_stream s);
+int mlx_concatenate(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ const mlx_stream s);
+int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_contiguous(
+ mlx_array* res,
+ const mlx_array a,
+ bool allow_col_major,
+ const mlx_stream s);
+int mlx_conv1d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride,
+ int padding,
+ int dilation,
+ int groups,
+ const mlx_stream s);
+int mlx_conv2d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int padding_0,
+ int padding_1,
+ int dilation_0,
+ int dilation_1,
+ int groups,
+ const mlx_stream s);
+int mlx_conv3d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int stride_2,
+ int padding_0,
+ int padding_1,
+ int padding_2,
+ int dilation_0,
+ int dilation_1,
+ int dilation_2,
+ int groups,
+ const mlx_stream s);
+int mlx_conv_general(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ const int* stride,
+ size_t stride_num,
+ const int* padding_lo,
+ size_t padding_lo_num,
+ const int* padding_hi,
+ size_t padding_hi_num,
+ const int* kernel_dilation,
+ size_t kernel_dilation_num,
+ const int* input_dilation,
+ size_t input_dilation_num,
+ int groups,
+ bool flip,
+ const mlx_stream s);
+int mlx_conv_transpose1d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride,
+ int padding,
+ int dilation,
+ int output_padding,
+ int groups,
+ const mlx_stream s);
+int mlx_conv_transpose2d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int padding_0,
+ int padding_1,
+ int dilation_0,
+ int dilation_1,
+ int output_padding_0,
+ int output_padding_1,
+ int groups,
+ const mlx_stream s);
+int mlx_conv_transpose3d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int stride_2,
+ int padding_0,
+ int padding_1,
+ int padding_2,
+ int dilation_0,
+ int dilation_1,
+ int dilation_2,
+ int output_padding_0,
+ int output_padding_1,
+ int output_padding_2,
+ int groups,
+ const mlx_stream s);
+int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_cummax(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+int mlx_cummin(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+int mlx_cumprod(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+int mlx_cumsum(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_depends(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_vector_array dependencies);
+int mlx_dequantize(
+ mlx_array* res,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ mlx_optional_dtype dtype,
+ const mlx_stream s);
+int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
+int mlx_diagonal(
+ mlx_array* res,
+ const mlx_array a,
+ int offset,
+ int axis1,
+ int axis2,
+ const mlx_stream s);
+int mlx_divide(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_divmod(
+ mlx_vector_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_einsum(
+ mlx_array* res,
+ const char* subscripts,
+ const mlx_vector_array operands,
+ const mlx_stream s);
+int mlx_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_expand_dims_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_expand_dims(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s);
+int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_eye(
+ mlx_array* res,
+ int n,
+ int m,
+ int k,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_flatten(
+ mlx_array* res,
+ const mlx_array a,
+ int start_axis,
+ int end_axis,
+ const mlx_stream s);
+int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_floor_divide(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_from_fp8(
+ mlx_array* res,
+ const mlx_array x,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_full(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array vals,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_full_like(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array vals,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_gather(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const int* axes,
+ size_t axes_num,
+ const int* slice_sizes,
+ size_t slice_sizes_num,
+ const mlx_stream s);
+int mlx_gather_single(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const int* slice_sizes,
+ size_t slice_sizes_num,
+ const mlx_stream s);
+int mlx_gather_mm(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_array lhs_indices /* may be null */,
+ const mlx_array rhs_indices /* may be null */,
+ bool sorted_indices,
+ const mlx_stream s);
+int mlx_gather_qmm(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ const mlx_array lhs_indices /* may be null */,
+ const mlx_array rhs_indices /* may be null */,
+ bool transpose,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ bool sorted_indices,
+ const mlx_stream s);
+int mlx_greater(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_greater_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_hadamard_transform(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_optional_float scale,
+ const mlx_stream s);
+int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
+int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_inner(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_isclose(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ double rtol,
+ double atol,
+ bool equal_nan,
+ const mlx_stream s);
+int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_kron(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_left_shift(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_less(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_less_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_linspace(
+ mlx_array* res,
+ double start,
+ double stop,
+ int num,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_logaddexp(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_logcumsumexp(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+int mlx_logical_and(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_logical_or(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_logsumexp_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_logsumexp_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_logsumexp(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_masked_scatter(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array mask,
+ const mlx_array src,
+ const mlx_stream s);
+int mlx_matmul(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_max_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_max_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_max(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_maximum(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_mean_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_mean_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_mean(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_median(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_meshgrid(
+ mlx_vector_array* res,
+ const mlx_vector_array arrays,
+ bool sparse,
+ const char* indexing,
+ const mlx_stream s);
+int mlx_min_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_min_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_min(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_minimum(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_moveaxis(
+ mlx_array* res,
+ const mlx_array a,
+ int source,
+ int destination,
+ const mlx_stream s);
+int mlx_multiply(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_nan_to_num(
+ mlx_array* res,
+ const mlx_array a,
+ float nan,
+ mlx_optional_float posinf,
+ mlx_optional_float neginf,
+ const mlx_stream s);
+int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_not_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_number_of_elements(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool inverted,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_ones(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_outer(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_pad(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const int* low_pad_size,
+ size_t low_pad_size_num,
+ const int* high_pad_size,
+ size_t high_pad_size_num,
+ const mlx_array pad_value,
+ const char* mode,
+ const mlx_stream s);
+int mlx_pad_symmetric(
+ mlx_array* res,
+ const mlx_array a,
+ int pad_width,
+ const mlx_array pad_value,
+ const char* mode,
+ const mlx_stream s);
+int mlx_partition_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ int axis,
+ const mlx_stream s);
+int mlx_partition(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ const mlx_stream s);
+int mlx_power(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_prod_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_prod_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_prod(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_put_along_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array values,
+ int axis,
+ const mlx_stream s);
+int mlx_qqmm(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array w_scales /* may be null */,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s);
+int mlx_quantize(
+ mlx_vector_array* res,
+ const mlx_array w,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s);
+int mlx_quantized_matmul(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ bool transpose,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s);
+int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_remainder(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_repeat_axis(
+ mlx_array* res,
+ const mlx_array arr,
+ int repeats,
+ int axis,
+ const mlx_stream s);
+int mlx_repeat(
+ mlx_array* res,
+ const mlx_array arr,
+ int repeats,
+ const mlx_stream s);
+int mlx_reshape(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s);
+int mlx_right_shift(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_roll_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ int axis,
+ const mlx_stream s);
+int mlx_roll_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_roll(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ const mlx_stream s);
+int mlx_round(
+ mlx_array* res,
+ const mlx_array a,
+ int decimals,
+ const mlx_stream s);
+int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_scatter(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_scatter_single(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array updates,
+ int axis,
+ const mlx_stream s);
+int mlx_scatter_add(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_scatter_add_single(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array updates,
+ int axis,
+ const mlx_stream s);
+int mlx_scatter_add_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array values,
+ int axis,
+ const mlx_stream s);
+int mlx_scatter_max(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_scatter_max_single(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array updates,
+ int axis,
+ const mlx_stream s);
+int mlx_scatter_min(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_scatter_min_single(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array updates,
+ int axis,
+ const mlx_stream s);
+int mlx_scatter_prod(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_scatter_prod_single(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array updates,
+ int axis,
+ const mlx_stream s);
+int mlx_segmented_mm(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_array segments,
+ const mlx_stream s);
+int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_slice(
+ mlx_array* res,
+ const mlx_array a,
+ const int* start,
+ size_t start_num,
+ const int* stop,
+ size_t stop_num,
+ const int* strides,
+ size_t strides_num,
+ const mlx_stream s);
+int mlx_slice_dynamic(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array start,
+ const int* axes,
+ size_t axes_num,
+ const int* slice_size,
+ size_t slice_size_num,
+ const mlx_stream s);
+int mlx_slice_update(
+ mlx_array* res,
+ const mlx_array src,
+ const mlx_array update,
+ const int* start,
+ size_t start_num,
+ const int* stop,
+ size_t stop_num,
+ const int* strides,
+ size_t strides_num,
+ const mlx_stream s);
+int mlx_slice_update_dynamic(
+ mlx_array* res,
+ const mlx_array src,
+ const mlx_array update,
+ const mlx_array start,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_softmax_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool precise,
+ const mlx_stream s);
+int mlx_softmax_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool precise,
+ const mlx_stream s);
+int mlx_softmax(
+ mlx_array* res,
+ const mlx_array a,
+ bool precise,
+ const mlx_stream s);
+int mlx_sort_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s);
+int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_split(
+ mlx_vector_array* res,
+ const mlx_array a,
+ int num_splits,
+ int axis,
+ const mlx_stream s);
+int mlx_split_sections(
+ mlx_vector_array* res,
+ const mlx_array a,
+ const int* indices,
+ size_t indices_num,
+ int axis,
+ const mlx_stream s);
+int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_squeeze_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_squeeze_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s);
+int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_stack_axis(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ int axis,
+ const mlx_stream s);
+int mlx_stack(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ const mlx_stream s);
+int mlx_std_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+int mlx_std_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+int mlx_std(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_subtract(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+int mlx_sum_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_sum_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_sum(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+int mlx_swapaxes(
+ mlx_array* res,
+ const mlx_array a,
+ int axis1,
+ int axis2,
+ const mlx_stream s);
+int mlx_take_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const mlx_stream s);
+int mlx_take(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_stream s);
+int mlx_take_along_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const mlx_stream s);
+int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_tensordot(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const int* axes_a,
+ size_t axes_a_num,
+ const int* axes_b,
+ size_t axes_b_num,
+ const mlx_stream s);
+int mlx_tensordot_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int axis,
+ const mlx_stream s);
+int mlx_tile(
+ mlx_array* res,
+ const mlx_array arr,
+ const int* reps,
+ size_t reps_num,
+ const mlx_stream s);
+int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s);
+int mlx_topk_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int k,
+ int axis,
+ const mlx_stream s);
+int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
+int mlx_trace(
+ mlx_array* res,
+ const mlx_array a,
+ int offset,
+ int axis1,
+ int axis2,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_transpose_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s);
+int mlx_tri(
+ mlx_array* res,
+ int n,
+ int m,
+ int k,
+ mlx_dtype type,
+ const mlx_stream s);
+int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s);
+int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s);
+int mlx_unflatten(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s);
+int mlx_var_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+int mlx_var_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+int mlx_var(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+int mlx_view(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_where(
+ mlx_array* res,
+ const mlx_array condition,
+ const mlx_array x,
+ const mlx_array y,
+ const mlx_stream s);
+int mlx_zeros(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_stream s);
+int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/optional.h b/x/mlxrunner/mlx/include/mlx/c/optional.h
new file mode 100644
index 000000000..ff9ea14e5
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/optional.h
@@ -0,0 +1,51 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_OPTIONAL_H
+#define MLX_OPTIONAL_H
+
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/string.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_optional Optionals
+ * MLX optional scalars.
+ */
+/**@{*/
+
+/**
+ * A int optional.
+ */
+typedef struct mlx_optional_int_ {
+ int value;
+ bool has_value;
+} mlx_optional_int;
+
+/**
+ * A float optional.
+ */
+typedef struct mlx_optional_float_ {
+ float value;
+ bool has_value;
+} mlx_optional_float;
+
+/**
+ * A dtype optional.
+ */
+typedef struct mlx_optional_dtype_ {
+ mlx_dtype value;
+ bool has_value;
+} mlx_optional_dtype;
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/random.h b/x/mlxrunner/mlx/include/mlx/c/random.h
new file mode 100644
index 000000000..dbce0be37
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/random.h
@@ -0,0 +1,166 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_RANDOM_H
+#define MLX_RANDOM_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup random Random number operations
+ */
+/**@{*/
+
+int mlx_random_bernoulli(
+ mlx_array* res,
+ const mlx_array p,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_bits(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ int width,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_categorical_shape(
+ mlx_array* res,
+ const mlx_array logits,
+ int axis,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_categorical_num_samples(
+ mlx_array* res,
+ const mlx_array logits_,
+ int axis,
+ int num_samples,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_categorical(
+ mlx_array* res,
+ const mlx_array logits,
+ int axis,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_gumbel(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_key(mlx_array* res, uint64_t seed);
+int mlx_random_laplace(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ float loc,
+ float scale,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_multivariate_normal(
+ mlx_array* res,
+ const mlx_array mean,
+ const mlx_array cov,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_normal_broadcast(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array loc /* may be null */,
+ const mlx_array scale /* may be null */,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_normal(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ float loc,
+ float scale,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_permutation(
+ mlx_array* res,
+ const mlx_array x,
+ int axis,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_permutation_arange(
+ mlx_array* res,
+ int x,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_randint(
+ mlx_array* res,
+ const mlx_array low,
+ const mlx_array high,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_seed(uint64_t seed);
+int mlx_random_split_num(
+ mlx_array* res,
+ const mlx_array key,
+ int num,
+ const mlx_stream s);
+int mlx_random_split(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array key,
+ const mlx_stream s);
+int mlx_random_truncated_normal(
+ mlx_array* res,
+ const mlx_array lower,
+ const mlx_array upper,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+int mlx_random_uniform(
+ mlx_array* res,
+ const mlx_array low,
+ const mlx_array high,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/stream.h b/x/mlxrunner/mlx/include/mlx/c/stream.h
new file mode 100644
index 000000000..d5865b806
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/stream.h
@@ -0,0 +1,88 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_STREAM_H
+#define MLX_STREAM_H
+
+#include
+
+#include "mlx/c/device.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_stream Stream
+ * MLX stream object.
+ */
+/**@{*/
+
+/**
+ * A MLX stream object.
+ */
+typedef struct mlx_stream_ {
+ void* ctx;
+} mlx_stream;
+
+/**
+ * Returns a new empty stream.
+ */
+mlx_stream mlx_stream_new(void);
+
+/**
+ * Returns a new stream on a device.
+ */
+mlx_stream mlx_stream_new_device(mlx_device dev);
+/**
+ * Set stream to provided src stream.
+ */
+int mlx_stream_set(mlx_stream* stream, const mlx_stream src);
+/**
+ * Free a stream.
+ */
+int mlx_stream_free(mlx_stream stream);
+/**
+ * Get stream description.
+ */
+int mlx_stream_tostring(mlx_string* str, mlx_stream stream);
+/**
+ * Check if streams are the same.
+ */
+bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs);
+/**
+ * Return the device of the stream.
+ */
+int mlx_stream_get_device(mlx_device* dev, mlx_stream stream);
+/**
+ * Return the index of the stream.
+ */
+int mlx_stream_get_index(int* index, mlx_stream stream);
+/**
+ * Synchronize with the provided stream.
+ */
+int mlx_synchronize(mlx_stream stream);
+/**
+ * Returns the default stream on the given device.
+ */
+int mlx_get_default_stream(mlx_stream* stream, mlx_device dev);
+/**
+ * Set default stream.
+ */
+int mlx_set_default_stream(mlx_stream stream);
+/**
+ * Returns the current default CPU stream.
+ */
+mlx_stream mlx_default_cpu_stream_new(void);
+
+/**
+ * Returns the current default GPU stream.
+ */
+mlx_stream mlx_default_gpu_stream_new(void);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/string.h b/x/mlxrunner/mlx/include/mlx/c/string.h
new file mode 100644
index 000000000..0d2a356ba
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/string.h
@@ -0,0 +1,55 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_STRING_H
+#define MLX_STRING_H
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_string String
+ * MLX string object.
+ */
+/**@{*/
+
+/**
+ * A MLX string object.
+ */
+typedef struct mlx_string_ {
+ void* ctx;
+} mlx_string;
+
+/**
+ * Returns a new empty string.
+ */
+mlx_string mlx_string_new(void);
+
+/**
+ * Returns a new string, copying contents from `str`, which must end with `\0`.
+ */
+mlx_string mlx_string_new_data(const char* str);
+
+/**
+ * Set string to src string.
+ */
+int mlx_string_set(mlx_string* str, const mlx_string src);
+
+/**
+ * Returns a pointer to the string contents.
+ * The pointer is valid for the life duration of the string.
+ */
+const char* mlx_string_data(mlx_string str);
+
+/**
+ * Free string.
+ */
+int mlx_string_free(mlx_string str);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/transforms.h b/x/mlxrunner/mlx/include/mlx/c/transforms.h
new file mode 100644
index 000000000..b2434619b
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/transforms.h
@@ -0,0 +1,68 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_TRANSFORMS_H
+#define MLX_TRANSFORMS_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup transforms Transform operations
+ */
+/**@{*/
+
+int mlx_async_eval(const mlx_vector_array outputs);
+int mlx_checkpoint(mlx_closure* res, const mlx_closure fun);
+int mlx_custom_function(
+ mlx_closure* res,
+ const mlx_closure fun,
+ const mlx_closure_custom fun_vjp /* may be null */,
+ const mlx_closure_custom_jvp fun_jvp /* may be null */,
+ const mlx_closure_custom_vmap fun_vmap /* may be null */);
+int mlx_custom_vjp(
+ mlx_closure* res,
+ const mlx_closure fun,
+ const mlx_closure_custom fun_vjp);
+int mlx_eval(const mlx_vector_array outputs);
+int mlx_jvp(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array primals,
+ const mlx_vector_array tangents);
+int mlx_value_and_grad(
+ mlx_closure_value_and_grad* res,
+ const mlx_closure fun,
+ const int* argnums,
+ size_t argnums_num);
+int mlx_vjp(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array primals,
+ const mlx_vector_array cotangents);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h b/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h
new file mode 100644
index 000000000..2b1356ebc
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h
@@ -0,0 +1,54 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_TRANSFORMS_IMPL_H
+#define MLX_TRANSFORMS_IMPL_H
+
+#include
+#include
+#include
+
+#include "mlx/c/array.h"
+#include "mlx/c/closure.h"
+#include "mlx/c/distributed_group.h"
+#include "mlx/c/io_types.h"
+#include "mlx/c/map.h"
+#include "mlx/c/stream.h"
+#include "mlx/c/string.h"
+#include "mlx/c/vector.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup transforms_impl Implementation detail operations
+ */
+/**@{*/
+
+int mlx_detail_vmap_replace(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_vector_array s_inputs,
+ const mlx_vector_array s_outputs,
+ const int* in_axes,
+ size_t in_axes_num,
+ const int* out_axes,
+ size_t out_axes_num);
+int mlx_detail_vmap_trace(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array inputs,
+ const int* in_axes,
+ size_t in_axes_num);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/vector.h b/x/mlxrunner/mlx/include/mlx/c/vector.h
new file mode 100644
index 000000000..81bcf7495
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/vector.h
@@ -0,0 +1,133 @@
+/* Copyright © 2023-2024 Apple Inc. */
+/* */
+/* This file is auto-generated. Do not edit manually. */
+/* */
+
+#ifndef MLX_VECTOR_H
+#define MLX_VECTOR_H
+
+#include "mlx/c/array.h"
+#include "mlx/c/string.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * \defgroup mlx_vector Vectors
+ * MLX vector objects.
+ */
+/**@{*/
+
+/**
+ * A vector of array.
+ */
+typedef struct mlx_vector_array_ {
+ void* ctx;
+} mlx_vector_array;
+mlx_vector_array mlx_vector_array_new(void);
+int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src);
+int mlx_vector_array_free(mlx_vector_array vec);
+mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size);
+mlx_vector_array mlx_vector_array_new_value(const mlx_array val);
+int mlx_vector_array_set_data(
+ mlx_vector_array* vec,
+ const mlx_array* data,
+ size_t size);
+int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val);
+int mlx_vector_array_append_data(
+ mlx_vector_array vec,
+ const mlx_array* data,
+ size_t size);
+int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val);
+size_t mlx_vector_array_size(mlx_vector_array vec);
+int mlx_vector_array_get(
+ mlx_array* res,
+ const mlx_vector_array vec,
+ size_t idx);
+
+/**
+ * A vector of vector_array.
+ */
+typedef struct mlx_vector_vector_array_ {
+ void* ctx;
+} mlx_vector_vector_array;
+mlx_vector_vector_array mlx_vector_vector_array_new(void);
+int mlx_vector_vector_array_set(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_vector_array src);
+int mlx_vector_vector_array_free(mlx_vector_vector_array vec);
+mlx_vector_vector_array mlx_vector_vector_array_new_data(
+ const mlx_vector_array* data,
+ size_t size);
+mlx_vector_vector_array mlx_vector_vector_array_new_value(
+ const mlx_vector_array val);
+int mlx_vector_vector_array_set_data(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_array* data,
+ size_t size);
+int mlx_vector_vector_array_set_value(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_array val);
+int mlx_vector_vector_array_append_data(
+ mlx_vector_vector_array vec,
+ const mlx_vector_array* data,
+ size_t size);
+int mlx_vector_vector_array_append_value(
+ mlx_vector_vector_array vec,
+ const mlx_vector_array val);
+size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec);
+int mlx_vector_vector_array_get(
+ mlx_vector_array* res,
+ const mlx_vector_vector_array vec,
+ size_t idx);
+
+/**
+ * A vector of int.
+ */
+typedef struct mlx_vector_int_ {
+ void* ctx;
+} mlx_vector_int;
+mlx_vector_int mlx_vector_int_new(void);
+int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src);
+int mlx_vector_int_free(mlx_vector_int vec);
+mlx_vector_int mlx_vector_int_new_data(int* data, size_t size);
+mlx_vector_int mlx_vector_int_new_value(int val);
+int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size);
+int mlx_vector_int_set_value(mlx_vector_int* vec, int val);
+int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size);
+int mlx_vector_int_append_value(mlx_vector_int vec, int val);
+size_t mlx_vector_int_size(mlx_vector_int vec);
+int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx);
+
+/**
+ * A vector of string.
+ */
+typedef struct mlx_vector_string_ {
+ void* ctx;
+} mlx_vector_string;
+mlx_vector_string mlx_vector_string_new(void);
+int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src);
+int mlx_vector_string_free(mlx_vector_string vec);
+mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size);
+mlx_vector_string mlx_vector_string_new_value(const char* val);
+int mlx_vector_string_set_data(
+ mlx_vector_string* vec,
+ const char** data,
+ size_t size);
+int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val);
+int mlx_vector_string_append_data(
+ mlx_vector_string vec,
+ const char** data,
+ size_t size);
+int mlx_vector_string_append_value(mlx_vector_string vec, const char* val);
+size_t mlx_vector_string_size(mlx_vector_string vec);
+int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx);
+
+/**@}*/
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/include/mlx/c/version.h b/x/mlxrunner/mlx/include/mlx/c/version.h
new file mode 100644
index 000000000..96dd23877
--- /dev/null
+++ b/x/mlxrunner/mlx/include/mlx/c/version.h
@@ -0,0 +1,18 @@
+/* Copyright © 2023-2024 Apple Inc. */
+
+#ifndef MLX_VERSION_H
+#define MLX_VERSION_H
+
+#include "mlx/c/string.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+int mlx_version(mlx_string* str_);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/x/mlxrunner/mlx/io.go b/x/mlxrunner/mlx/io.go
index 84868e005..0ddbd3b59 100644
--- a/x/mlxrunner/mlx/io.go
+++ b/x/mlxrunner/mlx/io.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
@@ -7,6 +5,7 @@ import "C"
import (
"iter"
+ "runtime"
"unsafe"
)
@@ -21,10 +20,17 @@ func Load(path string) iter.Seq2[string, *Array] {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
- cpu := C.mlx_default_cpu_stream_new()
- defer C.mlx_stream_free(cpu)
+ // Use GPU stream so tensors load directly to GPU memory (CUDA has Load::eval_gpu).
+ // macOS Metal doesn't implement eval_gpu for Load, so fall back to CPU stream.
+ var stream C.mlx_stream
+ if runtime.GOOS == "darwin" {
+ stream = C.mlx_default_cpu_stream_new()
+ } else {
+ stream = C.mlx_default_gpu_stream_new()
+ }
+ defer C.mlx_stream_free(stream)
- C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
+ C.mlx_load_safetensors(&string2array, &string2string, cPath, stream)
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
diff --git a/x/mlxrunner/mlx/memory.go b/x/mlxrunner/mlx/memory.go
index cf36c304c..a243b72c0 100644
--- a/x/mlxrunner/mlx/memory.go
+++ b/x/mlxrunner/mlx/memory.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go
index 962d192ae..d64a6b0ee 100644
--- a/x/mlxrunner/mlx/mlx.go
+++ b/x/mlxrunner/mlx/mlx.go
@@ -1,19 +1,22 @@
-//go:build mlx
-
package mlx
-//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
-//go:generate cmake --build build --parallel
-//go:generate cmake --install build
-//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
+//go:generate sh -c "go run generator/main.go -output=. ./include/mlx/c/*.h"
// #cgo CXXFLAGS: -std=c++17
-// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
-// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
+// #cgo CPPFLAGS: -I${SRCDIR}/include
+// #cgo LDFLAGS: -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h"
import "C"
+// Version returns the MLX core library version string.
+func Version() string {
+ str := C.mlx_string_new()
+ defer C.mlx_string_free(str)
+ C.mlx_version(&str)
+ return C.GoString(C.mlx_string_data(str))
+}
+
func doEval(outputs []*Array, async bool) {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go
index 3d5691368..d3a99a6cd 100644
--- a/x/mlxrunner/mlx/nn.go
+++ b/x/mlxrunner/mlx/nn.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
type Linear struct {
diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go
index 2f97ba8d2..3d8ec7dec 100644
--- a/x/mlxrunner/mlx/ops.go
+++ b/x/mlxrunner/mlx/ops.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go
index 283a2141c..a2e25d68b 100644
--- a/x/mlxrunner/mlx/ops_extra.go
+++ b/x/mlxrunner/mlx/ops_extra.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/random.go b/x/mlxrunner/mlx/random.go
index 6afdbbab4..82c3d6785 100644
--- a/x/mlxrunner/mlx/random.go
+++ b/x/mlxrunner/mlx/random.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/slice.go b/x/mlxrunner/mlx/slice.go
index ab1324774..ea642ebf7 100644
--- a/x/mlxrunner/mlx/slice.go
+++ b/x/mlxrunner/mlx/slice.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
diff --git a/x/mlxrunner/mlx/stream.go b/x/mlxrunner/mlx/stream.go
index 83a3eeffd..9b01b4a85 100644
--- a/x/mlxrunner/mlx/stream.go
+++ b/x/mlxrunner/mlx/stream.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlx
// #include "generated.h"
@@ -27,6 +25,22 @@ var DefaultDevice = sync.OnceValue(func() Device {
return Device{d}
})
+// GPUIsAvailable returns true if a GPU device is available.
+func GPUIsAvailable() bool {
+ dev := C.mlx_device_new_type(C.MLX_GPU, 0)
+ defer C.mlx_device_free(dev)
+ var avail C.bool
+ C.mlx_device_is_available(&avail, dev)
+ return bool(avail)
+}
+
+// SetDefaultDeviceGPU sets the default MLX device to GPU.
+func SetDefaultDeviceGPU() {
+ dev := C.mlx_device_new_type(C.MLX_GPU, 0)
+ C.mlx_set_default_device(dev)
+ C.mlx_device_free(dev)
+}
+
type Stream struct {
ctx C.mlx_stream
}
diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go
index 4cdf6df33..3a85b6eb0 100644
--- a/x/mlxrunner/model/base/base.go
+++ b/x/mlxrunner/model/base/base.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package base
import (
diff --git a/x/mlxrunner/model/base/base_stub.go b/x/mlxrunner/model/base/base_stub.go
deleted file mode 100644
index 318d8f911..000000000
--- a/x/mlxrunner/model/base/base_stub.go
+++ /dev/null
@@ -1,3 +0,0 @@
-//go:build !mlx
-
-package base
diff --git a/x/mlxrunner/model/linear.go b/x/mlxrunner/model/linear.go
index fffdbdb29..788e4e3f0 100644
--- a/x/mlxrunner/model/linear.go
+++ b/x/mlxrunner/model/linear.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package model
import (
diff --git a/x/mlxrunner/model/quant.go b/x/mlxrunner/model/quant.go
index 10896e4b4..d4a56c35c 100644
--- a/x/mlxrunner/model/quant.go
+++ b/x/mlxrunner/model/quant.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package model
import (
diff --git a/x/mlxrunner/model/root.go b/x/mlxrunner/model/root.go
index c912f7f4c..1c05ee6a8 100644
--- a/x/mlxrunner/model/root.go
+++ b/x/mlxrunner/model/root.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package model
import (
diff --git a/x/mlxrunner/model/root_stub.go b/x/mlxrunner/model/root_stub.go
deleted file mode 100644
index 3fcda9c25..000000000
--- a/x/mlxrunner/model/root_stub.go
+++ /dev/null
@@ -1,3 +0,0 @@
-//go:build !mlx
-
-package model
diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go
index 852b04dcc..3ce148c02 100644
--- a/x/mlxrunner/pipeline.go
+++ b/x/mlxrunner/pipeline.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlxrunner
import (
diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go
index acaef79bf..08a376d43 100644
--- a/x/mlxrunner/runner.go
+++ b/x/mlxrunner/runner.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlxrunner
import (
diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go
index a25b23d03..df9da7a99 100644
--- a/x/mlxrunner/sample/sample.go
+++ b/x/mlxrunner/sample/sample.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package sample
import (
diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go
index 9c7d7e775..a9972bfdc 100644
--- a/x/mlxrunner/server.go
+++ b/x/mlxrunner/server.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package mlxrunner
import (
@@ -29,6 +27,13 @@ func Execute(args []string) error {
return fmt.Errorf("MLX not available: %w", err)
}
+ if mlx.GPUIsAvailable() {
+ mlx.SetDefaultDeviceGPU()
+ slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu")
+ } else {
+ slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu")
+ }
+
var (
modelName string
port int
diff --git a/x/mlxrunner/server_stub.go b/x/mlxrunner/server_stub.go
deleted file mode 100644
index 3b0f35500..000000000
--- a/x/mlxrunner/server_stub.go
+++ /dev/null
@@ -1,10 +0,0 @@
-//go:build !mlx
-
-package mlxrunner
-
-import "errors"
-
-// Execute returns an error when not built with MLX support.
-func Execute(args []string) error {
- return errors.New("MLX runner not available: build with mlx tag")
-}
diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go
index edf66657c..01f0559a0 100644
--- a/x/models/gemma3/gemma3.go
+++ b/x/models/gemma3/gemma3.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package gemma3 provides the Gemma 3 text model implementation for MLX.
package gemma3
diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go
index fb9c4af6f..2e8365580 100644
--- a/x/models/glm4_moe_lite/glm4_moe_lite.go
+++ b/x/models/glm4_moe_lite/glm4_moe_lite.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
package glm4_moe_lite
diff --git a/x/models/glm4_moe_lite/parser.go b/x/models/glm4_moe_lite/parser.go
index de1b2cc17..9a9985f68 100644
--- a/x/models/glm4_moe_lite/parser.go
+++ b/x/models/glm4_moe_lite/parser.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package glm4_moe_lite
import (
diff --git a/x/models/glm4_moe_lite/parser_test.go b/x/models/glm4_moe_lite/parser_test.go
index 0ce382709..d15b4e803 100644
--- a/x/models/glm4_moe_lite/parser_test.go
+++ b/x/models/glm4_moe_lite/parser_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package glm4_moe_lite
import (
diff --git a/x/models/glm4_moe_lite/render.go b/x/models/glm4_moe_lite/render.go
index 4998604bf..d15a99e51 100644
--- a/x/models/glm4_moe_lite/render.go
+++ b/x/models/glm4_moe_lite/render.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package glm4_moe_lite
import (
diff --git a/x/models/glm4_moe_lite/render_test.go b/x/models/glm4_moe_lite/render_test.go
index f0d576bec..91b871acc 100644
--- a/x/models/glm4_moe_lite/render_test.go
+++ b/x/models/glm4_moe_lite/render_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package glm4_moe_lite
import (
diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go
index fc7f34488..18e39bc0a 100644
--- a/x/models/llama/llama.go
+++ b/x/models/llama/llama.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package llama provides a Llama-style decoder-only transformer for MLX.
package llama
diff --git a/x/models/nn/nn.go b/x/models/nn/nn.go
index 78f1b92b6..07047024d 100644
--- a/x/models/nn/nn.go
+++ b/x/models/nn/nn.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package nn
import "github.com/ollama/ollama/x/mlxrunner/mlx"
diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go
index 85d427f58..71596af98 100644
--- a/x/models/qwen3/qwen3.go
+++ b/x/models/qwen3/qwen3.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package qwen3 provides the Qwen3 text model implementation for MLX.
package qwen3
diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go
index fbee82b59..642ea1bba 100644
--- a/x/models/qwen3_5/qwen3_5.go
+++ b/x/models/qwen3_5/qwen3_5.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package qwen3_5 provides the Qwen 3.5 text and MoE implementation for MLX.
package qwen3_5
diff --git a/x/models/qwen3_5/qwen3_5_test.go b/x/models/qwen3_5/qwen3_5_test.go
index 0a70da189..8165cd484 100644
--- a/x/models/qwen3_5/qwen3_5_test.go
+++ b/x/models/qwen3_5/qwen3_5_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package qwen3_5
import (
diff --git a/x/models/qwen3_5_moe/qwen3_5_moe.go b/x/models/qwen3_5_moe/qwen3_5_moe.go
index 9e0be26be..a505b458e 100644
--- a/x/models/qwen3_5_moe/qwen3_5_moe.go
+++ b/x/models/qwen3_5_moe/qwen3_5_moe.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
package qwen3_5_moe
diff --git a/x/tokenizer/tokenizer.go b/x/tokenizer/tokenizer.go
index 301e51aea..a1ce5e8ee 100644
--- a/x/tokenizer/tokenizer.go
+++ b/x/tokenizer/tokenizer.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
diff --git a/x/tokenizer/tokenizer_benchmark_test.go b/x/tokenizer/tokenizer_benchmark_test.go
index e65a59786..9f3d2a10e 100644
--- a/x/tokenizer/tokenizer_benchmark_test.go
+++ b/x/tokenizer/tokenizer_benchmark_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import (
diff --git a/x/tokenizer/tokenizer_bpe.go b/x/tokenizer/tokenizer_bpe.go
index 1e625c20a..9037f15bc 100644
--- a/x/tokenizer/tokenizer_bpe.go
+++ b/x/tokenizer/tokenizer_bpe.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import "container/heap"
diff --git a/x/tokenizer/tokenizer_correctness_test.go b/x/tokenizer/tokenizer_correctness_test.go
index 2fe94d279..91adc167d 100644
--- a/x/tokenizer/tokenizer_correctness_test.go
+++ b/x/tokenizer/tokenizer_correctness_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import (
diff --git a/x/tokenizer/tokenizer_decode.go b/x/tokenizer/tokenizer_decode.go
index e02d2a88b..0056948a2 100644
--- a/x/tokenizer/tokenizer_decode.go
+++ b/x/tokenizer/tokenizer_decode.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import (
diff --git a/x/tokenizer/tokenizer_encode.go b/x/tokenizer/tokenizer_encode.go
index 1b71ea6d3..3eb629e56 100644
--- a/x/tokenizer/tokenizer_encode.go
+++ b/x/tokenizer/tokenizer_encode.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import (
diff --git a/x/tokenizer/tokenizer_ggml_parity_test.go b/x/tokenizer/tokenizer_ggml_parity_test.go
index 4cef3d3dd..ee9b68f38 100644
--- a/x/tokenizer/tokenizer_ggml_parity_test.go
+++ b/x/tokenizer/tokenizer_ggml_parity_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import (
diff --git a/x/tokenizer/tokenizer_load.go b/x/tokenizer/tokenizer_load.go
index d2a253e17..efd086628 100644
--- a/x/tokenizer/tokenizer_load.go
+++ b/x/tokenizer/tokenizer_load.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import (
diff --git a/x/tokenizer/tokenizer_load_test.go b/x/tokenizer/tokenizer_load_test.go
index 136399c2e..caf2b0d35 100644
--- a/x/tokenizer/tokenizer_load_test.go
+++ b/x/tokenizer/tokenizer_load_test.go
@@ -1,5 +1,3 @@
-//go:build mlx
-
package tokenizer
import (