diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 6d9596807..a51ca1b9c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -117,6 +117,25 @@ jobs: install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe flags: '' runner_dir: 'vulkan' + - os: windows + arch: amd64 + preset: 'MLX CUDA 13' + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"cufft"' + - '"cufft_dev"' + - '"nvrtc"' + - '"nvrtc_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' + flags: '' runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release env: @@ -125,8 +144,10 @@ jobs: - name: Install system dependencies run: | choco install -y --no-progress ccache ninja - ccache -o cache_dir=${{ github.workspace }}\.ccache - - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') + if (Get-Command ccache -ErrorAction SilentlyContinue) { + ccache -o cache_dir=${{ github.workspace }}\.ccache + } + - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ') id: cache-install uses: actions/cache/restore@v4 with: @@ -134,8 +155,9 @@ jobs: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm C:\VulkanSDK - key: ${{ matrix.install }} - - if: startsWith(matrix.preset, 'CUDA ') + C:\Program Files\NVIDIA\CUDNN + key: ${{ matrix.install }}-${{ matrix.cudnn-install }} + - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ') name: Install CUDA ${{ matrix.cuda-version }} run: | $ErrorActionPreference = "Stop" @@ -179,6 +201,23 @@ jobs: run: | echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append + - if: startsWith(matrix.preset, 'MLX ') + name: Install cuDNN for MLX + run: | + $ErrorActionPreference = "Stop" + $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN" + if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { + Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip" + Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted + $cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName + New-Item -ItemType Directory -Force -Path $cudnnRoot + Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse + } + + echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: @@ -186,7 +225,8 @@ jobs: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm C:\VulkanSDK - key: ${{ matrix.install }} + C:\Program Files\NVIDIA\CUDNN + key: ${{ matrix.install }}-${{ matrix.cudnn-install }} - uses: actions/checkout@v4 - uses: actions/cache@v4 with: @@ -198,7 +238,7 @@ jobs: Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}" cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}" - cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip + cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue env: CMAKE_GENERATOR: Ninja diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a47156516..cf0545b56 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -37,7 +37,7 @@ jobs: | xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))" } - echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT + echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT linux: @@ -60,6 +60,10 @@ jobs: mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev vulkan-sdk cmake ccache g++ make + - preset: 'MLX CUDA 13' + container: nvidia/cuda:13.0.0-devel-ubuntu22.04 + extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl + flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu' runs-on: linux container: ${{ matrix.container }} steps: @@ -76,6 +80,10 @@ jobs: $sudo apt-get update fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} + # MLX requires CMake 3.25+, install from official releases + if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then + curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1 + fi # Export VULKAN_SDK if provided by LunarG package (defensive) if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then echo "VULKAN_SDK=/usr" >> $GITHUB_ENV @@ -87,8 +95,8 @@ jobs: path: /github/home/.cache/ccache key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }} - run: | - cmake --preset ${{ matrix.preset }} ${{ matrix.flags }} - cmake --build --preset ${{ matrix.preset }} --parallel + cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} + cmake --build --preset "${{ matrix.preset }}" --parallel windows: needs: [changes] @@ -114,12 +122,31 @@ jobs: flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - preset: Vulkan install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe + - preset: 'MLX CUDA 13' + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip + flags: '-DCMAKE_CUDA_ARCHITECTURES=80' + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"cufft"' + - '"cufft_dev"' + - '"nvrtc"' + - '"nvrtc_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' runs-on: windows steps: - run: | choco install -y --no-progress ccache ninja - ccache -o cache_dir=${{ github.workspace }}\.ccache - - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' + if (Get-Command ccache -ErrorAction SilentlyContinue) { + ccache -o cache_dir=${{ github.workspace }}\.ccache + } + - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13' id: cache-install uses: actions/cache/restore@v4 with: @@ -127,8 +154,9 @@ jobs: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm C:\VulkanSDK - key: ${{ matrix.install }} - - if: matrix.preset == 'CUDA' + C:\Program Files\NVIDIA\CUDNN + key: ${{ matrix.install }}-${{ matrix.cudnn-install }} + - if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13' name: Install CUDA ${{ matrix.cuda-version }} run: | $ErrorActionPreference = "Stop" @@ -164,10 +192,27 @@ jobs: Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait } - + $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV + - if: matrix.preset == 'MLX CUDA 13' + name: Install cuDNN for MLX + run: | + $ErrorActionPreference = "Stop" + $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN" + if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { + Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip" + Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted + $cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName + New-Item -ItemType Directory -Force -Path $cudnnRoot + Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse + } + + echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: @@ -175,7 +220,8 @@ jobs: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm C:\VulkanSDK - key: ${{ matrix.install }} + C:\Program Files\NVIDIA\CUDNN + key: ${{ matrix.install }}-${{ matrix.cudnn-install }} - uses: actions/checkout@v4 - uses: actions/cache@v4 with: diff --git a/CMakeLists.txt b/CMakeLists.txt index a9e4471d8..8c37d3374 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR}) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) +# Store ggml include paths for use with target_include_directories later. +# We avoid global include_directories() to prevent polluting the include path +# for other projects like MLX (whose openblas dependency has its own common.h). +set(GGML_INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src + ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include + ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu + ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx +) add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0) @@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS) set(CPU_VARIANTS "ggml-cpu") endif() +# Apply ggml include directories to ggml targets only (not globally) +target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS}) +foreach(variant ${CPU_VARIANTS}) + if(TARGET ${variant}) + target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS}) + endif() +endforeach() + install(TARGETS ggml-base ${CPU_VARIANTS} RUNTIME_DEPENDENCIES PRE_EXCLUDE_REGEXES ".*" @@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER) find_package(CUDAToolkit) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) + target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS}) install(TARGETS ggml-cuda RUNTIME_DEPENDENCIES DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} @@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER) if(AMDGPU_TARGETS) find_package(hip REQUIRED) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) + target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS}) if (WIN32) target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) @@ -168,6 +183,7 @@ if(NOT APPLE) find_package(Vulkan) if(Vulkan_FOUND) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) + target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS}) install(TARGETS ggml-vulkan RUNTIME_DEPENDENCIES PRE_INCLUDE_REGEXES vulkan @@ -179,7 +195,6 @@ if(NOT APPLE) endif() option(MLX_ENGINE "Enable MLX backend" OFF) - if(MLX_ENGINE) message(STATUS "Setting up MLX (this takes a while...)") add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx) @@ -187,10 +202,36 @@ if(MLX_ENGINE) # Find CUDA toolkit if MLX is built with CUDA support find_package(CUDAToolkit) + # Build list of directories for runtime dependency resolution + set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}) + # Add cuDNN bin paths for DLLs (Windows MLX CUDA builds) + # CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location + if(DEFINED ENV{CUDNN_ROOT_DIR}) + # cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/) + file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*") + list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS}) + endif() + # Add build output directory and MLX dependency build directories + list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR}) + # OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/) + list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin) + # NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the + # regex below. If NCCL is not found, MLX links a static stub (OBJECT lib) + # so there is no runtime dependency. This path covers the stub build dir + # for windows so we include the DLL in our dependencies. + list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release) + + # Base regexes for runtime dependencies (cross-platform) + set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran) + # On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer) + if(WIN32) + list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$") + endif() + install(TARGETS mlx mlxc RUNTIME_DEPENDENCIES - DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} - PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran + DIRECTORIES ${MLX_RUNTIME_DIRS} + PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES} PRE_EXCLUDE_REGEXES ".*" RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX @@ -205,13 +246,54 @@ if(MLX_ENGINE) COMPONENT MLX) endif() - # Manually install cudart and cublas since they might not be picked up as direct dependencies + # Install CCCL headers for NVRTC JIT compilation at runtime. + # MLX's own install rules use the default component so they get skipped by + # --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR. + # On Linux, MLX's jit_module.cpp resolves CCCL via + # current_binary_dir().parent_path() / "include" / "cccl", so we create a + # symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include + # This will need refinement if we add multiple CUDA versions for MLX in the future. + if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda) + install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda + DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl + COMPONENT MLX) + install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv + DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl + COMPONENT MLX) + if(NOT WIN32 AND NOT APPLE) + install(CODE " + set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\") + set(_target \"${OLLAMA_RUNNER_DIR}/include\") + if(NOT EXISTS \${_link}) + execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link}) + endif() + " COMPONENT MLX) + endif() + endif() + + # On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation) + # RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because + # dlfcn-win32 is a known CMake target with its own install rules (which install + # to the wrong destination). We must install it explicitly here. + if(WIN32) + install(FILES ${OLLAMA_BUILD_DIR}/dl.dll + DESTINATION ${OLLAMA_INSTALL_DIR} + COMPONENT MLX) + endif() + + # Manually install CUDA runtime libraries that MLX loads via dlopen + # (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps) if(CUDAToolkit_FOUND) - file(GLOB CUDART_LIBS + file(GLOB MLX_CUDA_LIBS "${CUDAToolkit_LIBRARY_DIR}/libcudart.so*" - "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*") - if(CUDART_LIBS) - install(FILES ${CUDART_LIBS} + "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*" + "${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*" + "${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*" + "${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*" + "${CUDAToolkit_LIBRARY_DIR}/libcufft.so*" + "${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*") + if(MLX_CUDA_LIBS) + install(FILES ${MLX_CUDA_LIBS} DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX) endif() diff --git a/CMakePresets.json b/CMakePresets.json index 0d643038a..d099d3f16 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -112,6 +112,7 @@ "name": "MLX CUDA 13", "inherits": [ "MLX", "CUDA 13" ], "cacheVariables": { + "MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual", "OLLAMA_RUNNER_DIR": "mlx_cuda_v13" } } diff --git a/Dockerfile b/Dockerfile index cabb6cc82..0743cc45d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,11 @@ ARG CMAKEVERSION=3.31.2 ARG NINJAVERSION=1.12.1 ARG VULKANVERSION=1.4.321.1 +# Default empty stages for local MLX source overrides. +# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c +FROM scratch AS local-mlx +FROM scratch AS local-mlx-c + FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo @@ -152,12 +157,20 @@ COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY x/imagegen/mlx x/imagegen/mlx COPY go.mod go.sum . -COPY MLX_VERSION . +COPY MLX_VERSION MLX_CORE_VERSION . RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local ENV PATH=/usr/local/go/bin:$PATH RUN go mod download RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \ + --mount=type=bind,from=local-mlx,target=/tmp/local-mlx \ + --mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \ + if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \ + export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \ + fi \ + && if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \ + export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \ + fi \ + && cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \ && cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \ && cmake --install build --component MLX --strip @@ -168,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux- ENV PATH=/usr/local/go/bin:$PATH RUN go mod download COPY . . -# Clone mlx-c headers for CGO (version from MLX_VERSION file) -RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src ARG GOFLAGS="'-ldflags=-w -s'" ENV CGO_ENABLED=1 ARG CGO_CFLAGS ARG CGO_CXXFLAGS -ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src" +ENV CGO_CFLAGS="${CGO_CFLAGS}" ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}" RUN --mount=type=cache,target=/root/.cache/go-build \ - go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama . + go build -trimpath -buildmode=pie -o /bin/ollama . FROM --platform=linux/amd64 scratch AS amd64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ diff --git a/MLX_CORE_VERSION b/MLX_CORE_VERSION new file mode 100644 index 000000000..912750052 --- /dev/null +++ b/MLX_CORE_VERSION @@ -0,0 +1 @@ +v0.30.6 diff --git a/docs/development.md b/docs/development.md index d0120a191..12c69c204 100644 --- a/docs/development.md +++ b/docs/development.md @@ -51,6 +51,9 @@ Install prerequisites: - [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network) - (Optional) VULKAN GPU support - [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs +- (Optional) MLX engine support + - [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads) + - [cuDNN 9+](https://developer.nvidia.com/cudnn) Then, configure and build the project: @@ -101,6 +104,10 @@ Install prerequisites: - (Optional) VULKAN GPU support - [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs - Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS) +- (Optional) MLX engine support + - [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads) + - [cuDNN 9+](https://developer.nvidia.com/cudnn) + - OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian) > [!IMPORTANT] > Ensure prerequisites are in `PATH` before running CMake. @@ -118,6 +125,67 @@ Lastly, run Ollama: go run . serve ``` +## MLX Engine (Optional) + +The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13. + +### macOS (Apple Silicon) + +Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then: + +```shell +xcodebuild -downloadComponent MetalToolchain +``` + +Verify it's installed correctly (should print "no input files"): + +```shell +xcrun metal +``` + +Then build: + +```shell +cmake -B build --preset MLX +cmake --build build --preset MLX --parallel +cmake --install build --component MLX +``` + +> [!NOTE] +> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing. + +### Windows / Linux (CUDA) + +Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+. + +```shell +cmake -B build --preset "MLX CUDA 13" +cmake --build build --target mlx --target mlxc --config Release --parallel +cmake --install build --component MLX --strip +``` + +### Local MLX source overrides + +To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake: + +```shell +export OLLAMA_MLX_SOURCE=/path/to/mlx +export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c +``` + +For example, using the helper scripts with local mlx and mlx-c repos: +```shell +OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh + +OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh +``` + +```powershell +$env:OLLAMA_MLX_SOURCE="../mlx" +$env:OLLAMA_MLX_C_SOURCE="../mlx-c" +./scripts/build_darwin.ps1 +``` + ## Docker ```shell diff --git a/parser/parser.go b/parser/parser.go index 5ef918bf2..f3b6dcb55 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -181,6 +181,9 @@ func fileDigestMap(path string) (map[string]string, error) { } if !filepath.IsLocal(rel) { + if strings.Contains(rel, ".cache") { + return nil, fmt.Errorf("insecure path: %s\n\nUse --local-dir when downloading model to disable caching", rel) + } return nil, fmt.Errorf("insecure path: %s", rel) } diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index 4325a9787..4bec54cf8 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -59,7 +59,7 @@ _build_darwin() { cmake --install $BUILD_DIR --component CPU cmake --install $BUILD_DIR --component MLX # Override CGO flags to point to the amd64 build directory - MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" + MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0" MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0" else BUILD_DIR=build @@ -70,10 +70,10 @@ _build_darwin() { cmake --build --preset MLX --parallel cmake --install $BUILD_DIR --component MLX # Use default CGO flags from mlx.go for arm64 - MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" + MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0" MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0" fi - GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX . + GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -o $INSTALL_PREFIX . # Copy MLX libraries to same directory as executable for dlopen cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/ cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/ diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 21e6f3be0..f162e4f7e 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -4,7 +4,10 @@ # # gcloud auth application-default login -$ErrorActionPreference = "Stop" +# Use "Continue" so that stderr output from native commands (e.g. CGo warnings) +# is not promoted to a terminating exception by the try/catch block. +# All native commands already check $LASTEXITCODE explicitly. +$ErrorActionPreference = "Continue" mkdir -Force -path .\dist | Out-Null @@ -16,13 +19,13 @@ function checkEnv { if ($null -ne $arch) { $script:ARCH = ($arch.ToString().ToLower()).Replace("x64", "amd64") } else { - write-host "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override" + Write-Output "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override" $script:ARCH="amd64" } } $script:TARGET_ARCH=$script:ARCH Write-host "Building for ${script:TARGET_ARCH}" - write-host "Locating required tools and paths" + Write-Output "Locating required tools and paths" $script:SRC_DIR=$PWD # Locate CUDA versions @@ -37,16 +40,17 @@ function checkEnv { $script:CUDA_DIRS=($cudaList | sort-object -Descending) } if ($script:CUDA_DIRS.length -gt 0) { - write-host "Available CUDA Versions: $script:CUDA_DIRS" + Write-Output "Available CUDA Versions: $script:CUDA_DIRS" } else { - write-host "No CUDA versions detected" + Write-Output "No CUDA versions detected" } - # Locate ROCm version - if ($null -ne $env:HIP_PATH) { + # Locate ROCm v6 + $rocmDir=(get-item "C:\Program Files\AMD\ROCm\6.*" -ea 'silentlycontinue' | sort-object -Descending | select-object -First 1) + if ($null -ne $rocmDir) { + $script:HIP_PATH=$rocmDir.FullName + } elseif ($null -ne $env:HIP_PATH -and $env:HIP_PATH -match '[/\\]6\.') { $script:HIP_PATH=$env:HIP_PATH - } else { - $script:HIP_PATH=(get-item "C:\Program Files\AMD\ROCm\*\bin\" -ea 'silentlycontinue' | sort-object -Descending) } $inoSetup=(get-item "C:\Program Files*\Inno Setup*\") @@ -78,7 +82,7 @@ function checkEnv { } else { $script:PKG_VERSION="0.0.0" } - write-host "Building Ollama $script:VERSION with package version $script:PKG_VERSION" + Write-Output "Building Ollama $script:VERSION with package version $script:PKG_VERSION" # Note: Windows Kits 10 signtool crashes with GCP's plugin if ($null -eq $env:SIGN_TOOL) { @@ -87,12 +91,32 @@ function checkEnv { ${script:SignTool}=${env:SIGN_TOOL} } if ("${env:KEY_CONTAINER}") { - ${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt") - Write-host "Code signing enabled" + if (Test-Path "${script:SRC_DIR}\ollama_inc.crt") { + ${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt") + Write-host "Code signing enabled" + } else { + Write-Output "WARNING: KEY_CONTAINER is set but ollama_inc.crt not found at ${script:SRC_DIR}\ollama_inc.crt - code signing disabled" + } } else { - write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree" + Write-Output "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree" } - $script:JOBS=([Environment]::ProcessorCount) + if ($env:OLLAMA_BUILD_PARALLEL) { + $script:JOBS=[int]$env:OLLAMA_BUILD_PARALLEL + } else { + # Use physical core count rather than logical processors (hyperthreads) + # to avoid saturating the system during builds + try { + $cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfCores -Sum).Sum + } catch { + $cores = 0 + } + if ($cores -gt 0) { + $script:JOBS = $cores + } else { + $script:JOBS = [Environment]::ProcessorCount + } + } + Write-Output "Build parallelism: $script:JOBS (set OLLAMA_BUILD_PARALLEL to override)" } @@ -127,7 +151,7 @@ function cuda11 { } } } - write-host "Building CUDA v$cudaMajorVer backend libraries $cuda" + Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda" $env:CUDAToolkit_ROOT=$cuda & cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix "$script:DIST_DIR" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -136,12 +160,12 @@ function cuda11 { & cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } else { - write-host "CUDA v$cudaMajorVer not detected, skipping" + Write-Output "CUDA v$cudaMajorVer not detected, skipping" } } else { - write-host "not arch we wanted" + Write-Output "not arch we wanted" } - write-host "done" + Write-Output "done" } function cudaCommon { @@ -159,7 +183,7 @@ function cudaCommon { } } } - write-host "Building CUDA v$cudaMajorVer backend libraries $cuda" + Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda" $env:CUDAToolkit_ROOT=$cuda & cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -168,7 +192,7 @@ function cudaCommon { & cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } else { - write-host "CUDA v$cudaMajorVer not detected, skipping" + Write-Output "CUDA v$cudaMajorVer not detected, skipping" } } } @@ -181,11 +205,11 @@ function cuda13 { cudaCommon("13") } -function rocm { +function rocm6 { mkdir -Force -path "${script:DIST_DIR}\" | Out-Null if ($script:ARCH -ne "arm64") { if ($script:HIP_PATH) { - write-host "Building ROCm backend libraries $script:HIP_PATH" + Write-Output "Building ROCm backend libraries $script:HIP_PATH" if (-Not (get-command -ErrorAction silent ninja)) { $NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe).Directory.FullName $env:PATH="$NINJA_DIR;$env:PATH" @@ -193,9 +217,11 @@ function rocm { $env:HIPCXX="${script:HIP_PATH}\bin\clang++.exe" $env:HIP_PLATFORM="amd" $env:CMAKE_PREFIX_PATH="${script:HIP_PATH}" + # Set CC/CXX via environment instead of -D flags to avoid triggering + # spurious compiler-change reconfigures that reset CMAKE_INSTALL_PREFIX + $env:CC="${script:HIP_PATH}\bin\clang.exe" + $env:CXX="${script:HIP_PATH}\bin\clang++.exe" & cmake -B build\rocm --preset "ROCm 6" -G Ninja ` - -DCMAKE_C_COMPILER=clang ` - -DCMAKE_CXX_COMPILER=clang++ ` -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` --install-prefix $script:DIST_DIR @@ -203,20 +229,22 @@ function rocm { $env:HIPCXX="" $env:HIP_PLATFORM="" $env:CMAKE_PREFIX_PATH="" + $env:CC="" + $env:CXX="" & cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build\rocm --component "HIP" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue } else { - write-host "ROCm not detected, skipping" + Write-Output "ROCm not detected, skipping" } } } function vulkan { if ($env:VULKAN_SDK) { - write-host "Building Vulkan backend libraries" + Write-Output "Building Vulkan backend libraries" & cmake -B build\vulkan --preset Vulkan --install-prefix $script:DIST_DIR if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --build build\vulkan --target ggml-vulkan --config Release --parallel $script:JOBS @@ -224,33 +252,91 @@ function vulkan { & cmake --install build\vulkan --component Vulkan --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } else { - write-host "Vulkan not detected, skipping" + Write-Output "Vulkan not detected, skipping" + } +} + +function mlxCuda13 { + mkdir -Force -path "${script:DIST_DIR}\" | Out-Null + $cudaMajorVer="13" + if ($script:ARCH -ne "arm64") { + if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) { + foreach ($d in $Script:CUDA_DIRS){ + if ($d.FullName.Contains("v$cudaMajorVer")) { + if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) { + $cuda=($d.FullName|split-path -parent) + break + } + } + } + + # Check for cuDNN - required for MLX CUDA backend + # Supports two layouts: + # 1. CI/zip extract: CUDNN\include\cudnn.h, lib\x64\, bin\x64\ + # 2. Official installer: CUDNN\v*\include\{cuda-ver}\cudnn.h, lib\{cuda-ver}\x64\, bin\{cuda-ver}\ + if ($env:CUDNN_INCLUDE_PATH -and $env:CUDNN_LIBRARY_PATH) { + Write-Output "Using cuDNN from environment: $env:CUDNN_INCLUDE_PATH" + } elseif (Test-Path "C:\Program Files\NVIDIA\CUDNN\include\cudnn.h") { + # CI/zip layout (flat) + $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN" + $env:CUDNN_ROOT_DIR = $cudnnRoot + $env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include" + $env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\x64" + Write-Output "Found cuDNN at $cudnnRoot (flat layout)" + } else { + # Official installer layout (versioned) + $cudnnRoot = $null + $resolved = Resolve-Path -Path "C:\Program Files\NVIDIA\CUDNN\v*" -ErrorAction SilentlyContinue | Sort-Object -Descending | Select-Object -First 1 + if ($resolved -and (Test-Path "$($resolved.Path)\include\$cudaMajorVer.0\cudnn.h")) { + $cudnnRoot = $resolved.Path + $env:CUDNN_ROOT_DIR = $cudnnRoot + $env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include\$cudaMajorVer.0" + $env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\$cudaMajorVer.0\x64" + Write-Output "Found cuDNN at $cudnnRoot (official installer, CUDA $cudaMajorVer.0)" + } else { + Write-Output "cuDNN not found - set CUDNN_INCLUDE_PATH and CUDNN_LIBRARY_PATH environment variables" + Write-Output "Skipping MLX build" + return + } + } + + Write-Output "Building MLX CUDA v$cudaMajorVer backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake -B build\mlx_cuda_v$cudaMajorVer --preset "MLX CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build build\mlx_cuda_v$cudaMajorVer --target mlx --target mlxc --config Release --parallel $script:JOBS -- /nodeReuse:false + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build\mlx_cuda_v$cudaMajorVer --component "MLX" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } else { + Write-Output "CUDA v$cudaMajorVer not detected, skipping MLX build" + } } } function ollama { mkdir -Force -path "${script:DIST_DIR}\" | Out-Null - write-host "Building ollama CLI" + Write-Output "Building ollama CLI" & go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" . if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} cp .\ollama.exe "${script:DIST_DIR}\" } function app { - write-host "Building Ollama App $script:VERSION with package version $script:PKG_VERSION" + Write-Output "Building Ollama App $script:VERSION with package version $script:PKG_VERSION" if (!(Get-Command npm -ErrorAction SilentlyContinue)) { - write-host "npm is not installed. Please install Node.js and npm first:" - write-host " Visit: https://nodejs.org/" + Write-Output "npm is not installed. Please install Node.js and npm first:" + Write-Output " Visit: https://nodejs.org/" exit 1 } if (!(Get-Command tsc -ErrorAction SilentlyContinue)) { - write-host "Installing TypeScript compiler..." + Write-Output "Installing TypeScript compiler..." npm install -g typescript } if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) { - write-host "Installing tscriptify..." + Write-Output "Installing tscriptify..." go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest } if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) { @@ -260,32 +346,32 @@ function app { Push-Location app/ui/app npm install if ($LASTEXITCODE -ne 0) { - write-host "ERROR: npm install failed with exit code $LASTEXITCODE" + Write-Output "ERROR: npm install failed with exit code $LASTEXITCODE" exit $LASTEXITCODE } - write-host "Building React application..." + Write-Output "Building React application..." npm run build if ($LASTEXITCODE -ne 0) { - write-host "ERROR: npm run build failed with exit code $LASTEXITCODE" + Write-Output "ERROR: npm run build failed with exit code $LASTEXITCODE" exit $LASTEXITCODE } # Check if dist directory exists and has content if (!(Test-Path "dist")) { - write-host "ERROR: dist directory was not created by npm run build" + Write-Output "ERROR: dist directory was not created by npm run build" exit 1 } $distFiles = Get-ChildItem "dist" -Recurse if ($distFiles.Count -eq 0) { - write-host "ERROR: dist directory is empty after npm run build" + Write-Output "ERROR: dist directory is empty after npm run build" exit 1 } Pop-Location - write-host "Running go generate" + Write-Output "Running go generate" & go generate ./... if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/ollama/ollama/app/version.Version=$script:VERSION" -o .\dist\windows-ollama-app-${script:ARCH}.exe ./app/cmd/app/ @@ -293,42 +379,42 @@ function app { } function deps { - write-host "Download MSVC Redistributables" + Write-Output "Download MSVC Redistributables" mkdir -Force -path "${script:SRC_DIR}\dist\\windows-arm64" | Out-Null mkdir -Force -path "${script:SRC_DIR}\dist\\windows-amd64" | Out-Null - invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" - invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" - write-host "Done." + invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" -ErrorAction Stop + invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" -ErrorAction Stop + Write-Output "Done." } function sign { # Copy install.ps1 to dist for release packaging - write-host "Copying install.ps1 to dist" - Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" + Write-Output "Copying install.ps1 to dist" + Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" -ErrorAction Stop if ("${env:KEY_CONTAINER}") { - write-host "Signing Ollama executables, scripts and libraries" + Write-Output "Signing Ollama executables, scripts and libraries" & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ` $(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll')) if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - write-host "Signing install.ps1" + Write-Output "Signing install.ps1" & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ` "${script:SRC_DIR}\dist\install.ps1" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } else { - write-host "Signing not enabled" + Write-Output "Signing not enabled" } } function installer { if ($null -eq ${script:INNO_SETUP_DIR}) { - write-host "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php" + Write-Output "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php" exit 1 } - write-host "Building Ollama Installer" + Write-Output "Building Ollama Installer" cd "${script:SRC_DIR}\app" $env:PKG_VERSION=$script:PKG_VERSION if ("${env:KEY_CONTAINER}") { @@ -342,24 +428,24 @@ function installer { function zip { if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") { if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") { - write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" + Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" # Temporarily adjust paths so we can retain the same directory structure Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm" mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt" - Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" + Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" -ErrorAction Stop Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force } - write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip" + Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip" Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") { - Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" + Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" -ErrorAction Stop } } if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") { - write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip" + Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip" Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force } } @@ -375,8 +461,9 @@ try { cpu cuda12 cuda13 - rocm + rocm6 vulkan + mlxCuda13 ollama app deps @@ -385,13 +472,13 @@ try { zip } else { for ( $i = 0; $i -lt $args.count; $i++ ) { - write-host "running build step $($args[$i])" + Write-Output "running build step $($args[$i])" & $($args[$i]) } } } catch { - write-host "Build Failed" - write-host $_ + Write-Error "Build Failed: $($_.Exception.Message)" + Write-Error "$($_.ScriptStackTrace)" } finally { set-location $script:SRC_DIR $env:PKG_VERSION="" diff --git a/scripts/env.sh b/scripts/env.sh index 65a970bdc..3e6d69cc9 100644 --- a/scripts/env.sh +++ b/scripts/env.sh @@ -18,6 +18,14 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \ --build-arg=GPU_RUNNER_CPU_FLAGS \ --build-arg=AMDGPU_TARGETS" +# Forward local MLX source overrides as Docker build contexts +if [ -n "${OLLAMA_MLX_SOURCE:-}" ]; then + OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx=$(cd "$OLLAMA_MLX_SOURCE" && pwd)" +fi +if [ -n "${OLLAMA_MLX_C_SOURCE:-}" ]; then + OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx-c=$(cd "$OLLAMA_MLX_C_SOURCE" && pwd)" +fi + echo "Building Ollama" echo "VERSION=$VERSION" echo "PLATFORM=$PLATFORM" \ No newline at end of file diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index 6ccb1ad6d..bb379bbac 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -1,5 +1,3 @@ -//go:build mlx - package client import ( @@ -194,9 +192,10 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) { return blobData, nil } -// QuantizeSupported returns true if quantization is supported (MLX build) +// QuantizeSupported returns true if quantization is supported (MLX library available) func QuantizeSupported() bool { - return true + mlx.InitMLX() + return mlx.IsMLXAvailable() } // ensureTempDir creates the temp directory for quantization if it doesn't exist diff --git a/x/create/client/quantize_stub.go b/x/create/client/quantize_stub.go deleted file mode 100644 index 7a75671a0..000000000 --- a/x/create/client/quantize_stub.go +++ /dev/null @@ -1,25 +0,0 @@ -//go:build !mlx - -package client - -import ( - "fmt" - "io" - - "github.com/ollama/ollama/x/create" -) - -// quantizeTensor is not available without MLX -func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) { - return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") -} - -// quantizePackedGroup is not available without MLX -func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) { - return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") -} - -// QuantizeSupported returns false when MLX is not available -func QuantizeSupported() bool { - return false -} diff --git a/x/imagegen/cache/cache.go b/x/imagegen/cache/cache.go index 8a25193cd..d4e19ba55 100644 --- a/x/imagegen/cache/cache.go +++ b/x/imagegen/cache/cache.go @@ -1,5 +1,3 @@ -//go:build mlx - package cache import "github.com/ollama/ollama/x/imagegen/mlx" diff --git a/x/imagegen/cache/step.go b/x/imagegen/cache/step.go index f91f22fa0..066f2f645 100644 --- a/x/imagegen/cache/step.go +++ b/x/imagegen/cache/step.go @@ -1,5 +1,3 @@ -//go:build mlx - package cache import "github.com/ollama/ollama/x/imagegen/mlx" diff --git a/x/imagegen/cache/teacache.go b/x/imagegen/cache/teacache.go index 60031d8cb..fb06047ea 100644 --- a/x/imagegen/cache/teacache.go +++ b/x/imagegen/cache/teacache.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package cache provides caching mechanisms for diffusion model inference. package cache diff --git a/x/imagegen/cmd/engine/README.md b/x/imagegen/cmd/engine/README.md index 3991c02a8..e6ab2f5d3 100644 --- a/x/imagegen/cmd/engine/README.md +++ b/x/imagegen/cmd/engine/README.md @@ -5,7 +5,7 @@ Experimental MLX backend for running models on Apple Silicon and CUDA. ## Build ```bash -go build -tags mlx -o engine ./x/imagegen/cmd/engine +go build -o engine ./x/imagegen/cmd/engine ``` ## Text Generation diff --git a/x/imagegen/cmd/engine/generate.go b/x/imagegen/cmd/engine/generate.go index 51118afc1..95173cd80 100644 --- a/x/imagegen/cmd/engine/generate.go +++ b/x/imagegen/cmd/engine/generate.go @@ -1,5 +1,3 @@ -//go:build mlx - package main import ( diff --git a/x/imagegen/cmd/engine/image.go b/x/imagegen/cmd/engine/image.go index e8af2222a..3c393cf66 100644 --- a/x/imagegen/cmd/engine/image.go +++ b/x/imagegen/cmd/engine/image.go @@ -1,5 +1,3 @@ -//go:build mlx - package main import ( diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go index 6ec7de9e1..31411f466 100644 --- a/x/imagegen/cmd/engine/main.go +++ b/x/imagegen/cmd/engine/main.go @@ -1,5 +1,3 @@ -//go:build mlx - package main import ( diff --git a/x/imagegen/cmd/engine/sample.go b/x/imagegen/cmd/engine/sample.go index 5d723e6dc..165c40774 100644 --- a/x/imagegen/cmd/engine/sample.go +++ b/x/imagegen/cmd/engine/sample.go @@ -1,5 +1,3 @@ -//go:build mlx - package main import "github.com/ollama/ollama/x/imagegen/mlx" diff --git a/x/imagegen/image.go b/x/imagegen/image.go index 2dca0ee1d..9bdd95304 100644 --- a/x/imagegen/image.go +++ b/x/imagegen/image.go @@ -1,5 +1,3 @@ -//go:build mlx - package imagegen import ( diff --git a/x/imagegen/image_processor.go b/x/imagegen/image_processor.go index 7a562feb5..c3e68ebb3 100644 --- a/x/imagegen/image_processor.go +++ b/x/imagegen/image_processor.go @@ -1,5 +1,3 @@ -//go:build mlx - package imagegen import ( diff --git a/x/imagegen/imagegen.go b/x/imagegen/imagegen.go index d870bed9b..c4b586505 100644 --- a/x/imagegen/imagegen.go +++ b/x/imagegen/imagegen.go @@ -1,5 +1,3 @@ -//go:build mlx - package imagegen import ( diff --git a/x/imagegen/manifest/weights.go b/x/imagegen/manifest/weights.go index e1209c9db..fcb30449e 100644 --- a/x/imagegen/manifest/weights.go +++ b/x/imagegen/manifest/weights.go @@ -1,5 +1,3 @@ -//go:build mlx - package manifest import ( diff --git a/x/imagegen/mlx/CMakeLists.txt b/x/imagegen/mlx/CMakeLists.txt index b62cbf2eb..70246ef4b 100644 --- a/x/imagegen/mlx/CMakeLists.txt +++ b/x/imagegen/mlx/CMakeLists.txt @@ -4,6 +4,10 @@ include(FetchContent) file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG) string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG) +# Read MLX core version from top-level file +file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG) +string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG) + set(MLX_C_BUILD_EXAMPLES OFF) set(MLX_BUILD_GGUF OFF) @@ -43,6 +47,17 @@ if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES) message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}") endif() +# Forward cuDNN environment variables to cmake variables so MLX's FindCUDNN.cmake +# can find them via HINTS ${CUDNN_INCLUDE_PATH} / ${CUDNN_LIBRARY_PATH}. +if(DEFINED ENV{CUDNN_INCLUDE_PATH} AND NOT CUDNN_INCLUDE_PATH) + set(CUDNN_INCLUDE_PATH "$ENV{CUDNN_INCLUDE_PATH}" CACHE PATH "cuDNN include path") + message(STATUS "Using CUDNN_INCLUDE_PATH from environment: ${CUDNN_INCLUDE_PATH}") +endif() +if(DEFINED ENV{CUDNN_LIBRARY_PATH} AND NOT CUDNN_LIBRARY_PATH) + set(CUDNN_LIBRARY_PATH "$ENV{CUDNN_LIBRARY_PATH}" CACHE PATH "cuDNN library path") + message(STATUS "Using CUDNN_LIBRARY_PATH from environment: ${CUDNN_LIBRARY_PATH}") +endif() + # Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER) set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE) @@ -51,11 +66,58 @@ elseif(MLX_CUDA_ARCHITECTURES) message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled") endif() +# Allow local source overrides via environment variables +# Resolve to absolute paths so FetchContent doesn't break on relative dirs. +if(DEFINED ENV{OLLAMA_MLX_SOURCE}) + get_filename_component(_mlx_src "$ENV{OLLAMA_MLX_SOURCE}" ABSOLUTE BASE_DIR ${CMAKE_SOURCE_DIR}) + set(FETCHCONTENT_SOURCE_DIR_MLX "${_mlx_src}" CACHE PATH "" FORCE) + message(STATUS "Using local MLX source: ${_mlx_src}") +endif() +if(DEFINED ENV{OLLAMA_MLX_C_SOURCE}) + get_filename_component(_mlx_c_src "$ENV{OLLAMA_MLX_C_SOURCE}" ABSOLUTE BASE_DIR ${CMAKE_SOURCE_DIR}) + set(FETCHCONTENT_SOURCE_DIR_MLX-C "${_mlx_c_src}" CACHE PATH "" FORCE) + message(STATUS "Using local MLX-C source: ${_mlx_c_src}") +endif() + +# Pre-declare mlx so our pinned version takes precedence over the one +# hardcoded in mlx-c's CMakeLists.txt (first FetchContent_Declare wins). FetchContent_Declare( - mlx-c - GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" - GIT_TAG ${MLX_C_GIT_TAG}) + mlx + GIT_REPOSITORY "https://github.com/ml-explore/mlx.git" + GIT_TAG ${MLX_GIT_TAG} +) + +FetchContent_Declare( + mlx-c + GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" + GIT_TAG ${MLX_C_GIT_TAG} +) FetchContent_MakeAvailable(mlx-c) +# Sync vendored headers with fetched version +file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h") +file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/") + +# For local dev builds, override MLX_VERSION with git describe output +if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX) + execute_process( + COMMAND git describe --tags --first-parent --abbrev=7 --long --dirty --always + WORKING_DIRECTORY ${mlx_SOURCE_DIR} + OUTPUT_VARIABLE _mlx_git_version + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + RESULT_VARIABLE _mlx_git_result + ) + if(_mlx_git_result EQUAL 0 AND _mlx_git_version) + # Strip leading "v" prefix for consistency + string(REGEX REPLACE "^v" "" _mlx_git_version "${_mlx_git_version}") + get_target_property(_mlx_defs mlx_version COMPILE_DEFINITIONS) + list(FILTER _mlx_defs EXCLUDE REGEX "^MLX_VERSION=") + set_target_properties(mlx_version PROPERTIES COMPILE_DEFINITIONS "${_mlx_defs}") + target_compile_definitions(mlx_version PRIVATE "MLX_VERSION=\"${_mlx_git_version}\"") + message(STATUS "MLX version (local dev): ${_mlx_git_version}") + endif() +endif() + set_target_output_directory(mlx) set_target_output_directory(mlxc) diff --git a/x/imagegen/mlx/compile.go b/x/imagegen/mlx/compile.go index 0dd2dd02a..746e6eaf6 100644 --- a/x/imagegen/mlx/compile.go +++ b/x/imagegen/mlx/compile.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx /* diff --git a/x/imagegen/mlx/doc.go b/x/imagegen/mlx/doc.go index ced1802b0..5410f3b2d 100644 --- a/x/imagegen/mlx/doc.go +++ b/x/imagegen/mlx/doc.go @@ -1,6 +1,4 @@ -//go:build mlx - // Package mlx provides Go bindings for the MLX-C library with dynamic loading support. // -//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c +//go:generate go run generate_wrappers.go ../../mlxrunner/mlx/include/mlx/c mlx.h mlx.c package mlx diff --git a/x/imagegen/mlx/generate_wrappers.go b/x/imagegen/mlx/generate_wrappers.go index 8aa5bd0c8..114ac2a15 100644 --- a/x/imagegen/mlx/generate_wrappers.go +++ b/x/imagegen/mlx/generate_wrappers.go @@ -291,8 +291,15 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err implBuf.WriteString("#include \"mlx/c/mlx.h\"\n") implBuf.WriteString("#include \"mlx_dynamic.h\"\n") - implBuf.WriteString("#include \n") - implBuf.WriteString("#include \n\n") + implBuf.WriteString("#include \n\n") + implBuf.WriteString("// Platform-specific dynamic loading\n") + implBuf.WriteString("#ifdef _WIN32\n") + implBuf.WriteString("#include \n") + implBuf.WriteString("#define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name)\n") + implBuf.WriteString("#else\n") + implBuf.WriteString("#include \n") + implBuf.WriteString("#define GET_SYM(handle, name) dlsym(handle, name)\n") + implBuf.WriteString("#endif\n\n") // Function pointer definitions implBuf.WriteString("// Function pointer definitions\n") @@ -308,7 +315,7 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err implBuf.WriteString("\n") // Initialization function - implBuf.WriteString("// Initialize all function pointers via dlsym\n") + implBuf.WriteString("// Initialize all function pointers\n") implBuf.WriteString("int mlx_load_functions(void* handle) {\n") implBuf.WriteString(" if (handle == NULL) {\n") implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n") @@ -319,7 +326,7 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err if fn.NeedsARM64Guard { implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n") } - implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name)) + implBuf.WriteString(fmt.Sprintf(" %s_ptr = GET_SYM(handle, \"%s\");\n", fn.Name, fn.Name)) implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name)) implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name)) implBuf.WriteString(" return -1;\n") diff --git a/x/imagegen/mlx/mlx.c b/x/imagegen/mlx/mlx.c index 770b60922..b0ccbacdf 100644 --- a/x/imagegen/mlx/mlx.c +++ b/x/imagegen/mlx/mlx.c @@ -5,7 +5,15 @@ #include "mlx/c/mlx.h" #include "mlx_dynamic.h" #include + +// Platform-specific dynamic loading +#ifdef _WIN32 +#include +#define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name) +#else #include +#define GET_SYM(handle, name) dlsym(handle, name) +#endif // Function pointer definitions size_t (*mlx_dtype_size_ptr)(mlx_dtype dtype) = NULL; @@ -603,2947 +611,2947 @@ size_t (*mlx_vector_string_size_ptr)(mlx_vector_string vec) = NULL; int (*mlx_vector_string_get_ptr)(char** res, const mlx_vector_string vec, size_t idx) = NULL; int (*mlx_version_ptr)(mlx_string* str_) = NULL; -// Initialize all function pointers via dlsym +// Initialize all function pointers int mlx_load_functions(void* handle) { if (handle == NULL) { fprintf(stderr, "MLX: Invalid library handle\n"); return -1; } - mlx_dtype_size_ptr = dlsym(handle, "mlx_dtype_size"); + mlx_dtype_size_ptr = GET_SYM(handle, "mlx_dtype_size"); if (mlx_dtype_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_dtype_size\n"); return -1; } - mlx_array_tostring_ptr = dlsym(handle, "mlx_array_tostring"); + mlx_array_tostring_ptr = GET_SYM(handle, "mlx_array_tostring"); if (mlx_array_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_tostring\n"); return -1; } - mlx_array_new_ptr = dlsym(handle, "mlx_array_new"); + mlx_array_new_ptr = GET_SYM(handle, "mlx_array_new"); if (mlx_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new\n"); return -1; } - mlx_array_free_ptr = dlsym(handle, "mlx_array_free"); + mlx_array_free_ptr = GET_SYM(handle, "mlx_array_free"); if (mlx_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_free\n"); return -1; } - mlx_array_new_bool_ptr = dlsym(handle, "mlx_array_new_bool"); + mlx_array_new_bool_ptr = GET_SYM(handle, "mlx_array_new_bool"); if (mlx_array_new_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_bool\n"); return -1; } - mlx_array_new_int_ptr = dlsym(handle, "mlx_array_new_int"); + mlx_array_new_int_ptr = GET_SYM(handle, "mlx_array_new_int"); if (mlx_array_new_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_int\n"); return -1; } - mlx_array_new_float32_ptr = dlsym(handle, "mlx_array_new_float32"); + mlx_array_new_float32_ptr = GET_SYM(handle, "mlx_array_new_float32"); if (mlx_array_new_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float32\n"); return -1; } - mlx_array_new_float_ptr = dlsym(handle, "mlx_array_new_float"); + mlx_array_new_float_ptr = GET_SYM(handle, "mlx_array_new_float"); if (mlx_array_new_float_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float\n"); return -1; } - mlx_array_new_float64_ptr = dlsym(handle, "mlx_array_new_float64"); + mlx_array_new_float64_ptr = GET_SYM(handle, "mlx_array_new_float64"); if (mlx_array_new_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float64\n"); return -1; } - mlx_array_new_double_ptr = dlsym(handle, "mlx_array_new_double"); + mlx_array_new_double_ptr = GET_SYM(handle, "mlx_array_new_double"); if (mlx_array_new_double_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_double\n"); return -1; } - mlx_array_new_complex_ptr = dlsym(handle, "mlx_array_new_complex"); + mlx_array_new_complex_ptr = GET_SYM(handle, "mlx_array_new_complex"); if (mlx_array_new_complex_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_complex\n"); return -1; } - mlx_array_new_data_ptr = dlsym(handle, "mlx_array_new_data"); + mlx_array_new_data_ptr = GET_SYM(handle, "mlx_array_new_data"); if (mlx_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n"); return -1; } - mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed"); + mlx_array_new_data_managed_ptr = GET_SYM(handle, "mlx_array_new_data_managed"); if (mlx_array_new_data_managed_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n"); return -1; } - mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload"); + mlx_array_new_data_managed_payload_ptr = GET_SYM(handle, "mlx_array_new_data_managed_payload"); if (mlx_array_new_data_managed_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n"); return -1; } - mlx_array_set_ptr = dlsym(handle, "mlx_array_set"); + mlx_array_set_ptr = GET_SYM(handle, "mlx_array_set"); if (mlx_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n"); return -1; } - mlx_array_set_bool_ptr = dlsym(handle, "mlx_array_set_bool"); + mlx_array_set_bool_ptr = GET_SYM(handle, "mlx_array_set_bool"); if (mlx_array_set_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_bool\n"); return -1; } - mlx_array_set_int_ptr = dlsym(handle, "mlx_array_set_int"); + mlx_array_set_int_ptr = GET_SYM(handle, "mlx_array_set_int"); if (mlx_array_set_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_int\n"); return -1; } - mlx_array_set_float32_ptr = dlsym(handle, "mlx_array_set_float32"); + mlx_array_set_float32_ptr = GET_SYM(handle, "mlx_array_set_float32"); if (mlx_array_set_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float32\n"); return -1; } - mlx_array_set_float_ptr = dlsym(handle, "mlx_array_set_float"); + mlx_array_set_float_ptr = GET_SYM(handle, "mlx_array_set_float"); if (mlx_array_set_float_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float\n"); return -1; } - mlx_array_set_float64_ptr = dlsym(handle, "mlx_array_set_float64"); + mlx_array_set_float64_ptr = GET_SYM(handle, "mlx_array_set_float64"); if (mlx_array_set_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float64\n"); return -1; } - mlx_array_set_double_ptr = dlsym(handle, "mlx_array_set_double"); + mlx_array_set_double_ptr = GET_SYM(handle, "mlx_array_set_double"); if (mlx_array_set_double_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_double\n"); return -1; } - mlx_array_set_complex_ptr = dlsym(handle, "mlx_array_set_complex"); + mlx_array_set_complex_ptr = GET_SYM(handle, "mlx_array_set_complex"); if (mlx_array_set_complex_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_complex\n"); return -1; } - mlx_array_set_data_ptr = dlsym(handle, "mlx_array_set_data"); + mlx_array_set_data_ptr = GET_SYM(handle, "mlx_array_set_data"); if (mlx_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_data\n"); return -1; } - mlx_array_itemsize_ptr = dlsym(handle, "mlx_array_itemsize"); + mlx_array_itemsize_ptr = GET_SYM(handle, "mlx_array_itemsize"); if (mlx_array_itemsize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_itemsize\n"); return -1; } - mlx_array_size_ptr = dlsym(handle, "mlx_array_size"); + mlx_array_size_ptr = GET_SYM(handle, "mlx_array_size"); if (mlx_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_size\n"); return -1; } - mlx_array_nbytes_ptr = dlsym(handle, "mlx_array_nbytes"); + mlx_array_nbytes_ptr = GET_SYM(handle, "mlx_array_nbytes"); if (mlx_array_nbytes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_nbytes\n"); return -1; } - mlx_array_ndim_ptr = dlsym(handle, "mlx_array_ndim"); + mlx_array_ndim_ptr = GET_SYM(handle, "mlx_array_ndim"); if (mlx_array_ndim_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_ndim\n"); return -1; } - mlx_array_shape_ptr = dlsym(handle, "mlx_array_shape"); + mlx_array_shape_ptr = GET_SYM(handle, "mlx_array_shape"); if (mlx_array_shape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_shape\n"); return -1; } - mlx_array_strides_ptr = dlsym(handle, "mlx_array_strides"); + mlx_array_strides_ptr = GET_SYM(handle, "mlx_array_strides"); if (mlx_array_strides_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_strides\n"); return -1; } - mlx_array_dim_ptr = dlsym(handle, "mlx_array_dim"); + mlx_array_dim_ptr = GET_SYM(handle, "mlx_array_dim"); if (mlx_array_dim_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dim\n"); return -1; } - mlx_array_dtype_ptr = dlsym(handle, "mlx_array_dtype"); + mlx_array_dtype_ptr = GET_SYM(handle, "mlx_array_dtype"); if (mlx_array_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dtype\n"); return -1; } - mlx_array_eval_ptr = dlsym(handle, "mlx_array_eval"); + mlx_array_eval_ptr = GET_SYM(handle, "mlx_array_eval"); if (mlx_array_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_eval\n"); return -1; } - mlx_array_item_bool_ptr = dlsym(handle, "mlx_array_item_bool"); + mlx_array_item_bool_ptr = GET_SYM(handle, "mlx_array_item_bool"); if (mlx_array_item_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bool\n"); return -1; } - mlx_array_item_uint8_ptr = dlsym(handle, "mlx_array_item_uint8"); + mlx_array_item_uint8_ptr = GET_SYM(handle, "mlx_array_item_uint8"); if (mlx_array_item_uint8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint8\n"); return -1; } - mlx_array_item_uint16_ptr = dlsym(handle, "mlx_array_item_uint16"); + mlx_array_item_uint16_ptr = GET_SYM(handle, "mlx_array_item_uint16"); if (mlx_array_item_uint16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint16\n"); return -1; } - mlx_array_item_uint32_ptr = dlsym(handle, "mlx_array_item_uint32"); + mlx_array_item_uint32_ptr = GET_SYM(handle, "mlx_array_item_uint32"); if (mlx_array_item_uint32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint32\n"); return -1; } - mlx_array_item_uint64_ptr = dlsym(handle, "mlx_array_item_uint64"); + mlx_array_item_uint64_ptr = GET_SYM(handle, "mlx_array_item_uint64"); if (mlx_array_item_uint64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint64\n"); return -1; } - mlx_array_item_int8_ptr = dlsym(handle, "mlx_array_item_int8"); + mlx_array_item_int8_ptr = GET_SYM(handle, "mlx_array_item_int8"); if (mlx_array_item_int8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int8\n"); return -1; } - mlx_array_item_int16_ptr = dlsym(handle, "mlx_array_item_int16"); + mlx_array_item_int16_ptr = GET_SYM(handle, "mlx_array_item_int16"); if (mlx_array_item_int16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int16\n"); return -1; } - mlx_array_item_int32_ptr = dlsym(handle, "mlx_array_item_int32"); + mlx_array_item_int32_ptr = GET_SYM(handle, "mlx_array_item_int32"); if (mlx_array_item_int32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int32\n"); return -1; } - mlx_array_item_int64_ptr = dlsym(handle, "mlx_array_item_int64"); + mlx_array_item_int64_ptr = GET_SYM(handle, "mlx_array_item_int64"); if (mlx_array_item_int64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int64\n"); return -1; } - mlx_array_item_float32_ptr = dlsym(handle, "mlx_array_item_float32"); + mlx_array_item_float32_ptr = GET_SYM(handle, "mlx_array_item_float32"); if (mlx_array_item_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float32\n"); return -1; } - mlx_array_item_float64_ptr = dlsym(handle, "mlx_array_item_float64"); + mlx_array_item_float64_ptr = GET_SYM(handle, "mlx_array_item_float64"); if (mlx_array_item_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float64\n"); return -1; } - mlx_array_item_complex64_ptr = dlsym(handle, "mlx_array_item_complex64"); + mlx_array_item_complex64_ptr = GET_SYM(handle, "mlx_array_item_complex64"); if (mlx_array_item_complex64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_complex64\n"); return -1; } #if defined(__aarch64__) || defined(_M_ARM64) - mlx_array_item_float16_ptr = dlsym(handle, "mlx_array_item_float16"); + mlx_array_item_float16_ptr = GET_SYM(handle, "mlx_array_item_float16"); if (mlx_array_item_float16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float16\n"); return -1; } #endif #if defined(__aarch64__) || defined(_M_ARM64) - mlx_array_item_bfloat16_ptr = dlsym(handle, "mlx_array_item_bfloat16"); + mlx_array_item_bfloat16_ptr = GET_SYM(handle, "mlx_array_item_bfloat16"); if (mlx_array_item_bfloat16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bfloat16\n"); return -1; } #endif - mlx_array_data_bool_ptr = dlsym(handle, "mlx_array_data_bool"); + mlx_array_data_bool_ptr = GET_SYM(handle, "mlx_array_data_bool"); if (mlx_array_data_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bool\n"); return -1; } - mlx_array_data_uint8_ptr = dlsym(handle, "mlx_array_data_uint8"); + mlx_array_data_uint8_ptr = GET_SYM(handle, "mlx_array_data_uint8"); if (mlx_array_data_uint8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint8\n"); return -1; } - mlx_array_data_uint16_ptr = dlsym(handle, "mlx_array_data_uint16"); + mlx_array_data_uint16_ptr = GET_SYM(handle, "mlx_array_data_uint16"); if (mlx_array_data_uint16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint16\n"); return -1; } - mlx_array_data_uint32_ptr = dlsym(handle, "mlx_array_data_uint32"); + mlx_array_data_uint32_ptr = GET_SYM(handle, "mlx_array_data_uint32"); if (mlx_array_data_uint32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint32\n"); return -1; } - mlx_array_data_uint64_ptr = dlsym(handle, "mlx_array_data_uint64"); + mlx_array_data_uint64_ptr = GET_SYM(handle, "mlx_array_data_uint64"); if (mlx_array_data_uint64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint64\n"); return -1; } - mlx_array_data_int8_ptr = dlsym(handle, "mlx_array_data_int8"); + mlx_array_data_int8_ptr = GET_SYM(handle, "mlx_array_data_int8"); if (mlx_array_data_int8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int8\n"); return -1; } - mlx_array_data_int16_ptr = dlsym(handle, "mlx_array_data_int16"); + mlx_array_data_int16_ptr = GET_SYM(handle, "mlx_array_data_int16"); if (mlx_array_data_int16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int16\n"); return -1; } - mlx_array_data_int32_ptr = dlsym(handle, "mlx_array_data_int32"); + mlx_array_data_int32_ptr = GET_SYM(handle, "mlx_array_data_int32"); if (mlx_array_data_int32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int32\n"); return -1; } - mlx_array_data_int64_ptr = dlsym(handle, "mlx_array_data_int64"); + mlx_array_data_int64_ptr = GET_SYM(handle, "mlx_array_data_int64"); if (mlx_array_data_int64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int64\n"); return -1; } - mlx_array_data_float32_ptr = dlsym(handle, "mlx_array_data_float32"); + mlx_array_data_float32_ptr = GET_SYM(handle, "mlx_array_data_float32"); if (mlx_array_data_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float32\n"); return -1; } - mlx_array_data_float64_ptr = dlsym(handle, "mlx_array_data_float64"); + mlx_array_data_float64_ptr = GET_SYM(handle, "mlx_array_data_float64"); if (mlx_array_data_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float64\n"); return -1; } - mlx_array_data_complex64_ptr = dlsym(handle, "mlx_array_data_complex64"); + mlx_array_data_complex64_ptr = GET_SYM(handle, "mlx_array_data_complex64"); if (mlx_array_data_complex64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_complex64\n"); return -1; } #if defined(__aarch64__) || defined(_M_ARM64) - mlx_array_data_float16_ptr = dlsym(handle, "mlx_array_data_float16"); + mlx_array_data_float16_ptr = GET_SYM(handle, "mlx_array_data_float16"); if (mlx_array_data_float16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float16\n"); return -1; } #endif #if defined(__aarch64__) || defined(_M_ARM64) - mlx_array_data_bfloat16_ptr = dlsym(handle, "mlx_array_data_bfloat16"); + mlx_array_data_bfloat16_ptr = GET_SYM(handle, "mlx_array_data_bfloat16"); if (mlx_array_data_bfloat16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bfloat16\n"); return -1; } #endif - _mlx_array_is_available_ptr = dlsym(handle, "_mlx_array_is_available"); + _mlx_array_is_available_ptr = GET_SYM(handle, "_mlx_array_is_available"); if (_mlx_array_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_available\n"); return -1; } - _mlx_array_wait_ptr = dlsym(handle, "_mlx_array_wait"); + _mlx_array_wait_ptr = GET_SYM(handle, "_mlx_array_wait"); if (_mlx_array_wait_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_wait\n"); return -1; } - _mlx_array_is_contiguous_ptr = dlsym(handle, "_mlx_array_is_contiguous"); + _mlx_array_is_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_contiguous"); if (_mlx_array_is_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_contiguous\n"); return -1; } - _mlx_array_is_row_contiguous_ptr = dlsym(handle, "_mlx_array_is_row_contiguous"); + _mlx_array_is_row_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_row_contiguous"); if (_mlx_array_is_row_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_row_contiguous\n"); return -1; } - _mlx_array_is_col_contiguous_ptr = dlsym(handle, "_mlx_array_is_col_contiguous"); + _mlx_array_is_col_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_col_contiguous"); if (_mlx_array_is_col_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_col_contiguous\n"); return -1; } - mlx_closure_new_ptr = dlsym(handle, "mlx_closure_new"); + mlx_closure_new_ptr = GET_SYM(handle, "mlx_closure_new"); if (mlx_closure_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new\n"); return -1; } - mlx_closure_free_ptr = dlsym(handle, "mlx_closure_free"); + mlx_closure_free_ptr = GET_SYM(handle, "mlx_closure_free"); if (mlx_closure_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_free\n"); return -1; } - mlx_closure_new_func_ptr = dlsym(handle, "mlx_closure_new_func"); + mlx_closure_new_func_ptr = GET_SYM(handle, "mlx_closure_new_func"); if (mlx_closure_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func\n"); return -1; } - mlx_closure_new_func_payload_ptr = dlsym(handle, "mlx_closure_new_func_payload"); + mlx_closure_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_new_func_payload"); if (mlx_closure_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func_payload\n"); return -1; } - mlx_closure_set_ptr = dlsym(handle, "mlx_closure_set"); + mlx_closure_set_ptr = GET_SYM(handle, "mlx_closure_set"); if (mlx_closure_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_set\n"); return -1; } - mlx_closure_apply_ptr = dlsym(handle, "mlx_closure_apply"); + mlx_closure_apply_ptr = GET_SYM(handle, "mlx_closure_apply"); if (mlx_closure_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_apply\n"); return -1; } - mlx_closure_new_unary_ptr = dlsym(handle, "mlx_closure_new_unary"); + mlx_closure_new_unary_ptr = GET_SYM(handle, "mlx_closure_new_unary"); if (mlx_closure_new_unary_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_unary\n"); return -1; } - mlx_closure_kwargs_new_ptr = dlsym(handle, "mlx_closure_kwargs_new"); + mlx_closure_kwargs_new_ptr = GET_SYM(handle, "mlx_closure_kwargs_new"); if (mlx_closure_kwargs_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new\n"); return -1; } - mlx_closure_kwargs_free_ptr = dlsym(handle, "mlx_closure_kwargs_free"); + mlx_closure_kwargs_free_ptr = GET_SYM(handle, "mlx_closure_kwargs_free"); if (mlx_closure_kwargs_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_free\n"); return -1; } - mlx_closure_kwargs_new_func_ptr = dlsym(handle, "mlx_closure_kwargs_new_func"); + mlx_closure_kwargs_new_func_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func"); if (mlx_closure_kwargs_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func\n"); return -1; } - mlx_closure_kwargs_new_func_payload_ptr = dlsym(handle, "mlx_closure_kwargs_new_func_payload"); + mlx_closure_kwargs_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func_payload"); if (mlx_closure_kwargs_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func_payload\n"); return -1; } - mlx_closure_kwargs_set_ptr = dlsym(handle, "mlx_closure_kwargs_set"); + mlx_closure_kwargs_set_ptr = GET_SYM(handle, "mlx_closure_kwargs_set"); if (mlx_closure_kwargs_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_set\n"); return -1; } - mlx_closure_kwargs_apply_ptr = dlsym(handle, "mlx_closure_kwargs_apply"); + mlx_closure_kwargs_apply_ptr = GET_SYM(handle, "mlx_closure_kwargs_apply"); if (mlx_closure_kwargs_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_apply\n"); return -1; } - mlx_closure_value_and_grad_new_ptr = dlsym(handle, "mlx_closure_value_and_grad_new"); + mlx_closure_value_and_grad_new_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new"); if (mlx_closure_value_and_grad_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new\n"); return -1; } - mlx_closure_value_and_grad_free_ptr = dlsym(handle, "mlx_closure_value_and_grad_free"); + mlx_closure_value_and_grad_free_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_free"); if (mlx_closure_value_and_grad_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_free\n"); return -1; } - mlx_closure_value_and_grad_new_func_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func"); + mlx_closure_value_and_grad_new_func_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func"); if (mlx_closure_value_and_grad_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func\n"); return -1; } - mlx_closure_value_and_grad_new_func_payload_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func_payload"); + mlx_closure_value_and_grad_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func_payload"); if (mlx_closure_value_and_grad_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func_payload\n"); return -1; } - mlx_closure_value_and_grad_set_ptr = dlsym(handle, "mlx_closure_value_and_grad_set"); + mlx_closure_value_and_grad_set_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_set"); if (mlx_closure_value_and_grad_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_set\n"); return -1; } - mlx_closure_value_and_grad_apply_ptr = dlsym(handle, "mlx_closure_value_and_grad_apply"); + mlx_closure_value_and_grad_apply_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_apply"); if (mlx_closure_value_and_grad_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_apply\n"); return -1; } - mlx_closure_custom_new_ptr = dlsym(handle, "mlx_closure_custom_new"); + mlx_closure_custom_new_ptr = GET_SYM(handle, "mlx_closure_custom_new"); if (mlx_closure_custom_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new\n"); return -1; } - mlx_closure_custom_free_ptr = dlsym(handle, "mlx_closure_custom_free"); + mlx_closure_custom_free_ptr = GET_SYM(handle, "mlx_closure_custom_free"); if (mlx_closure_custom_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_free\n"); return -1; } - mlx_closure_custom_new_func_ptr = dlsym(handle, "mlx_closure_custom_new_func"); + mlx_closure_custom_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_new_func"); if (mlx_closure_custom_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func\n"); return -1; } - mlx_closure_custom_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_new_func_payload"); + mlx_closure_custom_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_new_func_payload"); if (mlx_closure_custom_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func_payload\n"); return -1; } - mlx_closure_custom_set_ptr = dlsym(handle, "mlx_closure_custom_set"); + mlx_closure_custom_set_ptr = GET_SYM(handle, "mlx_closure_custom_set"); if (mlx_closure_custom_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_set\n"); return -1; } - mlx_closure_custom_apply_ptr = dlsym(handle, "mlx_closure_custom_apply"); + mlx_closure_custom_apply_ptr = GET_SYM(handle, "mlx_closure_custom_apply"); if (mlx_closure_custom_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_apply\n"); return -1; } - mlx_closure_custom_jvp_new_ptr = dlsym(handle, "mlx_closure_custom_jvp_new"); + mlx_closure_custom_jvp_new_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new"); if (mlx_closure_custom_jvp_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new\n"); return -1; } - mlx_closure_custom_jvp_free_ptr = dlsym(handle, "mlx_closure_custom_jvp_free"); + mlx_closure_custom_jvp_free_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_free"); if (mlx_closure_custom_jvp_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_free\n"); return -1; } - mlx_closure_custom_jvp_new_func_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func"); + mlx_closure_custom_jvp_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func"); if (mlx_closure_custom_jvp_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func\n"); return -1; } - mlx_closure_custom_jvp_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func_payload"); + mlx_closure_custom_jvp_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func_payload"); if (mlx_closure_custom_jvp_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func_payload\n"); return -1; } - mlx_closure_custom_jvp_set_ptr = dlsym(handle, "mlx_closure_custom_jvp_set"); + mlx_closure_custom_jvp_set_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_set"); if (mlx_closure_custom_jvp_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_set\n"); return -1; } - mlx_closure_custom_jvp_apply_ptr = dlsym(handle, "mlx_closure_custom_jvp_apply"); + mlx_closure_custom_jvp_apply_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_apply"); if (mlx_closure_custom_jvp_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_apply\n"); return -1; } - mlx_closure_custom_vmap_new_ptr = dlsym(handle, "mlx_closure_custom_vmap_new"); + mlx_closure_custom_vmap_new_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new"); if (mlx_closure_custom_vmap_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new\n"); return -1; } - mlx_closure_custom_vmap_free_ptr = dlsym(handle, "mlx_closure_custom_vmap_free"); + mlx_closure_custom_vmap_free_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_free"); if (mlx_closure_custom_vmap_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_free\n"); return -1; } - mlx_closure_custom_vmap_new_func_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func"); + mlx_closure_custom_vmap_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func"); if (mlx_closure_custom_vmap_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func\n"); return -1; } - mlx_closure_custom_vmap_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func_payload"); + mlx_closure_custom_vmap_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func_payload"); if (mlx_closure_custom_vmap_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func_payload\n"); return -1; } - mlx_closure_custom_vmap_set_ptr = dlsym(handle, "mlx_closure_custom_vmap_set"); + mlx_closure_custom_vmap_set_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_set"); if (mlx_closure_custom_vmap_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_set\n"); return -1; } - mlx_closure_custom_vmap_apply_ptr = dlsym(handle, "mlx_closure_custom_vmap_apply"); + mlx_closure_custom_vmap_apply_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_apply"); if (mlx_closure_custom_vmap_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_apply\n"); return -1; } - mlx_compile_ptr = dlsym(handle, "mlx_compile"); + mlx_compile_ptr = GET_SYM(handle, "mlx_compile"); if (mlx_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_compile\n"); return -1; } - mlx_detail_compile_ptr = dlsym(handle, "mlx_detail_compile"); + mlx_detail_compile_ptr = GET_SYM(handle, "mlx_detail_compile"); if (mlx_detail_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile\n"); return -1; } - mlx_detail_compile_clear_cache_ptr = dlsym(handle, "mlx_detail_compile_clear_cache"); + mlx_detail_compile_clear_cache_ptr = GET_SYM(handle, "mlx_detail_compile_clear_cache"); if (mlx_detail_compile_clear_cache_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_clear_cache\n"); return -1; } - mlx_detail_compile_erase_ptr = dlsym(handle, "mlx_detail_compile_erase"); + mlx_detail_compile_erase_ptr = GET_SYM(handle, "mlx_detail_compile_erase"); if (mlx_detail_compile_erase_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_erase\n"); return -1; } - mlx_disable_compile_ptr = dlsym(handle, "mlx_disable_compile"); + mlx_disable_compile_ptr = GET_SYM(handle, "mlx_disable_compile"); if (mlx_disable_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_disable_compile\n"); return -1; } - mlx_enable_compile_ptr = dlsym(handle, "mlx_enable_compile"); + mlx_enable_compile_ptr = GET_SYM(handle, "mlx_enable_compile"); if (mlx_enable_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_enable_compile\n"); return -1; } - mlx_set_compile_mode_ptr = dlsym(handle, "mlx_set_compile_mode"); + mlx_set_compile_mode_ptr = GET_SYM(handle, "mlx_set_compile_mode"); if (mlx_set_compile_mode_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n"); return -1; } - mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available"); + mlx_cuda_is_available_ptr = GET_SYM(handle, "mlx_cuda_is_available"); if (mlx_cuda_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n"); return -1; } - mlx_device_new_ptr = dlsym(handle, "mlx_device_new"); + mlx_device_new_ptr = GET_SYM(handle, "mlx_device_new"); if (mlx_device_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n"); return -1; } - mlx_device_new_type_ptr = dlsym(handle, "mlx_device_new_type"); + mlx_device_new_type_ptr = GET_SYM(handle, "mlx_device_new_type"); if (mlx_device_new_type_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new_type\n"); return -1; } - mlx_device_free_ptr = dlsym(handle, "mlx_device_free"); + mlx_device_free_ptr = GET_SYM(handle, "mlx_device_free"); if (mlx_device_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_free\n"); return -1; } - mlx_device_set_ptr = dlsym(handle, "mlx_device_set"); + mlx_device_set_ptr = GET_SYM(handle, "mlx_device_set"); if (mlx_device_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_set\n"); return -1; } - mlx_device_tostring_ptr = dlsym(handle, "mlx_device_tostring"); + mlx_device_tostring_ptr = GET_SYM(handle, "mlx_device_tostring"); if (mlx_device_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_tostring\n"); return -1; } - mlx_device_equal_ptr = dlsym(handle, "mlx_device_equal"); + mlx_device_equal_ptr = GET_SYM(handle, "mlx_device_equal"); if (mlx_device_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_equal\n"); return -1; } - mlx_device_get_index_ptr = dlsym(handle, "mlx_device_get_index"); + mlx_device_get_index_ptr = GET_SYM(handle, "mlx_device_get_index"); if (mlx_device_get_index_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_index\n"); return -1; } - mlx_device_get_type_ptr = dlsym(handle, "mlx_device_get_type"); + mlx_device_get_type_ptr = GET_SYM(handle, "mlx_device_get_type"); if (mlx_device_get_type_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_type\n"); return -1; } - mlx_get_default_device_ptr = dlsym(handle, "mlx_get_default_device"); + mlx_get_default_device_ptr = GET_SYM(handle, "mlx_get_default_device"); if (mlx_get_default_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_device\n"); return -1; } - mlx_set_default_device_ptr = dlsym(handle, "mlx_set_default_device"); + mlx_set_default_device_ptr = GET_SYM(handle, "mlx_set_default_device"); if (mlx_set_default_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n"); return -1; } - mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available"); + mlx_device_is_available_ptr = GET_SYM(handle, "mlx_device_is_available"); if (mlx_device_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n"); return -1; } - mlx_device_count_ptr = dlsym(handle, "mlx_device_count"); + mlx_device_count_ptr = GET_SYM(handle, "mlx_device_count"); if (mlx_device_count_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n"); return -1; } - mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new"); + mlx_device_info_new_ptr = GET_SYM(handle, "mlx_device_info_new"); if (mlx_device_info_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n"); return -1; } - mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get"); + mlx_device_info_get_ptr = GET_SYM(handle, "mlx_device_info_get"); if (mlx_device_info_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n"); return -1; } - mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free"); + mlx_device_info_free_ptr = GET_SYM(handle, "mlx_device_info_free"); if (mlx_device_info_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n"); return -1; } - mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key"); + mlx_device_info_has_key_ptr = GET_SYM(handle, "mlx_device_info_has_key"); if (mlx_device_info_has_key_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n"); return -1; } - mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string"); + mlx_device_info_is_string_ptr = GET_SYM(handle, "mlx_device_info_is_string"); if (mlx_device_info_is_string_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n"); return -1; } - mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string"); + mlx_device_info_get_string_ptr = GET_SYM(handle, "mlx_device_info_get_string"); if (mlx_device_info_get_string_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n"); return -1; } - mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size"); + mlx_device_info_get_size_ptr = GET_SYM(handle, "mlx_device_info_get_size"); if (mlx_device_info_get_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n"); return -1; } - mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys"); + mlx_device_info_get_keys_ptr = GET_SYM(handle, "mlx_device_info_get_keys"); if (mlx_device_info_get_keys_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n"); return -1; } - mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather"); + mlx_distributed_all_gather_ptr = GET_SYM(handle, "mlx_distributed_all_gather"); if (mlx_distributed_all_gather_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n"); return -1; } - mlx_distributed_all_max_ptr = dlsym(handle, "mlx_distributed_all_max"); + mlx_distributed_all_max_ptr = GET_SYM(handle, "mlx_distributed_all_max"); if (mlx_distributed_all_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_max\n"); return -1; } - mlx_distributed_all_min_ptr = dlsym(handle, "mlx_distributed_all_min"); + mlx_distributed_all_min_ptr = GET_SYM(handle, "mlx_distributed_all_min"); if (mlx_distributed_all_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_min\n"); return -1; } - mlx_distributed_all_sum_ptr = dlsym(handle, "mlx_distributed_all_sum"); + mlx_distributed_all_sum_ptr = GET_SYM(handle, "mlx_distributed_all_sum"); if (mlx_distributed_all_sum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_sum\n"); return -1; } - mlx_distributed_recv_ptr = dlsym(handle, "mlx_distributed_recv"); + mlx_distributed_recv_ptr = GET_SYM(handle, "mlx_distributed_recv"); if (mlx_distributed_recv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv\n"); return -1; } - mlx_distributed_recv_like_ptr = dlsym(handle, "mlx_distributed_recv_like"); + mlx_distributed_recv_like_ptr = GET_SYM(handle, "mlx_distributed_recv_like"); if (mlx_distributed_recv_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv_like\n"); return -1; } - mlx_distributed_send_ptr = dlsym(handle, "mlx_distributed_send"); + mlx_distributed_send_ptr = GET_SYM(handle, "mlx_distributed_send"); if (mlx_distributed_send_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_send\n"); return -1; } - mlx_distributed_sum_scatter_ptr = dlsym(handle, "mlx_distributed_sum_scatter"); + mlx_distributed_sum_scatter_ptr = GET_SYM(handle, "mlx_distributed_sum_scatter"); if (mlx_distributed_sum_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_sum_scatter\n"); return -1; } - mlx_distributed_group_rank_ptr = dlsym(handle, "mlx_distributed_group_rank"); + mlx_distributed_group_rank_ptr = GET_SYM(handle, "mlx_distributed_group_rank"); if (mlx_distributed_group_rank_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_rank\n"); return -1; } - mlx_distributed_group_size_ptr = dlsym(handle, "mlx_distributed_group_size"); + mlx_distributed_group_size_ptr = GET_SYM(handle, "mlx_distributed_group_size"); if (mlx_distributed_group_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_size\n"); return -1; } - mlx_distributed_group_split_ptr = dlsym(handle, "mlx_distributed_group_split"); + mlx_distributed_group_split_ptr = GET_SYM(handle, "mlx_distributed_group_split"); if (mlx_distributed_group_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_split\n"); return -1; } - mlx_distributed_is_available_ptr = dlsym(handle, "mlx_distributed_is_available"); + mlx_distributed_is_available_ptr = GET_SYM(handle, "mlx_distributed_is_available"); if (mlx_distributed_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_is_available\n"); return -1; } - mlx_distributed_init_ptr = dlsym(handle, "mlx_distributed_init"); + mlx_distributed_init_ptr = GET_SYM(handle, "mlx_distributed_init"); if (mlx_distributed_init_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n"); return -1; } - mlx_set_error_handler_ptr = dlsym(handle, "mlx_set_error_handler"); + mlx_set_error_handler_ptr = GET_SYM(handle, "mlx_set_error_handler"); if (mlx_set_error_handler_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_error_handler\n"); return -1; } - _mlx_error_ptr = dlsym(handle, "_mlx_error"); + _mlx_error_ptr = GET_SYM(handle, "_mlx_error"); if (_mlx_error_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_error\n"); return -1; } - mlx_export_function_ptr = dlsym(handle, "mlx_export_function"); + mlx_export_function_ptr = GET_SYM(handle, "mlx_export_function"); if (mlx_export_function_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function\n"); return -1; } - mlx_export_function_kwargs_ptr = dlsym(handle, "mlx_export_function_kwargs"); + mlx_export_function_kwargs_ptr = GET_SYM(handle, "mlx_export_function_kwargs"); if (mlx_export_function_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function_kwargs\n"); return -1; } - mlx_function_exporter_new_ptr = dlsym(handle, "mlx_function_exporter_new"); + mlx_function_exporter_new_ptr = GET_SYM(handle, "mlx_function_exporter_new"); if (mlx_function_exporter_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_new\n"); return -1; } - mlx_function_exporter_free_ptr = dlsym(handle, "mlx_function_exporter_free"); + mlx_function_exporter_free_ptr = GET_SYM(handle, "mlx_function_exporter_free"); if (mlx_function_exporter_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_free\n"); return -1; } - mlx_function_exporter_apply_ptr = dlsym(handle, "mlx_function_exporter_apply"); + mlx_function_exporter_apply_ptr = GET_SYM(handle, "mlx_function_exporter_apply"); if (mlx_function_exporter_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply\n"); return -1; } - mlx_function_exporter_apply_kwargs_ptr = dlsym(handle, "mlx_function_exporter_apply_kwargs"); + mlx_function_exporter_apply_kwargs_ptr = GET_SYM(handle, "mlx_function_exporter_apply_kwargs"); if (mlx_function_exporter_apply_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply_kwargs\n"); return -1; } - mlx_imported_function_new_ptr = dlsym(handle, "mlx_imported_function_new"); + mlx_imported_function_new_ptr = GET_SYM(handle, "mlx_imported_function_new"); if (mlx_imported_function_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_new\n"); return -1; } - mlx_imported_function_free_ptr = dlsym(handle, "mlx_imported_function_free"); + mlx_imported_function_free_ptr = GET_SYM(handle, "mlx_imported_function_free"); if (mlx_imported_function_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_free\n"); return -1; } - mlx_imported_function_apply_ptr = dlsym(handle, "mlx_imported_function_apply"); + mlx_imported_function_apply_ptr = GET_SYM(handle, "mlx_imported_function_apply"); if (mlx_imported_function_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply\n"); return -1; } - mlx_imported_function_apply_kwargs_ptr = dlsym(handle, "mlx_imported_function_apply_kwargs"); + mlx_imported_function_apply_kwargs_ptr = GET_SYM(handle, "mlx_imported_function_apply_kwargs"); if (mlx_imported_function_apply_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply_kwargs\n"); return -1; } - mlx_fast_cuda_kernel_config_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_new"); + mlx_fast_cuda_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_new"); if (mlx_fast_cuda_kernel_config_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_new\n"); return -1; } - mlx_fast_cuda_kernel_config_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_free"); + mlx_fast_cuda_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_free"); if (mlx_fast_cuda_kernel_config_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_free\n"); return -1; } - mlx_fast_cuda_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_output_arg"); + mlx_fast_cuda_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_output_arg"); if (mlx_fast_cuda_kernel_config_add_output_arg_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_output_arg\n"); return -1; } - mlx_fast_cuda_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_grid"); + mlx_fast_cuda_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_grid"); if (mlx_fast_cuda_kernel_config_set_grid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_grid\n"); return -1; } - mlx_fast_cuda_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_thread_group"); + mlx_fast_cuda_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_thread_group"); if (mlx_fast_cuda_kernel_config_set_thread_group_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_thread_group\n"); return -1; } - mlx_fast_cuda_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_init_value"); + mlx_fast_cuda_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_init_value"); if (mlx_fast_cuda_kernel_config_set_init_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_init_value\n"); return -1; } - mlx_fast_cuda_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_verbose"); + mlx_fast_cuda_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_verbose"); if (mlx_fast_cuda_kernel_config_set_verbose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_verbose\n"); return -1; } - mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype"); + mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype"); if (mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_dtype\n"); return -1; } - mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int"); + mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int"); if (mlx_fast_cuda_kernel_config_add_template_arg_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_int\n"); return -1; } - mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool"); + mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool"); if (mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_bool\n"); return -1; } - mlx_fast_cuda_kernel_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_new"); + mlx_fast_cuda_kernel_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_new"); if (mlx_fast_cuda_kernel_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_new\n"); return -1; } - mlx_fast_cuda_kernel_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_free"); + mlx_fast_cuda_kernel_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_free"); if (mlx_fast_cuda_kernel_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_free\n"); return -1; } - mlx_fast_cuda_kernel_apply_ptr = dlsym(handle, "mlx_fast_cuda_kernel_apply"); + mlx_fast_cuda_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_apply"); if (mlx_fast_cuda_kernel_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_apply\n"); return -1; } - mlx_fast_layer_norm_ptr = dlsym(handle, "mlx_fast_layer_norm"); + mlx_fast_layer_norm_ptr = GET_SYM(handle, "mlx_fast_layer_norm"); if (mlx_fast_layer_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_layer_norm\n"); return -1; } - mlx_fast_metal_kernel_config_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_new"); + mlx_fast_metal_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_new"); if (mlx_fast_metal_kernel_config_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_new\n"); return -1; } - mlx_fast_metal_kernel_config_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_free"); + mlx_fast_metal_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_free"); if (mlx_fast_metal_kernel_config_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_free\n"); return -1; } - mlx_fast_metal_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_output_arg"); + mlx_fast_metal_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_output_arg"); if (mlx_fast_metal_kernel_config_add_output_arg_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_output_arg\n"); return -1; } - mlx_fast_metal_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_grid"); + mlx_fast_metal_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_grid"); if (mlx_fast_metal_kernel_config_set_grid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_grid\n"); return -1; } - mlx_fast_metal_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_thread_group"); + mlx_fast_metal_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_thread_group"); if (mlx_fast_metal_kernel_config_set_thread_group_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_thread_group\n"); return -1; } - mlx_fast_metal_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_init_value"); + mlx_fast_metal_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_init_value"); if (mlx_fast_metal_kernel_config_set_init_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_init_value\n"); return -1; } - mlx_fast_metal_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_verbose"); + mlx_fast_metal_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_verbose"); if (mlx_fast_metal_kernel_config_set_verbose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_verbose\n"); return -1; } - mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype"); + mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype"); if (mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_dtype\n"); return -1; } - mlx_fast_metal_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_int"); + mlx_fast_metal_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_int"); if (mlx_fast_metal_kernel_config_add_template_arg_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_int\n"); return -1; } - mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool"); + mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool"); if (mlx_fast_metal_kernel_config_add_template_arg_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_bool\n"); return -1; } - mlx_fast_metal_kernel_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_new"); + mlx_fast_metal_kernel_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_new"); if (mlx_fast_metal_kernel_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_new\n"); return -1; } - mlx_fast_metal_kernel_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_free"); + mlx_fast_metal_kernel_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_free"); if (mlx_fast_metal_kernel_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_free\n"); return -1; } - mlx_fast_metal_kernel_apply_ptr = dlsym(handle, "mlx_fast_metal_kernel_apply"); + mlx_fast_metal_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_apply"); if (mlx_fast_metal_kernel_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_apply\n"); return -1; } - mlx_fast_rms_norm_ptr = dlsym(handle, "mlx_fast_rms_norm"); + mlx_fast_rms_norm_ptr = GET_SYM(handle, "mlx_fast_rms_norm"); if (mlx_fast_rms_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rms_norm\n"); return -1; } - mlx_fast_rope_ptr = dlsym(handle, "mlx_fast_rope"); + mlx_fast_rope_ptr = GET_SYM(handle, "mlx_fast_rope"); if (mlx_fast_rope_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope\n"); return -1; } - mlx_fast_rope_dynamic_ptr = dlsym(handle, "mlx_fast_rope_dynamic"); + mlx_fast_rope_dynamic_ptr = GET_SYM(handle, "mlx_fast_rope_dynamic"); if (mlx_fast_rope_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope_dynamic\n"); return -1; } - mlx_fast_scaled_dot_product_attention_ptr = dlsym(handle, "mlx_fast_scaled_dot_product_attention"); + mlx_fast_scaled_dot_product_attention_ptr = GET_SYM(handle, "mlx_fast_scaled_dot_product_attention"); if (mlx_fast_scaled_dot_product_attention_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_scaled_dot_product_attention\n"); return -1; } - mlx_fft_fft_ptr = dlsym(handle, "mlx_fft_fft"); + mlx_fft_fft_ptr = GET_SYM(handle, "mlx_fft_fft"); if (mlx_fft_fft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft\n"); return -1; } - mlx_fft_fft2_ptr = dlsym(handle, "mlx_fft_fft2"); + mlx_fft_fft2_ptr = GET_SYM(handle, "mlx_fft_fft2"); if (mlx_fft_fft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft2\n"); return -1; } - mlx_fft_fftn_ptr = dlsym(handle, "mlx_fft_fftn"); + mlx_fft_fftn_ptr = GET_SYM(handle, "mlx_fft_fftn"); if (mlx_fft_fftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftn\n"); return -1; } - mlx_fft_fftshift_ptr = dlsym(handle, "mlx_fft_fftshift"); + mlx_fft_fftshift_ptr = GET_SYM(handle, "mlx_fft_fftshift"); if (mlx_fft_fftshift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftshift\n"); return -1; } - mlx_fft_ifft_ptr = dlsym(handle, "mlx_fft_ifft"); + mlx_fft_ifft_ptr = GET_SYM(handle, "mlx_fft_ifft"); if (mlx_fft_ifft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft\n"); return -1; } - mlx_fft_ifft2_ptr = dlsym(handle, "mlx_fft_ifft2"); + mlx_fft_ifft2_ptr = GET_SYM(handle, "mlx_fft_ifft2"); if (mlx_fft_ifft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft2\n"); return -1; } - mlx_fft_ifftn_ptr = dlsym(handle, "mlx_fft_ifftn"); + mlx_fft_ifftn_ptr = GET_SYM(handle, "mlx_fft_ifftn"); if (mlx_fft_ifftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftn\n"); return -1; } - mlx_fft_ifftshift_ptr = dlsym(handle, "mlx_fft_ifftshift"); + mlx_fft_ifftshift_ptr = GET_SYM(handle, "mlx_fft_ifftshift"); if (mlx_fft_ifftshift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftshift\n"); return -1; } - mlx_fft_irfft_ptr = dlsym(handle, "mlx_fft_irfft"); + mlx_fft_irfft_ptr = GET_SYM(handle, "mlx_fft_irfft"); if (mlx_fft_irfft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft\n"); return -1; } - mlx_fft_irfft2_ptr = dlsym(handle, "mlx_fft_irfft2"); + mlx_fft_irfft2_ptr = GET_SYM(handle, "mlx_fft_irfft2"); if (mlx_fft_irfft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft2\n"); return -1; } - mlx_fft_irfftn_ptr = dlsym(handle, "mlx_fft_irfftn"); + mlx_fft_irfftn_ptr = GET_SYM(handle, "mlx_fft_irfftn"); if (mlx_fft_irfftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfftn\n"); return -1; } - mlx_fft_rfft_ptr = dlsym(handle, "mlx_fft_rfft"); + mlx_fft_rfft_ptr = GET_SYM(handle, "mlx_fft_rfft"); if (mlx_fft_rfft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft\n"); return -1; } - mlx_fft_rfft2_ptr = dlsym(handle, "mlx_fft_rfft2"); + mlx_fft_rfft2_ptr = GET_SYM(handle, "mlx_fft_rfft2"); if (mlx_fft_rfft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft2\n"); return -1; } - mlx_fft_rfftn_ptr = dlsym(handle, "mlx_fft_rfftn"); + mlx_fft_rfftn_ptr = GET_SYM(handle, "mlx_fft_rfftn"); if (mlx_fft_rfftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftn\n"); return -1; } - mlx_load_reader_ptr = dlsym(handle, "mlx_load_reader"); + mlx_load_reader_ptr = GET_SYM(handle, "mlx_load_reader"); if (mlx_load_reader_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_reader\n"); return -1; } - mlx_load_ptr = dlsym(handle, "mlx_load"); + mlx_load_ptr = GET_SYM(handle, "mlx_load"); if (mlx_load_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load\n"); return -1; } - mlx_load_safetensors_reader_ptr = dlsym(handle, "mlx_load_safetensors_reader"); + mlx_load_safetensors_reader_ptr = GET_SYM(handle, "mlx_load_safetensors_reader"); if (mlx_load_safetensors_reader_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors_reader\n"); return -1; } - mlx_load_safetensors_ptr = dlsym(handle, "mlx_load_safetensors"); + mlx_load_safetensors_ptr = GET_SYM(handle, "mlx_load_safetensors"); if (mlx_load_safetensors_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors\n"); return -1; } - mlx_save_writer_ptr = dlsym(handle, "mlx_save_writer"); + mlx_save_writer_ptr = GET_SYM(handle, "mlx_save_writer"); if (mlx_save_writer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_writer\n"); return -1; } - mlx_save_ptr = dlsym(handle, "mlx_save"); + mlx_save_ptr = GET_SYM(handle, "mlx_save"); if (mlx_save_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save\n"); return -1; } - mlx_save_safetensors_writer_ptr = dlsym(handle, "mlx_save_safetensors_writer"); + mlx_save_safetensors_writer_ptr = GET_SYM(handle, "mlx_save_safetensors_writer"); if (mlx_save_safetensors_writer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors_writer\n"); return -1; } - mlx_save_safetensors_ptr = dlsym(handle, "mlx_save_safetensors"); + mlx_save_safetensors_ptr = GET_SYM(handle, "mlx_save_safetensors"); if (mlx_save_safetensors_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors\n"); return -1; } - mlx_io_reader_new_ptr = dlsym(handle, "mlx_io_reader_new"); + mlx_io_reader_new_ptr = GET_SYM(handle, "mlx_io_reader_new"); if (mlx_io_reader_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_new\n"); return -1; } - mlx_io_reader_descriptor_ptr = dlsym(handle, "mlx_io_reader_descriptor"); + mlx_io_reader_descriptor_ptr = GET_SYM(handle, "mlx_io_reader_descriptor"); if (mlx_io_reader_descriptor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_descriptor\n"); return -1; } - mlx_io_reader_tostring_ptr = dlsym(handle, "mlx_io_reader_tostring"); + mlx_io_reader_tostring_ptr = GET_SYM(handle, "mlx_io_reader_tostring"); if (mlx_io_reader_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_tostring\n"); return -1; } - mlx_io_reader_free_ptr = dlsym(handle, "mlx_io_reader_free"); + mlx_io_reader_free_ptr = GET_SYM(handle, "mlx_io_reader_free"); if (mlx_io_reader_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_free\n"); return -1; } - mlx_io_writer_new_ptr = dlsym(handle, "mlx_io_writer_new"); + mlx_io_writer_new_ptr = GET_SYM(handle, "mlx_io_writer_new"); if (mlx_io_writer_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_new\n"); return -1; } - mlx_io_writer_descriptor_ptr = dlsym(handle, "mlx_io_writer_descriptor"); + mlx_io_writer_descriptor_ptr = GET_SYM(handle, "mlx_io_writer_descriptor"); if (mlx_io_writer_descriptor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_descriptor\n"); return -1; } - mlx_io_writer_tostring_ptr = dlsym(handle, "mlx_io_writer_tostring"); + mlx_io_writer_tostring_ptr = GET_SYM(handle, "mlx_io_writer_tostring"); if (mlx_io_writer_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_tostring\n"); return -1; } - mlx_io_writer_free_ptr = dlsym(handle, "mlx_io_writer_free"); + mlx_io_writer_free_ptr = GET_SYM(handle, "mlx_io_writer_free"); if (mlx_io_writer_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_free\n"); return -1; } - mlx_linalg_cholesky_ptr = dlsym(handle, "mlx_linalg_cholesky"); + mlx_linalg_cholesky_ptr = GET_SYM(handle, "mlx_linalg_cholesky"); if (mlx_linalg_cholesky_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky\n"); return -1; } - mlx_linalg_cholesky_inv_ptr = dlsym(handle, "mlx_linalg_cholesky_inv"); + mlx_linalg_cholesky_inv_ptr = GET_SYM(handle, "mlx_linalg_cholesky_inv"); if (mlx_linalg_cholesky_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky_inv\n"); return -1; } - mlx_linalg_cross_ptr = dlsym(handle, "mlx_linalg_cross"); + mlx_linalg_cross_ptr = GET_SYM(handle, "mlx_linalg_cross"); if (mlx_linalg_cross_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cross\n"); return -1; } - mlx_linalg_eig_ptr = dlsym(handle, "mlx_linalg_eig"); + mlx_linalg_eig_ptr = GET_SYM(handle, "mlx_linalg_eig"); if (mlx_linalg_eig_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eig\n"); return -1; } - mlx_linalg_eigh_ptr = dlsym(handle, "mlx_linalg_eigh"); + mlx_linalg_eigh_ptr = GET_SYM(handle, "mlx_linalg_eigh"); if (mlx_linalg_eigh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigh\n"); return -1; } - mlx_linalg_eigvals_ptr = dlsym(handle, "mlx_linalg_eigvals"); + mlx_linalg_eigvals_ptr = GET_SYM(handle, "mlx_linalg_eigvals"); if (mlx_linalg_eigvals_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvals\n"); return -1; } - mlx_linalg_eigvalsh_ptr = dlsym(handle, "mlx_linalg_eigvalsh"); + mlx_linalg_eigvalsh_ptr = GET_SYM(handle, "mlx_linalg_eigvalsh"); if (mlx_linalg_eigvalsh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvalsh\n"); return -1; } - mlx_linalg_inv_ptr = dlsym(handle, "mlx_linalg_inv"); + mlx_linalg_inv_ptr = GET_SYM(handle, "mlx_linalg_inv"); if (mlx_linalg_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_inv\n"); return -1; } - mlx_linalg_lu_ptr = dlsym(handle, "mlx_linalg_lu"); + mlx_linalg_lu_ptr = GET_SYM(handle, "mlx_linalg_lu"); if (mlx_linalg_lu_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu\n"); return -1; } - mlx_linalg_lu_factor_ptr = dlsym(handle, "mlx_linalg_lu_factor"); + mlx_linalg_lu_factor_ptr = GET_SYM(handle, "mlx_linalg_lu_factor"); if (mlx_linalg_lu_factor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu_factor\n"); return -1; } - mlx_linalg_norm_ptr = dlsym(handle, "mlx_linalg_norm"); + mlx_linalg_norm_ptr = GET_SYM(handle, "mlx_linalg_norm"); if (mlx_linalg_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm\n"); return -1; } - mlx_linalg_norm_matrix_ptr = dlsym(handle, "mlx_linalg_norm_matrix"); + mlx_linalg_norm_matrix_ptr = GET_SYM(handle, "mlx_linalg_norm_matrix"); if (mlx_linalg_norm_matrix_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_matrix\n"); return -1; } - mlx_linalg_norm_l2_ptr = dlsym(handle, "mlx_linalg_norm_l2"); + mlx_linalg_norm_l2_ptr = GET_SYM(handle, "mlx_linalg_norm_l2"); if (mlx_linalg_norm_l2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_l2\n"); return -1; } - mlx_linalg_pinv_ptr = dlsym(handle, "mlx_linalg_pinv"); + mlx_linalg_pinv_ptr = GET_SYM(handle, "mlx_linalg_pinv"); if (mlx_linalg_pinv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_pinv\n"); return -1; } - mlx_linalg_qr_ptr = dlsym(handle, "mlx_linalg_qr"); + mlx_linalg_qr_ptr = GET_SYM(handle, "mlx_linalg_qr"); if (mlx_linalg_qr_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_qr\n"); return -1; } - mlx_linalg_solve_ptr = dlsym(handle, "mlx_linalg_solve"); + mlx_linalg_solve_ptr = GET_SYM(handle, "mlx_linalg_solve"); if (mlx_linalg_solve_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve\n"); return -1; } - mlx_linalg_solve_triangular_ptr = dlsym(handle, "mlx_linalg_solve_triangular"); + mlx_linalg_solve_triangular_ptr = GET_SYM(handle, "mlx_linalg_solve_triangular"); if (mlx_linalg_solve_triangular_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve_triangular\n"); return -1; } - mlx_linalg_svd_ptr = dlsym(handle, "mlx_linalg_svd"); + mlx_linalg_svd_ptr = GET_SYM(handle, "mlx_linalg_svd"); if (mlx_linalg_svd_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_svd\n"); return -1; } - mlx_linalg_tri_inv_ptr = dlsym(handle, "mlx_linalg_tri_inv"); + mlx_linalg_tri_inv_ptr = GET_SYM(handle, "mlx_linalg_tri_inv"); if (mlx_linalg_tri_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_tri_inv\n"); return -1; } - mlx_map_string_to_array_new_ptr = dlsym(handle, "mlx_map_string_to_array_new"); + mlx_map_string_to_array_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_new"); if (mlx_map_string_to_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_new\n"); return -1; } - mlx_map_string_to_array_set_ptr = dlsym(handle, "mlx_map_string_to_array_set"); + mlx_map_string_to_array_set_ptr = GET_SYM(handle, "mlx_map_string_to_array_set"); if (mlx_map_string_to_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_set\n"); return -1; } - mlx_map_string_to_array_free_ptr = dlsym(handle, "mlx_map_string_to_array_free"); + mlx_map_string_to_array_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_free"); if (mlx_map_string_to_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_free\n"); return -1; } - mlx_map_string_to_array_insert_ptr = dlsym(handle, "mlx_map_string_to_array_insert"); + mlx_map_string_to_array_insert_ptr = GET_SYM(handle, "mlx_map_string_to_array_insert"); if (mlx_map_string_to_array_insert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_insert\n"); return -1; } - mlx_map_string_to_array_get_ptr = dlsym(handle, "mlx_map_string_to_array_get"); + mlx_map_string_to_array_get_ptr = GET_SYM(handle, "mlx_map_string_to_array_get"); if (mlx_map_string_to_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_get\n"); return -1; } - mlx_map_string_to_array_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_new"); + mlx_map_string_to_array_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_new"); if (mlx_map_string_to_array_iterator_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_new\n"); return -1; } - mlx_map_string_to_array_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_free"); + mlx_map_string_to_array_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_free"); if (mlx_map_string_to_array_iterator_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_free\n"); return -1; } - mlx_map_string_to_array_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_next"); + mlx_map_string_to_array_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_next"); if (mlx_map_string_to_array_iterator_next_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_next\n"); return -1; } - mlx_map_string_to_string_new_ptr = dlsym(handle, "mlx_map_string_to_string_new"); + mlx_map_string_to_string_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_new"); if (mlx_map_string_to_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_new\n"); return -1; } - mlx_map_string_to_string_set_ptr = dlsym(handle, "mlx_map_string_to_string_set"); + mlx_map_string_to_string_set_ptr = GET_SYM(handle, "mlx_map_string_to_string_set"); if (mlx_map_string_to_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_set\n"); return -1; } - mlx_map_string_to_string_free_ptr = dlsym(handle, "mlx_map_string_to_string_free"); + mlx_map_string_to_string_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_free"); if (mlx_map_string_to_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_free\n"); return -1; } - mlx_map_string_to_string_insert_ptr = dlsym(handle, "mlx_map_string_to_string_insert"); + mlx_map_string_to_string_insert_ptr = GET_SYM(handle, "mlx_map_string_to_string_insert"); if (mlx_map_string_to_string_insert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_insert\n"); return -1; } - mlx_map_string_to_string_get_ptr = dlsym(handle, "mlx_map_string_to_string_get"); + mlx_map_string_to_string_get_ptr = GET_SYM(handle, "mlx_map_string_to_string_get"); if (mlx_map_string_to_string_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_get\n"); return -1; } - mlx_map_string_to_string_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_new"); + mlx_map_string_to_string_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_new"); if (mlx_map_string_to_string_iterator_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_new\n"); return -1; } - mlx_map_string_to_string_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_free"); + mlx_map_string_to_string_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_free"); if (mlx_map_string_to_string_iterator_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_free\n"); return -1; } - mlx_map_string_to_string_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_next"); + mlx_map_string_to_string_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_next"); if (mlx_map_string_to_string_iterator_next_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_next\n"); return -1; } - mlx_clear_cache_ptr = dlsym(handle, "mlx_clear_cache"); + mlx_clear_cache_ptr = GET_SYM(handle, "mlx_clear_cache"); if (mlx_clear_cache_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_clear_cache\n"); return -1; } - mlx_get_active_memory_ptr = dlsym(handle, "mlx_get_active_memory"); + mlx_get_active_memory_ptr = GET_SYM(handle, "mlx_get_active_memory"); if (mlx_get_active_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_active_memory\n"); return -1; } - mlx_get_cache_memory_ptr = dlsym(handle, "mlx_get_cache_memory"); + mlx_get_cache_memory_ptr = GET_SYM(handle, "mlx_get_cache_memory"); if (mlx_get_cache_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_cache_memory\n"); return -1; } - mlx_get_memory_limit_ptr = dlsym(handle, "mlx_get_memory_limit"); + mlx_get_memory_limit_ptr = GET_SYM(handle, "mlx_get_memory_limit"); if (mlx_get_memory_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_memory_limit\n"); return -1; } - mlx_get_peak_memory_ptr = dlsym(handle, "mlx_get_peak_memory"); + mlx_get_peak_memory_ptr = GET_SYM(handle, "mlx_get_peak_memory"); if (mlx_get_peak_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_peak_memory\n"); return -1; } - mlx_reset_peak_memory_ptr = dlsym(handle, "mlx_reset_peak_memory"); + mlx_reset_peak_memory_ptr = GET_SYM(handle, "mlx_reset_peak_memory"); if (mlx_reset_peak_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reset_peak_memory\n"); return -1; } - mlx_set_cache_limit_ptr = dlsym(handle, "mlx_set_cache_limit"); + mlx_set_cache_limit_ptr = GET_SYM(handle, "mlx_set_cache_limit"); if (mlx_set_cache_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_cache_limit\n"); return -1; } - mlx_set_memory_limit_ptr = dlsym(handle, "mlx_set_memory_limit"); + mlx_set_memory_limit_ptr = GET_SYM(handle, "mlx_set_memory_limit"); if (mlx_set_memory_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_memory_limit\n"); return -1; } - mlx_set_wired_limit_ptr = dlsym(handle, "mlx_set_wired_limit"); + mlx_set_wired_limit_ptr = GET_SYM(handle, "mlx_set_wired_limit"); if (mlx_set_wired_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n"); return -1; } - mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available"); + mlx_metal_is_available_ptr = GET_SYM(handle, "mlx_metal_is_available"); if (mlx_metal_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n"); return -1; } - mlx_metal_start_capture_ptr = dlsym(handle, "mlx_metal_start_capture"); + mlx_metal_start_capture_ptr = GET_SYM(handle, "mlx_metal_start_capture"); if (mlx_metal_start_capture_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_start_capture\n"); return -1; } - mlx_metal_stop_capture_ptr = dlsym(handle, "mlx_metal_stop_capture"); + mlx_metal_stop_capture_ptr = GET_SYM(handle, "mlx_metal_stop_capture"); if (mlx_metal_stop_capture_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_stop_capture\n"); return -1; } - mlx_abs_ptr = dlsym(handle, "mlx_abs"); + mlx_abs_ptr = GET_SYM(handle, "mlx_abs"); if (mlx_abs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_abs\n"); return -1; } - mlx_add_ptr = dlsym(handle, "mlx_add"); + mlx_add_ptr = GET_SYM(handle, "mlx_add"); if (mlx_add_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_add\n"); return -1; } - mlx_addmm_ptr = dlsym(handle, "mlx_addmm"); + mlx_addmm_ptr = GET_SYM(handle, "mlx_addmm"); if (mlx_addmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_addmm\n"); return -1; } - mlx_all_axes_ptr = dlsym(handle, "mlx_all_axes"); + mlx_all_axes_ptr = GET_SYM(handle, "mlx_all_axes"); if (mlx_all_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axes\n"); return -1; } - mlx_all_axis_ptr = dlsym(handle, "mlx_all_axis"); + mlx_all_axis_ptr = GET_SYM(handle, "mlx_all_axis"); if (mlx_all_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axis\n"); return -1; } - mlx_all_ptr = dlsym(handle, "mlx_all"); + mlx_all_ptr = GET_SYM(handle, "mlx_all"); if (mlx_all_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all\n"); return -1; } - mlx_allclose_ptr = dlsym(handle, "mlx_allclose"); + mlx_allclose_ptr = GET_SYM(handle, "mlx_allclose"); if (mlx_allclose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_allclose\n"); return -1; } - mlx_any_axes_ptr = dlsym(handle, "mlx_any_axes"); + mlx_any_axes_ptr = GET_SYM(handle, "mlx_any_axes"); if (mlx_any_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axes\n"); return -1; } - mlx_any_axis_ptr = dlsym(handle, "mlx_any_axis"); + mlx_any_axis_ptr = GET_SYM(handle, "mlx_any_axis"); if (mlx_any_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axis\n"); return -1; } - mlx_any_ptr = dlsym(handle, "mlx_any"); + mlx_any_ptr = GET_SYM(handle, "mlx_any"); if (mlx_any_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any\n"); return -1; } - mlx_arange_ptr = dlsym(handle, "mlx_arange"); + mlx_arange_ptr = GET_SYM(handle, "mlx_arange"); if (mlx_arange_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arange\n"); return -1; } - mlx_arccos_ptr = dlsym(handle, "mlx_arccos"); + mlx_arccos_ptr = GET_SYM(handle, "mlx_arccos"); if (mlx_arccos_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arccos\n"); return -1; } - mlx_arccosh_ptr = dlsym(handle, "mlx_arccosh"); + mlx_arccosh_ptr = GET_SYM(handle, "mlx_arccosh"); if (mlx_arccosh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arccosh\n"); return -1; } - mlx_arcsin_ptr = dlsym(handle, "mlx_arcsin"); + mlx_arcsin_ptr = GET_SYM(handle, "mlx_arcsin"); if (mlx_arcsin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsin\n"); return -1; } - mlx_arcsinh_ptr = dlsym(handle, "mlx_arcsinh"); + mlx_arcsinh_ptr = GET_SYM(handle, "mlx_arcsinh"); if (mlx_arcsinh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsinh\n"); return -1; } - mlx_arctan_ptr = dlsym(handle, "mlx_arctan"); + mlx_arctan_ptr = GET_SYM(handle, "mlx_arctan"); if (mlx_arctan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan\n"); return -1; } - mlx_arctan2_ptr = dlsym(handle, "mlx_arctan2"); + mlx_arctan2_ptr = GET_SYM(handle, "mlx_arctan2"); if (mlx_arctan2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan2\n"); return -1; } - mlx_arctanh_ptr = dlsym(handle, "mlx_arctanh"); + mlx_arctanh_ptr = GET_SYM(handle, "mlx_arctanh"); if (mlx_arctanh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctanh\n"); return -1; } - mlx_argmax_axis_ptr = dlsym(handle, "mlx_argmax_axis"); + mlx_argmax_axis_ptr = GET_SYM(handle, "mlx_argmax_axis"); if (mlx_argmax_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax_axis\n"); return -1; } - mlx_argmax_ptr = dlsym(handle, "mlx_argmax"); + mlx_argmax_ptr = GET_SYM(handle, "mlx_argmax"); if (mlx_argmax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax\n"); return -1; } - mlx_argmin_axis_ptr = dlsym(handle, "mlx_argmin_axis"); + mlx_argmin_axis_ptr = GET_SYM(handle, "mlx_argmin_axis"); if (mlx_argmin_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin_axis\n"); return -1; } - mlx_argmin_ptr = dlsym(handle, "mlx_argmin"); + mlx_argmin_ptr = GET_SYM(handle, "mlx_argmin"); if (mlx_argmin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin\n"); return -1; } - mlx_argpartition_axis_ptr = dlsym(handle, "mlx_argpartition_axis"); + mlx_argpartition_axis_ptr = GET_SYM(handle, "mlx_argpartition_axis"); if (mlx_argpartition_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition_axis\n"); return -1; } - mlx_argpartition_ptr = dlsym(handle, "mlx_argpartition"); + mlx_argpartition_ptr = GET_SYM(handle, "mlx_argpartition"); if (mlx_argpartition_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition\n"); return -1; } - mlx_argsort_axis_ptr = dlsym(handle, "mlx_argsort_axis"); + mlx_argsort_axis_ptr = GET_SYM(handle, "mlx_argsort_axis"); if (mlx_argsort_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort_axis\n"); return -1; } - mlx_argsort_ptr = dlsym(handle, "mlx_argsort"); + mlx_argsort_ptr = GET_SYM(handle, "mlx_argsort"); if (mlx_argsort_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort\n"); return -1; } - mlx_array_equal_ptr = dlsym(handle, "mlx_array_equal"); + mlx_array_equal_ptr = GET_SYM(handle, "mlx_array_equal"); if (mlx_array_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_equal\n"); return -1; } - mlx_as_strided_ptr = dlsym(handle, "mlx_as_strided"); + mlx_as_strided_ptr = GET_SYM(handle, "mlx_as_strided"); if (mlx_as_strided_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_as_strided\n"); return -1; } - mlx_astype_ptr = dlsym(handle, "mlx_astype"); + mlx_astype_ptr = GET_SYM(handle, "mlx_astype"); if (mlx_astype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_astype\n"); return -1; } - mlx_atleast_1d_ptr = dlsym(handle, "mlx_atleast_1d"); + mlx_atleast_1d_ptr = GET_SYM(handle, "mlx_atleast_1d"); if (mlx_atleast_1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_1d\n"); return -1; } - mlx_atleast_2d_ptr = dlsym(handle, "mlx_atleast_2d"); + mlx_atleast_2d_ptr = GET_SYM(handle, "mlx_atleast_2d"); if (mlx_atleast_2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_2d\n"); return -1; } - mlx_atleast_3d_ptr = dlsym(handle, "mlx_atleast_3d"); + mlx_atleast_3d_ptr = GET_SYM(handle, "mlx_atleast_3d"); if (mlx_atleast_3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n"); return -1; } - mlx_bitwise_and_ptr = dlsym(handle, "mlx_bitwise_and"); + mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and"); if (mlx_bitwise_and_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n"); return -1; } - mlx_bitwise_invert_ptr = dlsym(handle, "mlx_bitwise_invert"); + mlx_bitwise_invert_ptr = GET_SYM(handle, "mlx_bitwise_invert"); if (mlx_bitwise_invert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_invert\n"); return -1; } - mlx_bitwise_or_ptr = dlsym(handle, "mlx_bitwise_or"); + mlx_bitwise_or_ptr = GET_SYM(handle, "mlx_bitwise_or"); if (mlx_bitwise_or_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_or\n"); return -1; } - mlx_bitwise_xor_ptr = dlsym(handle, "mlx_bitwise_xor"); + mlx_bitwise_xor_ptr = GET_SYM(handle, "mlx_bitwise_xor"); if (mlx_bitwise_xor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n"); return -1; } - mlx_block_masked_mm_ptr = dlsym(handle, "mlx_block_masked_mm"); + mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm"); if (mlx_block_masked_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n"); return -1; } - mlx_broadcast_arrays_ptr = dlsym(handle, "mlx_broadcast_arrays"); + mlx_broadcast_arrays_ptr = GET_SYM(handle, "mlx_broadcast_arrays"); if (mlx_broadcast_arrays_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_arrays\n"); return -1; } - mlx_broadcast_to_ptr = dlsym(handle, "mlx_broadcast_to"); + mlx_broadcast_to_ptr = GET_SYM(handle, "mlx_broadcast_to"); if (mlx_broadcast_to_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_to\n"); return -1; } - mlx_ceil_ptr = dlsym(handle, "mlx_ceil"); + mlx_ceil_ptr = GET_SYM(handle, "mlx_ceil"); if (mlx_ceil_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ceil\n"); return -1; } - mlx_clip_ptr = dlsym(handle, "mlx_clip"); + mlx_clip_ptr = GET_SYM(handle, "mlx_clip"); if (mlx_clip_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_clip\n"); return -1; } - mlx_concatenate_axis_ptr = dlsym(handle, "mlx_concatenate_axis"); + mlx_concatenate_axis_ptr = GET_SYM(handle, "mlx_concatenate_axis"); if (mlx_concatenate_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate_axis\n"); return -1; } - mlx_concatenate_ptr = dlsym(handle, "mlx_concatenate"); + mlx_concatenate_ptr = GET_SYM(handle, "mlx_concatenate"); if (mlx_concatenate_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate\n"); return -1; } - mlx_conjugate_ptr = dlsym(handle, "mlx_conjugate"); + mlx_conjugate_ptr = GET_SYM(handle, "mlx_conjugate"); if (mlx_conjugate_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conjugate\n"); return -1; } - mlx_contiguous_ptr = dlsym(handle, "mlx_contiguous"); + mlx_contiguous_ptr = GET_SYM(handle, "mlx_contiguous"); if (mlx_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_contiguous\n"); return -1; } - mlx_conv1d_ptr = dlsym(handle, "mlx_conv1d"); + mlx_conv1d_ptr = GET_SYM(handle, "mlx_conv1d"); if (mlx_conv1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv1d\n"); return -1; } - mlx_conv2d_ptr = dlsym(handle, "mlx_conv2d"); + mlx_conv2d_ptr = GET_SYM(handle, "mlx_conv2d"); if (mlx_conv2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv2d\n"); return -1; } - mlx_conv3d_ptr = dlsym(handle, "mlx_conv3d"); + mlx_conv3d_ptr = GET_SYM(handle, "mlx_conv3d"); if (mlx_conv3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv3d\n"); return -1; } - mlx_conv_general_ptr = dlsym(handle, "mlx_conv_general"); + mlx_conv_general_ptr = GET_SYM(handle, "mlx_conv_general"); if (mlx_conv_general_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_general\n"); return -1; } - mlx_conv_transpose1d_ptr = dlsym(handle, "mlx_conv_transpose1d"); + mlx_conv_transpose1d_ptr = GET_SYM(handle, "mlx_conv_transpose1d"); if (mlx_conv_transpose1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose1d\n"); return -1; } - mlx_conv_transpose2d_ptr = dlsym(handle, "mlx_conv_transpose2d"); + mlx_conv_transpose2d_ptr = GET_SYM(handle, "mlx_conv_transpose2d"); if (mlx_conv_transpose2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose2d\n"); return -1; } - mlx_conv_transpose3d_ptr = dlsym(handle, "mlx_conv_transpose3d"); + mlx_conv_transpose3d_ptr = GET_SYM(handle, "mlx_conv_transpose3d"); if (mlx_conv_transpose3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose3d\n"); return -1; } - mlx_copy_ptr = dlsym(handle, "mlx_copy"); + mlx_copy_ptr = GET_SYM(handle, "mlx_copy"); if (mlx_copy_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_copy\n"); return -1; } - mlx_cos_ptr = dlsym(handle, "mlx_cos"); + mlx_cos_ptr = GET_SYM(handle, "mlx_cos"); if (mlx_cos_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cos\n"); return -1; } - mlx_cosh_ptr = dlsym(handle, "mlx_cosh"); + mlx_cosh_ptr = GET_SYM(handle, "mlx_cosh"); if (mlx_cosh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cosh\n"); return -1; } - mlx_cummax_ptr = dlsym(handle, "mlx_cummax"); + mlx_cummax_ptr = GET_SYM(handle, "mlx_cummax"); if (mlx_cummax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cummax\n"); return -1; } - mlx_cummin_ptr = dlsym(handle, "mlx_cummin"); + mlx_cummin_ptr = GET_SYM(handle, "mlx_cummin"); if (mlx_cummin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cummin\n"); return -1; } - mlx_cumprod_ptr = dlsym(handle, "mlx_cumprod"); + mlx_cumprod_ptr = GET_SYM(handle, "mlx_cumprod"); if (mlx_cumprod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cumprod\n"); return -1; } - mlx_cumsum_ptr = dlsym(handle, "mlx_cumsum"); + mlx_cumsum_ptr = GET_SYM(handle, "mlx_cumsum"); if (mlx_cumsum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cumsum\n"); return -1; } - mlx_degrees_ptr = dlsym(handle, "mlx_degrees"); + mlx_degrees_ptr = GET_SYM(handle, "mlx_degrees"); if (mlx_degrees_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_degrees\n"); return -1; } - mlx_depends_ptr = dlsym(handle, "mlx_depends"); + mlx_depends_ptr = GET_SYM(handle, "mlx_depends"); if (mlx_depends_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_depends\n"); return -1; } - mlx_dequantize_ptr = dlsym(handle, "mlx_dequantize"); + mlx_dequantize_ptr = GET_SYM(handle, "mlx_dequantize"); if (mlx_dequantize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_dequantize\n"); return -1; } - mlx_diag_ptr = dlsym(handle, "mlx_diag"); + mlx_diag_ptr = GET_SYM(handle, "mlx_diag"); if (mlx_diag_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_diag\n"); return -1; } - mlx_diagonal_ptr = dlsym(handle, "mlx_diagonal"); + mlx_diagonal_ptr = GET_SYM(handle, "mlx_diagonal"); if (mlx_diagonal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_diagonal\n"); return -1; } - mlx_divide_ptr = dlsym(handle, "mlx_divide"); + mlx_divide_ptr = GET_SYM(handle, "mlx_divide"); if (mlx_divide_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_divide\n"); return -1; } - mlx_divmod_ptr = dlsym(handle, "mlx_divmod"); + mlx_divmod_ptr = GET_SYM(handle, "mlx_divmod"); if (mlx_divmod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_divmod\n"); return -1; } - mlx_einsum_ptr = dlsym(handle, "mlx_einsum"); + mlx_einsum_ptr = GET_SYM(handle, "mlx_einsum"); if (mlx_einsum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_einsum\n"); return -1; } - mlx_equal_ptr = dlsym(handle, "mlx_equal"); + mlx_equal_ptr = GET_SYM(handle, "mlx_equal"); if (mlx_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_equal\n"); return -1; } - mlx_erf_ptr = dlsym(handle, "mlx_erf"); + mlx_erf_ptr = GET_SYM(handle, "mlx_erf"); if (mlx_erf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_erf\n"); return -1; } - mlx_erfinv_ptr = dlsym(handle, "mlx_erfinv"); + mlx_erfinv_ptr = GET_SYM(handle, "mlx_erfinv"); if (mlx_erfinv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_erfinv\n"); return -1; } - mlx_exp_ptr = dlsym(handle, "mlx_exp"); + mlx_exp_ptr = GET_SYM(handle, "mlx_exp"); if (mlx_exp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_exp\n"); return -1; } - mlx_expand_dims_axes_ptr = dlsym(handle, "mlx_expand_dims_axes"); + mlx_expand_dims_axes_ptr = GET_SYM(handle, "mlx_expand_dims_axes"); if (mlx_expand_dims_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims_axes\n"); return -1; } - mlx_expand_dims_ptr = dlsym(handle, "mlx_expand_dims"); + mlx_expand_dims_ptr = GET_SYM(handle, "mlx_expand_dims"); if (mlx_expand_dims_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims\n"); return -1; } - mlx_expm1_ptr = dlsym(handle, "mlx_expm1"); + mlx_expm1_ptr = GET_SYM(handle, "mlx_expm1"); if (mlx_expm1_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expm1\n"); return -1; } - mlx_eye_ptr = dlsym(handle, "mlx_eye"); + mlx_eye_ptr = GET_SYM(handle, "mlx_eye"); if (mlx_eye_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_eye\n"); return -1; } - mlx_flatten_ptr = dlsym(handle, "mlx_flatten"); + mlx_flatten_ptr = GET_SYM(handle, "mlx_flatten"); if (mlx_flatten_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_flatten\n"); return -1; } - mlx_floor_ptr = dlsym(handle, "mlx_floor"); + mlx_floor_ptr = GET_SYM(handle, "mlx_floor"); if (mlx_floor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_floor\n"); return -1; } - mlx_floor_divide_ptr = dlsym(handle, "mlx_floor_divide"); + mlx_floor_divide_ptr = GET_SYM(handle, "mlx_floor_divide"); if (mlx_floor_divide_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_floor_divide\n"); return -1; } - mlx_from_fp8_ptr = dlsym(handle, "mlx_from_fp8"); + mlx_from_fp8_ptr = GET_SYM(handle, "mlx_from_fp8"); if (mlx_from_fp8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_from_fp8\n"); return -1; } - mlx_full_ptr = dlsym(handle, "mlx_full"); + mlx_full_ptr = GET_SYM(handle, "mlx_full"); if (mlx_full_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_full\n"); return -1; } - mlx_full_like_ptr = dlsym(handle, "mlx_full_like"); + mlx_full_like_ptr = GET_SYM(handle, "mlx_full_like"); if (mlx_full_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_full_like\n"); return -1; } - mlx_gather_ptr = dlsym(handle, "mlx_gather"); + mlx_gather_ptr = GET_SYM(handle, "mlx_gather"); if (mlx_gather_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather\n"); return -1; } - mlx_gather_single_ptr = dlsym(handle, "mlx_gather_single"); + mlx_gather_single_ptr = GET_SYM(handle, "mlx_gather_single"); if (mlx_gather_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_single\n"); return -1; } - mlx_gather_mm_ptr = dlsym(handle, "mlx_gather_mm"); + mlx_gather_mm_ptr = GET_SYM(handle, "mlx_gather_mm"); if (mlx_gather_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_mm\n"); return -1; } - mlx_gather_qmm_ptr = dlsym(handle, "mlx_gather_qmm"); + mlx_gather_qmm_ptr = GET_SYM(handle, "mlx_gather_qmm"); if (mlx_gather_qmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_qmm\n"); return -1; } - mlx_greater_ptr = dlsym(handle, "mlx_greater"); + mlx_greater_ptr = GET_SYM(handle, "mlx_greater"); if (mlx_greater_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_greater\n"); return -1; } - mlx_greater_equal_ptr = dlsym(handle, "mlx_greater_equal"); + mlx_greater_equal_ptr = GET_SYM(handle, "mlx_greater_equal"); if (mlx_greater_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_greater_equal\n"); return -1; } - mlx_hadamard_transform_ptr = dlsym(handle, "mlx_hadamard_transform"); + mlx_hadamard_transform_ptr = GET_SYM(handle, "mlx_hadamard_transform"); if (mlx_hadamard_transform_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n"); return -1; } - mlx_identity_ptr = dlsym(handle, "mlx_identity"); + mlx_identity_ptr = GET_SYM(handle, "mlx_identity"); if (mlx_identity_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n"); return -1; } - mlx_imag_ptr = dlsym(handle, "mlx_imag"); + mlx_imag_ptr = GET_SYM(handle, "mlx_imag"); if (mlx_imag_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imag\n"); return -1; } - mlx_inner_ptr = dlsym(handle, "mlx_inner"); + mlx_inner_ptr = GET_SYM(handle, "mlx_inner"); if (mlx_inner_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_inner\n"); return -1; } - mlx_isclose_ptr = dlsym(handle, "mlx_isclose"); + mlx_isclose_ptr = GET_SYM(handle, "mlx_isclose"); if (mlx_isclose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isclose\n"); return -1; } - mlx_isfinite_ptr = dlsym(handle, "mlx_isfinite"); + mlx_isfinite_ptr = GET_SYM(handle, "mlx_isfinite"); if (mlx_isfinite_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isfinite\n"); return -1; } - mlx_isinf_ptr = dlsym(handle, "mlx_isinf"); + mlx_isinf_ptr = GET_SYM(handle, "mlx_isinf"); if (mlx_isinf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isinf\n"); return -1; } - mlx_isnan_ptr = dlsym(handle, "mlx_isnan"); + mlx_isnan_ptr = GET_SYM(handle, "mlx_isnan"); if (mlx_isnan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isnan\n"); return -1; } - mlx_isneginf_ptr = dlsym(handle, "mlx_isneginf"); + mlx_isneginf_ptr = GET_SYM(handle, "mlx_isneginf"); if (mlx_isneginf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isneginf\n"); return -1; } - mlx_isposinf_ptr = dlsym(handle, "mlx_isposinf"); + mlx_isposinf_ptr = GET_SYM(handle, "mlx_isposinf"); if (mlx_isposinf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isposinf\n"); return -1; } - mlx_kron_ptr = dlsym(handle, "mlx_kron"); + mlx_kron_ptr = GET_SYM(handle, "mlx_kron"); if (mlx_kron_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_kron\n"); return -1; } - mlx_left_shift_ptr = dlsym(handle, "mlx_left_shift"); + mlx_left_shift_ptr = GET_SYM(handle, "mlx_left_shift"); if (mlx_left_shift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_left_shift\n"); return -1; } - mlx_less_ptr = dlsym(handle, "mlx_less"); + mlx_less_ptr = GET_SYM(handle, "mlx_less"); if (mlx_less_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_less\n"); return -1; } - mlx_less_equal_ptr = dlsym(handle, "mlx_less_equal"); + mlx_less_equal_ptr = GET_SYM(handle, "mlx_less_equal"); if (mlx_less_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_less_equal\n"); return -1; } - mlx_linspace_ptr = dlsym(handle, "mlx_linspace"); + mlx_linspace_ptr = GET_SYM(handle, "mlx_linspace"); if (mlx_linspace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linspace\n"); return -1; } - mlx_log_ptr = dlsym(handle, "mlx_log"); + mlx_log_ptr = GET_SYM(handle, "mlx_log"); if (mlx_log_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log\n"); return -1; } - mlx_log10_ptr = dlsym(handle, "mlx_log10"); + mlx_log10_ptr = GET_SYM(handle, "mlx_log10"); if (mlx_log10_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log10\n"); return -1; } - mlx_log1p_ptr = dlsym(handle, "mlx_log1p"); + mlx_log1p_ptr = GET_SYM(handle, "mlx_log1p"); if (mlx_log1p_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log1p\n"); return -1; } - mlx_log2_ptr = dlsym(handle, "mlx_log2"); + mlx_log2_ptr = GET_SYM(handle, "mlx_log2"); if (mlx_log2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log2\n"); return -1; } - mlx_logaddexp_ptr = dlsym(handle, "mlx_logaddexp"); + mlx_logaddexp_ptr = GET_SYM(handle, "mlx_logaddexp"); if (mlx_logaddexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logaddexp\n"); return -1; } - mlx_logcumsumexp_ptr = dlsym(handle, "mlx_logcumsumexp"); + mlx_logcumsumexp_ptr = GET_SYM(handle, "mlx_logcumsumexp"); if (mlx_logcumsumexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logcumsumexp\n"); return -1; } - mlx_logical_and_ptr = dlsym(handle, "mlx_logical_and"); + mlx_logical_and_ptr = GET_SYM(handle, "mlx_logical_and"); if (mlx_logical_and_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_and\n"); return -1; } - mlx_logical_not_ptr = dlsym(handle, "mlx_logical_not"); + mlx_logical_not_ptr = GET_SYM(handle, "mlx_logical_not"); if (mlx_logical_not_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_not\n"); return -1; } - mlx_logical_or_ptr = dlsym(handle, "mlx_logical_or"); + mlx_logical_or_ptr = GET_SYM(handle, "mlx_logical_or"); if (mlx_logical_or_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_or\n"); return -1; } - mlx_logsumexp_axes_ptr = dlsym(handle, "mlx_logsumexp_axes"); + mlx_logsumexp_axes_ptr = GET_SYM(handle, "mlx_logsumexp_axes"); if (mlx_logsumexp_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axes\n"); return -1; } - mlx_logsumexp_axis_ptr = dlsym(handle, "mlx_logsumexp_axis"); + mlx_logsumexp_axis_ptr = GET_SYM(handle, "mlx_logsumexp_axis"); if (mlx_logsumexp_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axis\n"); return -1; } - mlx_logsumexp_ptr = dlsym(handle, "mlx_logsumexp"); + mlx_logsumexp_ptr = GET_SYM(handle, "mlx_logsumexp"); if (mlx_logsumexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp\n"); return -1; } - mlx_masked_scatter_ptr = dlsym(handle, "mlx_masked_scatter"); + mlx_masked_scatter_ptr = GET_SYM(handle, "mlx_masked_scatter"); if (mlx_masked_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_masked_scatter\n"); return -1; } - mlx_matmul_ptr = dlsym(handle, "mlx_matmul"); + mlx_matmul_ptr = GET_SYM(handle, "mlx_matmul"); if (mlx_matmul_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_matmul\n"); return -1; } - mlx_max_axes_ptr = dlsym(handle, "mlx_max_axes"); + mlx_max_axes_ptr = GET_SYM(handle, "mlx_max_axes"); if (mlx_max_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axes\n"); return -1; } - mlx_max_axis_ptr = dlsym(handle, "mlx_max_axis"); + mlx_max_axis_ptr = GET_SYM(handle, "mlx_max_axis"); if (mlx_max_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axis\n"); return -1; } - mlx_max_ptr = dlsym(handle, "mlx_max"); + mlx_max_ptr = GET_SYM(handle, "mlx_max"); if (mlx_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max\n"); return -1; } - mlx_maximum_ptr = dlsym(handle, "mlx_maximum"); + mlx_maximum_ptr = GET_SYM(handle, "mlx_maximum"); if (mlx_maximum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_maximum\n"); return -1; } - mlx_mean_axes_ptr = dlsym(handle, "mlx_mean_axes"); + mlx_mean_axes_ptr = GET_SYM(handle, "mlx_mean_axes"); if (mlx_mean_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axes\n"); return -1; } - mlx_mean_axis_ptr = dlsym(handle, "mlx_mean_axis"); + mlx_mean_axis_ptr = GET_SYM(handle, "mlx_mean_axis"); if (mlx_mean_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axis\n"); return -1; } - mlx_mean_ptr = dlsym(handle, "mlx_mean"); + mlx_mean_ptr = GET_SYM(handle, "mlx_mean"); if (mlx_mean_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean\n"); return -1; } - mlx_median_ptr = dlsym(handle, "mlx_median"); + mlx_median_ptr = GET_SYM(handle, "mlx_median"); if (mlx_median_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_median\n"); return -1; } - mlx_meshgrid_ptr = dlsym(handle, "mlx_meshgrid"); + mlx_meshgrid_ptr = GET_SYM(handle, "mlx_meshgrid"); if (mlx_meshgrid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_meshgrid\n"); return -1; } - mlx_min_axes_ptr = dlsym(handle, "mlx_min_axes"); + mlx_min_axes_ptr = GET_SYM(handle, "mlx_min_axes"); if (mlx_min_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axes\n"); return -1; } - mlx_min_axis_ptr = dlsym(handle, "mlx_min_axis"); + mlx_min_axis_ptr = GET_SYM(handle, "mlx_min_axis"); if (mlx_min_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axis\n"); return -1; } - mlx_min_ptr = dlsym(handle, "mlx_min"); + mlx_min_ptr = GET_SYM(handle, "mlx_min"); if (mlx_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min\n"); return -1; } - mlx_minimum_ptr = dlsym(handle, "mlx_minimum"); + mlx_minimum_ptr = GET_SYM(handle, "mlx_minimum"); if (mlx_minimum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_minimum\n"); return -1; } - mlx_moveaxis_ptr = dlsym(handle, "mlx_moveaxis"); + mlx_moveaxis_ptr = GET_SYM(handle, "mlx_moveaxis"); if (mlx_moveaxis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_moveaxis\n"); return -1; } - mlx_multiply_ptr = dlsym(handle, "mlx_multiply"); + mlx_multiply_ptr = GET_SYM(handle, "mlx_multiply"); if (mlx_multiply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_multiply\n"); return -1; } - mlx_nan_to_num_ptr = dlsym(handle, "mlx_nan_to_num"); + mlx_nan_to_num_ptr = GET_SYM(handle, "mlx_nan_to_num"); if (mlx_nan_to_num_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_nan_to_num\n"); return -1; } - mlx_negative_ptr = dlsym(handle, "mlx_negative"); + mlx_negative_ptr = GET_SYM(handle, "mlx_negative"); if (mlx_negative_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_negative\n"); return -1; } - mlx_not_equal_ptr = dlsym(handle, "mlx_not_equal"); + mlx_not_equal_ptr = GET_SYM(handle, "mlx_not_equal"); if (mlx_not_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_not_equal\n"); return -1; } - mlx_number_of_elements_ptr = dlsym(handle, "mlx_number_of_elements"); + mlx_number_of_elements_ptr = GET_SYM(handle, "mlx_number_of_elements"); if (mlx_number_of_elements_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_number_of_elements\n"); return -1; } - mlx_ones_ptr = dlsym(handle, "mlx_ones"); + mlx_ones_ptr = GET_SYM(handle, "mlx_ones"); if (mlx_ones_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ones\n"); return -1; } - mlx_ones_like_ptr = dlsym(handle, "mlx_ones_like"); + mlx_ones_like_ptr = GET_SYM(handle, "mlx_ones_like"); if (mlx_ones_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ones_like\n"); return -1; } - mlx_outer_ptr = dlsym(handle, "mlx_outer"); + mlx_outer_ptr = GET_SYM(handle, "mlx_outer"); if (mlx_outer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_outer\n"); return -1; } - mlx_pad_ptr = dlsym(handle, "mlx_pad"); + mlx_pad_ptr = GET_SYM(handle, "mlx_pad"); if (mlx_pad_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_pad\n"); return -1; } - mlx_pad_symmetric_ptr = dlsym(handle, "mlx_pad_symmetric"); + mlx_pad_symmetric_ptr = GET_SYM(handle, "mlx_pad_symmetric"); if (mlx_pad_symmetric_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_pad_symmetric\n"); return -1; } - mlx_partition_axis_ptr = dlsym(handle, "mlx_partition_axis"); + mlx_partition_axis_ptr = GET_SYM(handle, "mlx_partition_axis"); if (mlx_partition_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_partition_axis\n"); return -1; } - mlx_partition_ptr = dlsym(handle, "mlx_partition"); + mlx_partition_ptr = GET_SYM(handle, "mlx_partition"); if (mlx_partition_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_partition\n"); return -1; } - mlx_power_ptr = dlsym(handle, "mlx_power"); + mlx_power_ptr = GET_SYM(handle, "mlx_power"); if (mlx_power_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_power\n"); return -1; } - mlx_prod_axes_ptr = dlsym(handle, "mlx_prod_axes"); + mlx_prod_axes_ptr = GET_SYM(handle, "mlx_prod_axes"); if (mlx_prod_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axes\n"); return -1; } - mlx_prod_axis_ptr = dlsym(handle, "mlx_prod_axis"); + mlx_prod_axis_ptr = GET_SYM(handle, "mlx_prod_axis"); if (mlx_prod_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axis\n"); return -1; } - mlx_prod_ptr = dlsym(handle, "mlx_prod"); + mlx_prod_ptr = GET_SYM(handle, "mlx_prod"); if (mlx_prod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod\n"); return -1; } - mlx_put_along_axis_ptr = dlsym(handle, "mlx_put_along_axis"); + mlx_put_along_axis_ptr = GET_SYM(handle, "mlx_put_along_axis"); if (mlx_put_along_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_put_along_axis\n"); return -1; } - mlx_qqmm_ptr = dlsym(handle, "mlx_qqmm"); + mlx_qqmm_ptr = GET_SYM(handle, "mlx_qqmm"); if (mlx_qqmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_qqmm\n"); return -1; } - mlx_quantize_ptr = dlsym(handle, "mlx_quantize"); + mlx_quantize_ptr = GET_SYM(handle, "mlx_quantize"); if (mlx_quantize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_quantize\n"); return -1; } - mlx_quantized_matmul_ptr = dlsym(handle, "mlx_quantized_matmul"); + mlx_quantized_matmul_ptr = GET_SYM(handle, "mlx_quantized_matmul"); if (mlx_quantized_matmul_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_quantized_matmul\n"); return -1; } - mlx_radians_ptr = dlsym(handle, "mlx_radians"); + mlx_radians_ptr = GET_SYM(handle, "mlx_radians"); if (mlx_radians_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_radians\n"); return -1; } - mlx_real_ptr = dlsym(handle, "mlx_real"); + mlx_real_ptr = GET_SYM(handle, "mlx_real"); if (mlx_real_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_real\n"); return -1; } - mlx_reciprocal_ptr = dlsym(handle, "mlx_reciprocal"); + mlx_reciprocal_ptr = GET_SYM(handle, "mlx_reciprocal"); if (mlx_reciprocal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reciprocal\n"); return -1; } - mlx_remainder_ptr = dlsym(handle, "mlx_remainder"); + mlx_remainder_ptr = GET_SYM(handle, "mlx_remainder"); if (mlx_remainder_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_remainder\n"); return -1; } - mlx_repeat_axis_ptr = dlsym(handle, "mlx_repeat_axis"); + mlx_repeat_axis_ptr = GET_SYM(handle, "mlx_repeat_axis"); if (mlx_repeat_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat_axis\n"); return -1; } - mlx_repeat_ptr = dlsym(handle, "mlx_repeat"); + mlx_repeat_ptr = GET_SYM(handle, "mlx_repeat"); if (mlx_repeat_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat\n"); return -1; } - mlx_reshape_ptr = dlsym(handle, "mlx_reshape"); + mlx_reshape_ptr = GET_SYM(handle, "mlx_reshape"); if (mlx_reshape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reshape\n"); return -1; } - mlx_right_shift_ptr = dlsym(handle, "mlx_right_shift"); + mlx_right_shift_ptr = GET_SYM(handle, "mlx_right_shift"); if (mlx_right_shift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_right_shift\n"); return -1; } - mlx_roll_axis_ptr = dlsym(handle, "mlx_roll_axis"); + mlx_roll_axis_ptr = GET_SYM(handle, "mlx_roll_axis"); if (mlx_roll_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axis\n"); return -1; } - mlx_roll_axes_ptr = dlsym(handle, "mlx_roll_axes"); + mlx_roll_axes_ptr = GET_SYM(handle, "mlx_roll_axes"); if (mlx_roll_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axes\n"); return -1; } - mlx_roll_ptr = dlsym(handle, "mlx_roll"); + mlx_roll_ptr = GET_SYM(handle, "mlx_roll"); if (mlx_roll_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll\n"); return -1; } - mlx_round_ptr = dlsym(handle, "mlx_round"); + mlx_round_ptr = GET_SYM(handle, "mlx_round"); if (mlx_round_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_round\n"); return -1; } - mlx_rsqrt_ptr = dlsym(handle, "mlx_rsqrt"); + mlx_rsqrt_ptr = GET_SYM(handle, "mlx_rsqrt"); if (mlx_rsqrt_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_rsqrt\n"); return -1; } - mlx_scatter_ptr = dlsym(handle, "mlx_scatter"); + mlx_scatter_ptr = GET_SYM(handle, "mlx_scatter"); if (mlx_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter\n"); return -1; } - mlx_scatter_single_ptr = dlsym(handle, "mlx_scatter_single"); + mlx_scatter_single_ptr = GET_SYM(handle, "mlx_scatter_single"); if (mlx_scatter_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_single\n"); return -1; } - mlx_scatter_add_ptr = dlsym(handle, "mlx_scatter_add"); + mlx_scatter_add_ptr = GET_SYM(handle, "mlx_scatter_add"); if (mlx_scatter_add_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add\n"); return -1; } - mlx_scatter_add_single_ptr = dlsym(handle, "mlx_scatter_add_single"); + mlx_scatter_add_single_ptr = GET_SYM(handle, "mlx_scatter_add_single"); if (mlx_scatter_add_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_single\n"); return -1; } - mlx_scatter_add_axis_ptr = dlsym(handle, "mlx_scatter_add_axis"); + mlx_scatter_add_axis_ptr = GET_SYM(handle, "mlx_scatter_add_axis"); if (mlx_scatter_add_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_axis\n"); return -1; } - mlx_scatter_max_ptr = dlsym(handle, "mlx_scatter_max"); + mlx_scatter_max_ptr = GET_SYM(handle, "mlx_scatter_max"); if (mlx_scatter_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max\n"); return -1; } - mlx_scatter_max_single_ptr = dlsym(handle, "mlx_scatter_max_single"); + mlx_scatter_max_single_ptr = GET_SYM(handle, "mlx_scatter_max_single"); if (mlx_scatter_max_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max_single\n"); return -1; } - mlx_scatter_min_ptr = dlsym(handle, "mlx_scatter_min"); + mlx_scatter_min_ptr = GET_SYM(handle, "mlx_scatter_min"); if (mlx_scatter_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min\n"); return -1; } - mlx_scatter_min_single_ptr = dlsym(handle, "mlx_scatter_min_single"); + mlx_scatter_min_single_ptr = GET_SYM(handle, "mlx_scatter_min_single"); if (mlx_scatter_min_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min_single\n"); return -1; } - mlx_scatter_prod_ptr = dlsym(handle, "mlx_scatter_prod"); + mlx_scatter_prod_ptr = GET_SYM(handle, "mlx_scatter_prod"); if (mlx_scatter_prod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod\n"); return -1; } - mlx_scatter_prod_single_ptr = dlsym(handle, "mlx_scatter_prod_single"); + mlx_scatter_prod_single_ptr = GET_SYM(handle, "mlx_scatter_prod_single"); if (mlx_scatter_prod_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod_single\n"); return -1; } - mlx_segmented_mm_ptr = dlsym(handle, "mlx_segmented_mm"); + mlx_segmented_mm_ptr = GET_SYM(handle, "mlx_segmented_mm"); if (mlx_segmented_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_segmented_mm\n"); return -1; } - mlx_sigmoid_ptr = dlsym(handle, "mlx_sigmoid"); + mlx_sigmoid_ptr = GET_SYM(handle, "mlx_sigmoid"); if (mlx_sigmoid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sigmoid\n"); return -1; } - mlx_sign_ptr = dlsym(handle, "mlx_sign"); + mlx_sign_ptr = GET_SYM(handle, "mlx_sign"); if (mlx_sign_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sign\n"); return -1; } - mlx_sin_ptr = dlsym(handle, "mlx_sin"); + mlx_sin_ptr = GET_SYM(handle, "mlx_sin"); if (mlx_sin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sin\n"); return -1; } - mlx_sinh_ptr = dlsym(handle, "mlx_sinh"); + mlx_sinh_ptr = GET_SYM(handle, "mlx_sinh"); if (mlx_sinh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sinh\n"); return -1; } - mlx_slice_ptr = dlsym(handle, "mlx_slice"); + mlx_slice_ptr = GET_SYM(handle, "mlx_slice"); if (mlx_slice_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice\n"); return -1; } - mlx_slice_dynamic_ptr = dlsym(handle, "mlx_slice_dynamic"); + mlx_slice_dynamic_ptr = GET_SYM(handle, "mlx_slice_dynamic"); if (mlx_slice_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_dynamic\n"); return -1; } - mlx_slice_update_ptr = dlsym(handle, "mlx_slice_update"); + mlx_slice_update_ptr = GET_SYM(handle, "mlx_slice_update"); if (mlx_slice_update_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update\n"); return -1; } - mlx_slice_update_dynamic_ptr = dlsym(handle, "mlx_slice_update_dynamic"); + mlx_slice_update_dynamic_ptr = GET_SYM(handle, "mlx_slice_update_dynamic"); if (mlx_slice_update_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_dynamic\n"); return -1; } - mlx_softmax_axes_ptr = dlsym(handle, "mlx_softmax_axes"); + mlx_softmax_axes_ptr = GET_SYM(handle, "mlx_softmax_axes"); if (mlx_softmax_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axes\n"); return -1; } - mlx_softmax_axis_ptr = dlsym(handle, "mlx_softmax_axis"); + mlx_softmax_axis_ptr = GET_SYM(handle, "mlx_softmax_axis"); if (mlx_softmax_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axis\n"); return -1; } - mlx_softmax_ptr = dlsym(handle, "mlx_softmax"); + mlx_softmax_ptr = GET_SYM(handle, "mlx_softmax"); if (mlx_softmax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax\n"); return -1; } - mlx_sort_axis_ptr = dlsym(handle, "mlx_sort_axis"); + mlx_sort_axis_ptr = GET_SYM(handle, "mlx_sort_axis"); if (mlx_sort_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sort_axis\n"); return -1; } - mlx_sort_ptr = dlsym(handle, "mlx_sort"); + mlx_sort_ptr = GET_SYM(handle, "mlx_sort"); if (mlx_sort_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sort\n"); return -1; } - mlx_split_ptr = dlsym(handle, "mlx_split"); + mlx_split_ptr = GET_SYM(handle, "mlx_split"); if (mlx_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_split\n"); return -1; } - mlx_split_sections_ptr = dlsym(handle, "mlx_split_sections"); + mlx_split_sections_ptr = GET_SYM(handle, "mlx_split_sections"); if (mlx_split_sections_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_split_sections\n"); return -1; } - mlx_sqrt_ptr = dlsym(handle, "mlx_sqrt"); + mlx_sqrt_ptr = GET_SYM(handle, "mlx_sqrt"); if (mlx_sqrt_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sqrt\n"); return -1; } - mlx_square_ptr = dlsym(handle, "mlx_square"); + mlx_square_ptr = GET_SYM(handle, "mlx_square"); if (mlx_square_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_square\n"); return -1; } - mlx_squeeze_axes_ptr = dlsym(handle, "mlx_squeeze_axes"); + mlx_squeeze_axes_ptr = GET_SYM(handle, "mlx_squeeze_axes"); if (mlx_squeeze_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axes\n"); return -1; } - mlx_squeeze_axis_ptr = dlsym(handle, "mlx_squeeze_axis"); + mlx_squeeze_axis_ptr = GET_SYM(handle, "mlx_squeeze_axis"); if (mlx_squeeze_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axis\n"); return -1; } - mlx_squeeze_ptr = dlsym(handle, "mlx_squeeze"); + mlx_squeeze_ptr = GET_SYM(handle, "mlx_squeeze"); if (mlx_squeeze_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze\n"); return -1; } - mlx_stack_axis_ptr = dlsym(handle, "mlx_stack_axis"); + mlx_stack_axis_ptr = GET_SYM(handle, "mlx_stack_axis"); if (mlx_stack_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stack_axis\n"); return -1; } - mlx_stack_ptr = dlsym(handle, "mlx_stack"); + mlx_stack_ptr = GET_SYM(handle, "mlx_stack"); if (mlx_stack_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stack\n"); return -1; } - mlx_std_axes_ptr = dlsym(handle, "mlx_std_axes"); + mlx_std_axes_ptr = GET_SYM(handle, "mlx_std_axes"); if (mlx_std_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axes\n"); return -1; } - mlx_std_axis_ptr = dlsym(handle, "mlx_std_axis"); + mlx_std_axis_ptr = GET_SYM(handle, "mlx_std_axis"); if (mlx_std_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axis\n"); return -1; } - mlx_std_ptr = dlsym(handle, "mlx_std"); + mlx_std_ptr = GET_SYM(handle, "mlx_std"); if (mlx_std_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std\n"); return -1; } - mlx_stop_gradient_ptr = dlsym(handle, "mlx_stop_gradient"); + mlx_stop_gradient_ptr = GET_SYM(handle, "mlx_stop_gradient"); if (mlx_stop_gradient_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stop_gradient\n"); return -1; } - mlx_subtract_ptr = dlsym(handle, "mlx_subtract"); + mlx_subtract_ptr = GET_SYM(handle, "mlx_subtract"); if (mlx_subtract_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_subtract\n"); return -1; } - mlx_sum_axes_ptr = dlsym(handle, "mlx_sum_axes"); + mlx_sum_axes_ptr = GET_SYM(handle, "mlx_sum_axes"); if (mlx_sum_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axes\n"); return -1; } - mlx_sum_axis_ptr = dlsym(handle, "mlx_sum_axis"); + mlx_sum_axis_ptr = GET_SYM(handle, "mlx_sum_axis"); if (mlx_sum_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axis\n"); return -1; } - mlx_sum_ptr = dlsym(handle, "mlx_sum"); + mlx_sum_ptr = GET_SYM(handle, "mlx_sum"); if (mlx_sum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum\n"); return -1; } - mlx_swapaxes_ptr = dlsym(handle, "mlx_swapaxes"); + mlx_swapaxes_ptr = GET_SYM(handle, "mlx_swapaxes"); if (mlx_swapaxes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_swapaxes\n"); return -1; } - mlx_take_axis_ptr = dlsym(handle, "mlx_take_axis"); + mlx_take_axis_ptr = GET_SYM(handle, "mlx_take_axis"); if (mlx_take_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take_axis\n"); return -1; } - mlx_take_ptr = dlsym(handle, "mlx_take"); + mlx_take_ptr = GET_SYM(handle, "mlx_take"); if (mlx_take_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take\n"); return -1; } - mlx_take_along_axis_ptr = dlsym(handle, "mlx_take_along_axis"); + mlx_take_along_axis_ptr = GET_SYM(handle, "mlx_take_along_axis"); if (mlx_take_along_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take_along_axis\n"); return -1; } - mlx_tan_ptr = dlsym(handle, "mlx_tan"); + mlx_tan_ptr = GET_SYM(handle, "mlx_tan"); if (mlx_tan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tan\n"); return -1; } - mlx_tanh_ptr = dlsym(handle, "mlx_tanh"); + mlx_tanh_ptr = GET_SYM(handle, "mlx_tanh"); if (mlx_tanh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tanh\n"); return -1; } - mlx_tensordot_ptr = dlsym(handle, "mlx_tensordot"); + mlx_tensordot_ptr = GET_SYM(handle, "mlx_tensordot"); if (mlx_tensordot_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot\n"); return -1; } - mlx_tensordot_axis_ptr = dlsym(handle, "mlx_tensordot_axis"); + mlx_tensordot_axis_ptr = GET_SYM(handle, "mlx_tensordot_axis"); if (mlx_tensordot_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot_axis\n"); return -1; } - mlx_tile_ptr = dlsym(handle, "mlx_tile"); + mlx_tile_ptr = GET_SYM(handle, "mlx_tile"); if (mlx_tile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tile\n"); return -1; } - mlx_to_fp8_ptr = dlsym(handle, "mlx_to_fp8"); + mlx_to_fp8_ptr = GET_SYM(handle, "mlx_to_fp8"); if (mlx_to_fp8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_to_fp8\n"); return -1; } - mlx_topk_axis_ptr = dlsym(handle, "mlx_topk_axis"); + mlx_topk_axis_ptr = GET_SYM(handle, "mlx_topk_axis"); if (mlx_topk_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_topk_axis\n"); return -1; } - mlx_topk_ptr = dlsym(handle, "mlx_topk"); + mlx_topk_ptr = GET_SYM(handle, "mlx_topk"); if (mlx_topk_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_topk\n"); return -1; } - mlx_trace_ptr = dlsym(handle, "mlx_trace"); + mlx_trace_ptr = GET_SYM(handle, "mlx_trace"); if (mlx_trace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_trace\n"); return -1; } - mlx_transpose_axes_ptr = dlsym(handle, "mlx_transpose_axes"); + mlx_transpose_axes_ptr = GET_SYM(handle, "mlx_transpose_axes"); if (mlx_transpose_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose_axes\n"); return -1; } - mlx_transpose_ptr = dlsym(handle, "mlx_transpose"); + mlx_transpose_ptr = GET_SYM(handle, "mlx_transpose"); if (mlx_transpose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose\n"); return -1; } - mlx_tri_ptr = dlsym(handle, "mlx_tri"); + mlx_tri_ptr = GET_SYM(handle, "mlx_tri"); if (mlx_tri_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tri\n"); return -1; } - mlx_tril_ptr = dlsym(handle, "mlx_tril"); + mlx_tril_ptr = GET_SYM(handle, "mlx_tril"); if (mlx_tril_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tril\n"); return -1; } - mlx_triu_ptr = dlsym(handle, "mlx_triu"); + mlx_triu_ptr = GET_SYM(handle, "mlx_triu"); if (mlx_triu_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_triu\n"); return -1; } - mlx_unflatten_ptr = dlsym(handle, "mlx_unflatten"); + mlx_unflatten_ptr = GET_SYM(handle, "mlx_unflatten"); if (mlx_unflatten_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_unflatten\n"); return -1; } - mlx_var_axes_ptr = dlsym(handle, "mlx_var_axes"); + mlx_var_axes_ptr = GET_SYM(handle, "mlx_var_axes"); if (mlx_var_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axes\n"); return -1; } - mlx_var_axis_ptr = dlsym(handle, "mlx_var_axis"); + mlx_var_axis_ptr = GET_SYM(handle, "mlx_var_axis"); if (mlx_var_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axis\n"); return -1; } - mlx_var_ptr = dlsym(handle, "mlx_var"); + mlx_var_ptr = GET_SYM(handle, "mlx_var"); if (mlx_var_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var\n"); return -1; } - mlx_view_ptr = dlsym(handle, "mlx_view"); + mlx_view_ptr = GET_SYM(handle, "mlx_view"); if (mlx_view_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_view\n"); return -1; } - mlx_where_ptr = dlsym(handle, "mlx_where"); + mlx_where_ptr = GET_SYM(handle, "mlx_where"); if (mlx_where_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_where\n"); return -1; } - mlx_zeros_ptr = dlsym(handle, "mlx_zeros"); + mlx_zeros_ptr = GET_SYM(handle, "mlx_zeros"); if (mlx_zeros_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros\n"); return -1; } - mlx_zeros_like_ptr = dlsym(handle, "mlx_zeros_like"); + mlx_zeros_like_ptr = GET_SYM(handle, "mlx_zeros_like"); if (mlx_zeros_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros_like\n"); return -1; } - mlx_random_bernoulli_ptr = dlsym(handle, "mlx_random_bernoulli"); + mlx_random_bernoulli_ptr = GET_SYM(handle, "mlx_random_bernoulli"); if (mlx_random_bernoulli_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bernoulli\n"); return -1; } - mlx_random_bits_ptr = dlsym(handle, "mlx_random_bits"); + mlx_random_bits_ptr = GET_SYM(handle, "mlx_random_bits"); if (mlx_random_bits_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bits\n"); return -1; } - mlx_random_categorical_shape_ptr = dlsym(handle, "mlx_random_categorical_shape"); + mlx_random_categorical_shape_ptr = GET_SYM(handle, "mlx_random_categorical_shape"); if (mlx_random_categorical_shape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_shape\n"); return -1; } - mlx_random_categorical_num_samples_ptr = dlsym(handle, "mlx_random_categorical_num_samples"); + mlx_random_categorical_num_samples_ptr = GET_SYM(handle, "mlx_random_categorical_num_samples"); if (mlx_random_categorical_num_samples_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_num_samples\n"); return -1; } - mlx_random_categorical_ptr = dlsym(handle, "mlx_random_categorical"); + mlx_random_categorical_ptr = GET_SYM(handle, "mlx_random_categorical"); if (mlx_random_categorical_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical\n"); return -1; } - mlx_random_gumbel_ptr = dlsym(handle, "mlx_random_gumbel"); + mlx_random_gumbel_ptr = GET_SYM(handle, "mlx_random_gumbel"); if (mlx_random_gumbel_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_gumbel\n"); return -1; } - mlx_random_key_ptr = dlsym(handle, "mlx_random_key"); + mlx_random_key_ptr = GET_SYM(handle, "mlx_random_key"); if (mlx_random_key_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_key\n"); return -1; } - mlx_random_laplace_ptr = dlsym(handle, "mlx_random_laplace"); + mlx_random_laplace_ptr = GET_SYM(handle, "mlx_random_laplace"); if (mlx_random_laplace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_laplace\n"); return -1; } - mlx_random_multivariate_normal_ptr = dlsym(handle, "mlx_random_multivariate_normal"); + mlx_random_multivariate_normal_ptr = GET_SYM(handle, "mlx_random_multivariate_normal"); if (mlx_random_multivariate_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_multivariate_normal\n"); return -1; } - mlx_random_normal_broadcast_ptr = dlsym(handle, "mlx_random_normal_broadcast"); + mlx_random_normal_broadcast_ptr = GET_SYM(handle, "mlx_random_normal_broadcast"); if (mlx_random_normal_broadcast_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal_broadcast\n"); return -1; } - mlx_random_normal_ptr = dlsym(handle, "mlx_random_normal"); + mlx_random_normal_ptr = GET_SYM(handle, "mlx_random_normal"); if (mlx_random_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal\n"); return -1; } - mlx_random_permutation_ptr = dlsym(handle, "mlx_random_permutation"); + mlx_random_permutation_ptr = GET_SYM(handle, "mlx_random_permutation"); if (mlx_random_permutation_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation\n"); return -1; } - mlx_random_permutation_arange_ptr = dlsym(handle, "mlx_random_permutation_arange"); + mlx_random_permutation_arange_ptr = GET_SYM(handle, "mlx_random_permutation_arange"); if (mlx_random_permutation_arange_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation_arange\n"); return -1; } - mlx_random_randint_ptr = dlsym(handle, "mlx_random_randint"); + mlx_random_randint_ptr = GET_SYM(handle, "mlx_random_randint"); if (mlx_random_randint_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_randint\n"); return -1; } - mlx_random_seed_ptr = dlsym(handle, "mlx_random_seed"); + mlx_random_seed_ptr = GET_SYM(handle, "mlx_random_seed"); if (mlx_random_seed_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_seed\n"); return -1; } - mlx_random_split_num_ptr = dlsym(handle, "mlx_random_split_num"); + mlx_random_split_num_ptr = GET_SYM(handle, "mlx_random_split_num"); if (mlx_random_split_num_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split_num\n"); return -1; } - mlx_random_split_ptr = dlsym(handle, "mlx_random_split"); + mlx_random_split_ptr = GET_SYM(handle, "mlx_random_split"); if (mlx_random_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split\n"); return -1; } - mlx_random_truncated_normal_ptr = dlsym(handle, "mlx_random_truncated_normal"); + mlx_random_truncated_normal_ptr = GET_SYM(handle, "mlx_random_truncated_normal"); if (mlx_random_truncated_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_truncated_normal\n"); return -1; } - mlx_random_uniform_ptr = dlsym(handle, "mlx_random_uniform"); + mlx_random_uniform_ptr = GET_SYM(handle, "mlx_random_uniform"); if (mlx_random_uniform_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_uniform\n"); return -1; } - mlx_stream_new_ptr = dlsym(handle, "mlx_stream_new"); + mlx_stream_new_ptr = GET_SYM(handle, "mlx_stream_new"); if (mlx_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new\n"); return -1; } - mlx_stream_new_device_ptr = dlsym(handle, "mlx_stream_new_device"); + mlx_stream_new_device_ptr = GET_SYM(handle, "mlx_stream_new_device"); if (mlx_stream_new_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new_device\n"); return -1; } - mlx_stream_set_ptr = dlsym(handle, "mlx_stream_set"); + mlx_stream_set_ptr = GET_SYM(handle, "mlx_stream_set"); if (mlx_stream_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_set\n"); return -1; } - mlx_stream_free_ptr = dlsym(handle, "mlx_stream_free"); + mlx_stream_free_ptr = GET_SYM(handle, "mlx_stream_free"); if (mlx_stream_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_free\n"); return -1; } - mlx_stream_tostring_ptr = dlsym(handle, "mlx_stream_tostring"); + mlx_stream_tostring_ptr = GET_SYM(handle, "mlx_stream_tostring"); if (mlx_stream_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_tostring\n"); return -1; } - mlx_stream_equal_ptr = dlsym(handle, "mlx_stream_equal"); + mlx_stream_equal_ptr = GET_SYM(handle, "mlx_stream_equal"); if (mlx_stream_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_equal\n"); return -1; } - mlx_stream_get_device_ptr = dlsym(handle, "mlx_stream_get_device"); + mlx_stream_get_device_ptr = GET_SYM(handle, "mlx_stream_get_device"); if (mlx_stream_get_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_device\n"); return -1; } - mlx_stream_get_index_ptr = dlsym(handle, "mlx_stream_get_index"); + mlx_stream_get_index_ptr = GET_SYM(handle, "mlx_stream_get_index"); if (mlx_stream_get_index_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_index\n"); return -1; } - mlx_synchronize_ptr = dlsym(handle, "mlx_synchronize"); + mlx_synchronize_ptr = GET_SYM(handle, "mlx_synchronize"); if (mlx_synchronize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_synchronize\n"); return -1; } - mlx_get_default_stream_ptr = dlsym(handle, "mlx_get_default_stream"); + mlx_get_default_stream_ptr = GET_SYM(handle, "mlx_get_default_stream"); if (mlx_get_default_stream_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_stream\n"); return -1; } - mlx_set_default_stream_ptr = dlsym(handle, "mlx_set_default_stream"); + mlx_set_default_stream_ptr = GET_SYM(handle, "mlx_set_default_stream"); if (mlx_set_default_stream_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_stream\n"); return -1; } - mlx_default_cpu_stream_new_ptr = dlsym(handle, "mlx_default_cpu_stream_new"); + mlx_default_cpu_stream_new_ptr = GET_SYM(handle, "mlx_default_cpu_stream_new"); if (mlx_default_cpu_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_default_cpu_stream_new\n"); return -1; } - mlx_default_gpu_stream_new_ptr = dlsym(handle, "mlx_default_gpu_stream_new"); + mlx_default_gpu_stream_new_ptr = GET_SYM(handle, "mlx_default_gpu_stream_new"); if (mlx_default_gpu_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_default_gpu_stream_new\n"); return -1; } - mlx_string_new_ptr = dlsym(handle, "mlx_string_new"); + mlx_string_new_ptr = GET_SYM(handle, "mlx_string_new"); if (mlx_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new\n"); return -1; } - mlx_string_new_data_ptr = dlsym(handle, "mlx_string_new_data"); + mlx_string_new_data_ptr = GET_SYM(handle, "mlx_string_new_data"); if (mlx_string_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new_data\n"); return -1; } - mlx_string_set_ptr = dlsym(handle, "mlx_string_set"); + mlx_string_set_ptr = GET_SYM(handle, "mlx_string_set"); if (mlx_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_set\n"); return -1; } - mlx_string_data_ptr = dlsym(handle, "mlx_string_data"); + mlx_string_data_ptr = GET_SYM(handle, "mlx_string_data"); if (mlx_string_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_data\n"); return -1; } - mlx_string_free_ptr = dlsym(handle, "mlx_string_free"); + mlx_string_free_ptr = GET_SYM(handle, "mlx_string_free"); if (mlx_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_free\n"); return -1; } - mlx_async_eval_ptr = dlsym(handle, "mlx_async_eval"); + mlx_async_eval_ptr = GET_SYM(handle, "mlx_async_eval"); if (mlx_async_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_async_eval\n"); return -1; } - mlx_checkpoint_ptr = dlsym(handle, "mlx_checkpoint"); + mlx_checkpoint_ptr = GET_SYM(handle, "mlx_checkpoint"); if (mlx_checkpoint_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_checkpoint\n"); return -1; } - mlx_custom_function_ptr = dlsym(handle, "mlx_custom_function"); + mlx_custom_function_ptr = GET_SYM(handle, "mlx_custom_function"); if (mlx_custom_function_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_function\n"); return -1; } - mlx_custom_vjp_ptr = dlsym(handle, "mlx_custom_vjp"); + mlx_custom_vjp_ptr = GET_SYM(handle, "mlx_custom_vjp"); if (mlx_custom_vjp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_vjp\n"); return -1; } - mlx_eval_ptr = dlsym(handle, "mlx_eval"); + mlx_eval_ptr = GET_SYM(handle, "mlx_eval"); if (mlx_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_eval\n"); return -1; } - mlx_jvp_ptr = dlsym(handle, "mlx_jvp"); + mlx_jvp_ptr = GET_SYM(handle, "mlx_jvp"); if (mlx_jvp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_jvp\n"); return -1; } - mlx_value_and_grad_ptr = dlsym(handle, "mlx_value_and_grad"); + mlx_value_and_grad_ptr = GET_SYM(handle, "mlx_value_and_grad"); if (mlx_value_and_grad_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_value_and_grad\n"); return -1; } - mlx_vjp_ptr = dlsym(handle, "mlx_vjp"); + mlx_vjp_ptr = GET_SYM(handle, "mlx_vjp"); if (mlx_vjp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vjp\n"); return -1; } - mlx_detail_vmap_replace_ptr = dlsym(handle, "mlx_detail_vmap_replace"); + mlx_detail_vmap_replace_ptr = GET_SYM(handle, "mlx_detail_vmap_replace"); if (mlx_detail_vmap_replace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_replace\n"); return -1; } - mlx_detail_vmap_trace_ptr = dlsym(handle, "mlx_detail_vmap_trace"); + mlx_detail_vmap_trace_ptr = GET_SYM(handle, "mlx_detail_vmap_trace"); if (mlx_detail_vmap_trace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_trace\n"); return -1; } - mlx_vector_array_new_ptr = dlsym(handle, "mlx_vector_array_new"); + mlx_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_array_new"); if (mlx_vector_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new\n"); return -1; } - mlx_vector_array_set_ptr = dlsym(handle, "mlx_vector_array_set"); + mlx_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_array_set"); if (mlx_vector_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set\n"); return -1; } - mlx_vector_array_free_ptr = dlsym(handle, "mlx_vector_array_free"); + mlx_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_array_free"); if (mlx_vector_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_free\n"); return -1; } - mlx_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_array_new_data"); + mlx_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_array_new_data"); if (mlx_vector_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_data\n"); return -1; } - mlx_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_array_new_value"); + mlx_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_array_new_value"); if (mlx_vector_array_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_value\n"); return -1; } - mlx_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_array_set_data"); + mlx_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_array_set_data"); if (mlx_vector_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_data\n"); return -1; } - mlx_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_array_set_value"); + mlx_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_array_set_value"); if (mlx_vector_array_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_value\n"); return -1; } - mlx_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_array_append_data"); + mlx_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_array_append_data"); if (mlx_vector_array_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_data\n"); return -1; } - mlx_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_array_append_value"); + mlx_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_array_append_value"); if (mlx_vector_array_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_value\n"); return -1; } - mlx_vector_array_size_ptr = dlsym(handle, "mlx_vector_array_size"); + mlx_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_array_size"); if (mlx_vector_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_size\n"); return -1; } - mlx_vector_array_get_ptr = dlsym(handle, "mlx_vector_array_get"); + mlx_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_array_get"); if (mlx_vector_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_get\n"); return -1; } - mlx_vector_vector_array_new_ptr = dlsym(handle, "mlx_vector_vector_array_new"); + mlx_vector_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_vector_array_new"); if (mlx_vector_vector_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new\n"); return -1; } - mlx_vector_vector_array_set_ptr = dlsym(handle, "mlx_vector_vector_array_set"); + mlx_vector_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_vector_array_set"); if (mlx_vector_vector_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set\n"); return -1; } - mlx_vector_vector_array_free_ptr = dlsym(handle, "mlx_vector_vector_array_free"); + mlx_vector_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_vector_array_free"); if (mlx_vector_vector_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_free\n"); return -1; } - mlx_vector_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_vector_array_new_data"); + mlx_vector_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_data"); if (mlx_vector_vector_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_data\n"); return -1; } - mlx_vector_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_vector_array_new_value"); + mlx_vector_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_value"); if (mlx_vector_vector_array_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_value\n"); return -1; } - mlx_vector_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_vector_array_set_data"); + mlx_vector_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_data"); if (mlx_vector_vector_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_data\n"); return -1; } - mlx_vector_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_vector_array_set_value"); + mlx_vector_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_value"); if (mlx_vector_vector_array_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_value\n"); return -1; } - mlx_vector_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_vector_array_append_data"); + mlx_vector_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_data"); if (mlx_vector_vector_array_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_data\n"); return -1; } - mlx_vector_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_vector_array_append_value"); + mlx_vector_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_value"); if (mlx_vector_vector_array_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_value\n"); return -1; } - mlx_vector_vector_array_size_ptr = dlsym(handle, "mlx_vector_vector_array_size"); + mlx_vector_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_vector_array_size"); if (mlx_vector_vector_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_size\n"); return -1; } - mlx_vector_vector_array_get_ptr = dlsym(handle, "mlx_vector_vector_array_get"); + mlx_vector_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_vector_array_get"); if (mlx_vector_vector_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_get\n"); return -1; } - mlx_vector_int_new_ptr = dlsym(handle, "mlx_vector_int_new"); + mlx_vector_int_new_ptr = GET_SYM(handle, "mlx_vector_int_new"); if (mlx_vector_int_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new\n"); return -1; } - mlx_vector_int_set_ptr = dlsym(handle, "mlx_vector_int_set"); + mlx_vector_int_set_ptr = GET_SYM(handle, "mlx_vector_int_set"); if (mlx_vector_int_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set\n"); return -1; } - mlx_vector_int_free_ptr = dlsym(handle, "mlx_vector_int_free"); + mlx_vector_int_free_ptr = GET_SYM(handle, "mlx_vector_int_free"); if (mlx_vector_int_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_free\n"); return -1; } - mlx_vector_int_new_data_ptr = dlsym(handle, "mlx_vector_int_new_data"); + mlx_vector_int_new_data_ptr = GET_SYM(handle, "mlx_vector_int_new_data"); if (mlx_vector_int_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_data\n"); return -1; } - mlx_vector_int_new_value_ptr = dlsym(handle, "mlx_vector_int_new_value"); + mlx_vector_int_new_value_ptr = GET_SYM(handle, "mlx_vector_int_new_value"); if (mlx_vector_int_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_value\n"); return -1; } - mlx_vector_int_set_data_ptr = dlsym(handle, "mlx_vector_int_set_data"); + mlx_vector_int_set_data_ptr = GET_SYM(handle, "mlx_vector_int_set_data"); if (mlx_vector_int_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_data\n"); return -1; } - mlx_vector_int_set_value_ptr = dlsym(handle, "mlx_vector_int_set_value"); + mlx_vector_int_set_value_ptr = GET_SYM(handle, "mlx_vector_int_set_value"); if (mlx_vector_int_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_value\n"); return -1; } - mlx_vector_int_append_data_ptr = dlsym(handle, "mlx_vector_int_append_data"); + mlx_vector_int_append_data_ptr = GET_SYM(handle, "mlx_vector_int_append_data"); if (mlx_vector_int_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_data\n"); return -1; } - mlx_vector_int_append_value_ptr = dlsym(handle, "mlx_vector_int_append_value"); + mlx_vector_int_append_value_ptr = GET_SYM(handle, "mlx_vector_int_append_value"); if (mlx_vector_int_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_value\n"); return -1; } - mlx_vector_int_size_ptr = dlsym(handle, "mlx_vector_int_size"); + mlx_vector_int_size_ptr = GET_SYM(handle, "mlx_vector_int_size"); if (mlx_vector_int_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_size\n"); return -1; } - mlx_vector_int_get_ptr = dlsym(handle, "mlx_vector_int_get"); + mlx_vector_int_get_ptr = GET_SYM(handle, "mlx_vector_int_get"); if (mlx_vector_int_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_get\n"); return -1; } - mlx_vector_string_new_ptr = dlsym(handle, "mlx_vector_string_new"); + mlx_vector_string_new_ptr = GET_SYM(handle, "mlx_vector_string_new"); if (mlx_vector_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new\n"); return -1; } - mlx_vector_string_set_ptr = dlsym(handle, "mlx_vector_string_set"); + mlx_vector_string_set_ptr = GET_SYM(handle, "mlx_vector_string_set"); if (mlx_vector_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set\n"); return -1; } - mlx_vector_string_free_ptr = dlsym(handle, "mlx_vector_string_free"); + mlx_vector_string_free_ptr = GET_SYM(handle, "mlx_vector_string_free"); if (mlx_vector_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_free\n"); return -1; } - mlx_vector_string_new_data_ptr = dlsym(handle, "mlx_vector_string_new_data"); + mlx_vector_string_new_data_ptr = GET_SYM(handle, "mlx_vector_string_new_data"); if (mlx_vector_string_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_data\n"); return -1; } - mlx_vector_string_new_value_ptr = dlsym(handle, "mlx_vector_string_new_value"); + mlx_vector_string_new_value_ptr = GET_SYM(handle, "mlx_vector_string_new_value"); if (mlx_vector_string_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_value\n"); return -1; } - mlx_vector_string_set_data_ptr = dlsym(handle, "mlx_vector_string_set_data"); + mlx_vector_string_set_data_ptr = GET_SYM(handle, "mlx_vector_string_set_data"); if (mlx_vector_string_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_data\n"); return -1; } - mlx_vector_string_set_value_ptr = dlsym(handle, "mlx_vector_string_set_value"); + mlx_vector_string_set_value_ptr = GET_SYM(handle, "mlx_vector_string_set_value"); if (mlx_vector_string_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_value\n"); return -1; } - mlx_vector_string_append_data_ptr = dlsym(handle, "mlx_vector_string_append_data"); + mlx_vector_string_append_data_ptr = GET_SYM(handle, "mlx_vector_string_append_data"); if (mlx_vector_string_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_data\n"); return -1; } - mlx_vector_string_append_value_ptr = dlsym(handle, "mlx_vector_string_append_value"); + mlx_vector_string_append_value_ptr = GET_SYM(handle, "mlx_vector_string_append_value"); if (mlx_vector_string_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_value\n"); return -1; } - mlx_vector_string_size_ptr = dlsym(handle, "mlx_vector_string_size"); + mlx_vector_string_size_ptr = GET_SYM(handle, "mlx_vector_string_size"); if (mlx_vector_string_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_size\n"); return -1; } - mlx_vector_string_get_ptr = dlsym(handle, "mlx_vector_string_get"); + mlx_vector_string_get_ptr = GET_SYM(handle, "mlx_vector_string_get"); if (mlx_vector_string_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_get\n"); return -1; } - mlx_version_ptr = dlsym(handle, "mlx_version"); + mlx_version_ptr = GET_SYM(handle, "mlx_version"); if (mlx_version_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_version\n"); return -1; diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go index cf3e51572..b529b9088 100644 --- a/x/imagegen/mlx/mlx.go +++ b/x/imagegen/mlx/mlx.go @@ -1,9 +1,7 @@ -//go:build mlx - package mlx /* -#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR} +#cgo CFLAGS: -O3 -I${SRCDIR}/../../mlxrunner/mlx/include -I${SRCDIR} #cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate #cgo linux LDFLAGS: -lstdc++ -ldl #cgo windows LDFLAGS: -lstdc++ @@ -32,7 +30,7 @@ static inline void set_default_stream(mlx_stream s) { _default_stream = s; } -// CPU stream for file loading (Load primitive only runs on CPU) +// CPU stream for operations that only support CPU evaluation static inline mlx_stream cpu_stream() { if (_cpu_stream.ctx == NULL) { _cpu_stream = mlx_default_cpu_stream_new(); @@ -45,8 +43,11 @@ static inline mlx_stream cpu_stream() { // nocallback: function won't call back into Go */ import "C" + import ( "fmt" + "os" + "path/filepath" "reflect" "runtime" "sync" @@ -1502,15 +1503,21 @@ type SafetensorsFile struct { metadata C.mlx_map_string_to_string } -// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader -// Note: Uses CPU stream because Load primitive only runs on CPU +// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader. +// On CUDA, Load::eval_gpu is implemented so we use the default (GPU) stream. +// On Metal, Load::eval_gpu is not implemented so we must use the CPU stream. func LoadSafetensorsNative(path string) (*SafetensorsFile, error) { cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) + stream := C.default_stream() + if runtime.GOOS == "darwin" { + stream = C.cpu_stream() + } + var arrays C.mlx_map_string_to_array var metadata C.mlx_map_string_to_string - if C.mlx_load_safetensors(&arrays, &metadata, cPath, C.cpu_stream()) != 0 { + if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 { return nil, fmt.Errorf("failed to load safetensors: %s", path) } return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil @@ -1689,11 +1696,91 @@ func ArgmaxKeepArray(logits *Array) *Array { // // Thread safety: Protected by randomStateMu, mimicking Python's GIL behavior. // All random functions that use global state acquire this lock. -var RandomState = []*Array{nil} -var randomStateMu sync.Mutex +var ( + RandomState = []*Array{nil} + randomStateMu sync.Mutex +) -var mlxInitialized bool -var mlxInitError error +var ( + mlxInitialized bool + mlxInitError error +) + +// mlxLibName returns the platform-specific shared library filename. +func mlxLibName() string { + switch runtime.GOOS { + case "windows": + return "mlxc.dll" + case "darwin": + return "libmlxc.dylib" + default: + return "libmlxc.so" + } +} + +// findMLXLibrary searches for the MLX shared library in standard locations. +// Returns the path to the library, or empty string if not found. +func findMLXLibrary() string { + libName := mlxLibName() + + // 1. OLLAMA_LIBRARY_PATH — check each dir and mlx_* subdirs + if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok { + for _, dir := range filepath.SplitList(paths) { + candidate := filepath.Join(dir, libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx*")); err == nil { + for _, mlxDir := range mlxDirs { + candidate = filepath.Join(mlxDir, libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + } + } + } + } + + // 2. Executable directory and lib/ollama/mlx* subdirs + if exe, err := os.Executable(); err == nil { + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + exeDir := filepath.Dir(exe) + + // Check exe dir directly (macOS copies dylib here) + candidate := filepath.Join(exeDir, libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + + // Check exe_dir/lib/ollama/mlx* subdirectories + // and exe_dir/../lib/ollama/mlx* (standard bin/lib sibling layout) + for _, libOllamaDir := range []string{ + filepath.Join(exeDir, "lib", "ollama"), + filepath.Join(exeDir, "..", "lib", "ollama"), + } { + if mlxDirs, err := filepath.Glob(filepath.Join(libOllamaDir, "mlx*")); err == nil { + for _, mlxDir := range mlxDirs { + candidate = filepath.Join(mlxDir, libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + } + } + } + } + + // 3. Build directory (for tests run from repo root) + if cwd, err := os.Getwd(); err == nil { + candidate := filepath.Join(cwd, "build", "lib", "ollama", libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + } + + return "" +} // InitMLX initializes the MLX library by dynamically loading libmlxc. // This must be called before using any MLX functions. @@ -1703,9 +1790,16 @@ func InitMLX() error { return mlxInitError } - // Try to load the MLX dynamic library - ret := C.mlx_dynamic_init() - if ret != 0 { + // Search for the library using Go path discovery + libPath := findMLXLibrary() + if libPath == "" { + mlxInitError = fmt.Errorf("failed to initialize MLX: %s not found", mlxLibName()) + return mlxInitError + } + + cPath := C.CString(libPath) + defer C.free(unsafe.Pointer(cPath)) + if C.mlx_dynamic_init_path(cPath) != 0 { errMsg := C.GoString(C.mlx_dynamic_error()) mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg) return mlxInitError @@ -1713,8 +1807,7 @@ func InitMLX() error { // Initialize all function pointers via dlsym handle := C.mlx_get_handle() - ret = C.mlx_load_functions(handle) - if ret != 0 { + if C.mlx_load_functions(handle) != 0 { mlxInitError = fmt.Errorf("failed to load MLX function symbols") return mlxInitError } diff --git a/x/imagegen/mlx/mlx_dynamic.c b/x/imagegen/mlx/mlx_dynamic.c index aedef7a01..1281ec0e3 100644 --- a/x/imagegen/mlx/mlx_dynamic.c +++ b/x/imagegen/mlx/mlx_dynamic.c @@ -9,114 +9,76 @@ #ifdef _WIN32 #include typedef HMODULE lib_handle_t; -#define LOAD_LIB(path) LoadLibraryA(path) -#define GET_SYMBOL(handle, name) GetProcAddress(handle, name) -#define CLOSE_LIB(handle) FreeLibrary(handle) -#define LIB_ERROR() "LoadLibrary failed" +static char win_error_buffer[256] = {0}; +static const char* get_win_error(void) { + DWORD err = GetLastError(); + snprintf(win_error_buffer, sizeof(win_error_buffer), "error code %lu", err); + return win_error_buffer; +} +#define LIB_ERROR() get_win_error() #else #include typedef void* lib_handle_t; -#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL) -#define GET_SYMBOL(handle, name) dlsym(handle, name) -#define CLOSE_LIB(handle) dlclose(handle) #define LIB_ERROR() dlerror() -#ifdef __APPLE__ -#include -#include -#endif #endif static lib_handle_t mlx_handle = NULL; static int mlx_initialized = 0; static char mlx_error_buffer[512] = {0}; -#ifdef __APPLE__ -// Get path to library in same directory as executable -static char* get_exe_relative_path(const char* libname) { - static char path[1024]; - uint32_t size = sizeof(path); - if (_NSGetExecutablePath(path, &size) != 0) { - return NULL; +#ifdef _WIN32 +// Windows: Load library from a path with dependency resolution. +// Temporarily adds the library's directory to the DLL search path +// so that dependencies (like mlx.dll) in the same directory are found. +static int try_load_win(const char* path) { + if (!path) return 0; + + // Extract directory and add to DLL search path for dependency resolution + char dir_path[MAX_PATH]; + strncpy(dir_path, path, MAX_PATH - 1); + dir_path[MAX_PATH - 1] = '\0'; + char* last_slash = strrchr(dir_path, '\\'); + if (!last_slash) last_slash = strrchr(dir_path, '/'); + if (last_slash) { + *last_slash = '\0'; + SetDllDirectoryA(dir_path); } - // Get directory of executable - char* dir = dirname(path); - static char fullpath[1024]; - snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname); - return fullpath; + + mlx_handle = LoadLibraryA(path); + SetDllDirectoryA(NULL); + return mlx_handle != NULL; } #endif // Try to load library from a specific path static int try_load_lib(const char* path) { if (!path) return 0; - mlx_handle = LOAD_LIB(path); +#ifdef _WIN32 + return try_load_win(path); +#else + mlx_handle = dlopen(path, RTLD_LAZY | RTLD_GLOBAL); return mlx_handle != NULL; +#endif } -// Initialize MLX dynamic library -// Returns 0 on success, -1 on failure -// On failure, call mlx_dynamic_error() to get error message -int mlx_dynamic_init(void) { +// Initialize the MLX dynamic library from a specific path. +// Returns 0 on success, -1 on failure. +int mlx_dynamic_init_path(const char* path) { if (mlx_initialized) { - return 0; // Already initialized + return 0; } - const char* lib_path = NULL; - const char* tried_paths[8] = {0}; - int num_tried = 0; - -#ifdef _WIN32 - // Windows: try same directory as executable - lib_path = "libmlxc.dll"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; -#elif defined(__APPLE__) - // macOS: try executable directory first - lib_path = get_exe_relative_path("libmlxc.dylib"); - if (lib_path) { - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; + if (try_load_lib(path)) { + mlx_initialized = 1; + snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), + "MLX: Successfully loaded %s", path ? path : "library"); + return 0; } - // Try build directory (for tests run from repo root) - lib_path = "./build/lib/ollama/libmlxc.dylib"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; - // Fallback to system paths - lib_path = "libmlxc.dylib"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; -#else - // Linux: try build directory first (for tests) - lib_path = "./build/lib/ollama/libmlxc.so"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; - // Fallback to system paths - lib_path = "libmlxc.so"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; -#endif - // Failed to load library - build error message with all tried paths - { - const char* err = LIB_ERROR(); - int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), - "MLX: Failed to load libmlxc library. Tried: "); - for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) { - offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset, - "%s%s", i > 0 ? ", " : "", tried_paths[i]); - } - if (err) { - snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset, - ". Last error: %s", err); - } - } - return -1; - -success: - mlx_initialized = 1; + const char* err = LIB_ERROR(); snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), - "MLX: Successfully loaded %s", lib_path ? lib_path : "library"); - return 0; + "MLX: Failed to load %s: %s", path ? path : "(null)", err ? err : "unknown error"); + return -1; } // Get the last error message @@ -124,21 +86,8 @@ const char* mlx_dynamic_error(void) { return mlx_error_buffer; } -// Check if MLX is initialized -int mlx_dynamic_is_initialized(void) { - return mlx_initialized; -} - // Get the library handle (for use by generated wrappers) void* mlx_get_handle(void) { return mlx_handle; } -// Cleanup (optional, called at program exit) -void mlx_dynamic_cleanup(void) { - if (mlx_handle != NULL) { - CLOSE_LIB(mlx_handle); - mlx_handle = NULL; - mlx_initialized = 0; - } -} diff --git a/x/imagegen/mlx/mlx_dynamic.h b/x/imagegen/mlx/mlx_dynamic.h index 9ca1473f9..3f4d0fd74 100644 --- a/x/imagegen/mlx/mlx_dynamic.h +++ b/x/imagegen/mlx/mlx_dynamic.h @@ -6,22 +6,16 @@ extern "C" { #endif -// Initialize the MLX dynamic library +// Initialize the MLX dynamic library from a specific path // Returns 0 on success, -1 on failure -int mlx_dynamic_init(void); +int mlx_dynamic_init_path(const char* path); // Get the last error message from dynamic loading const char* mlx_dynamic_error(void); -// Check if MLX is initialized -int mlx_dynamic_is_initialized(void); - // Get the library handle (for use by generated wrappers) void* mlx_get_handle(void); -// Cleanup resources (optional, for clean shutdown) -void mlx_dynamic_cleanup(void); - #ifdef __cplusplus } #endif diff --git a/x/imagegen/mlx/mlx_test.go b/x/imagegen/mlx/mlx_test.go index 37b3ac63b..82221cab7 100644 --- a/x/imagegen/mlx/mlx_test.go +++ b/x/imagegen/mlx/mlx_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx import ( diff --git a/x/imagegen/models/flux2/flux2.go b/x/imagegen/models/flux2/flux2.go index 894af41f8..908de3c87 100644 --- a/x/imagegen/models/flux2/flux2.go +++ b/x/imagegen/models/flux2/flux2.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package flux2 implements the FLUX.2 Klein diffusion transformer model. // Klein is a 4B parameter distilled model that supports sub-second inference. package flux2 diff --git a/x/imagegen/models/flux2/rope.go b/x/imagegen/models/flux2/rope.go index c349e7010..bc245b7d8 100644 --- a/x/imagegen/models/flux2/rope.go +++ b/x/imagegen/models/flux2/rope.go @@ -1,5 +1,3 @@ -//go:build mlx - package flux2 import ( diff --git a/x/imagegen/models/flux2/scheduler.go b/x/imagegen/models/flux2/scheduler.go index aba3c871f..033ee8c3c 100644 --- a/x/imagegen/models/flux2/scheduler.go +++ b/x/imagegen/models/flux2/scheduler.go @@ -1,5 +1,3 @@ -//go:build mlx - package flux2 import ( diff --git a/x/imagegen/models/flux2/transformer.go b/x/imagegen/models/flux2/transformer.go index 93771a661..3a48f27b7 100644 --- a/x/imagegen/models/flux2/transformer.go +++ b/x/imagegen/models/flux2/transformer.go @@ -1,5 +1,3 @@ -//go:build mlx - package flux2 import ( diff --git a/x/imagegen/models/flux2/vae.go b/x/imagegen/models/flux2/vae.go index 4b09b1ba4..057523d17 100644 --- a/x/imagegen/models/flux2/vae.go +++ b/x/imagegen/models/flux2/vae.go @@ -1,5 +1,3 @@ -//go:build mlx - package flux2 import ( diff --git a/x/imagegen/models/qwen3/text_encoder.go b/x/imagegen/models/qwen3/text_encoder.go index de32bd347..46b5a5e29 100644 --- a/x/imagegen/models/qwen3/text_encoder.go +++ b/x/imagegen/models/qwen3/text_encoder.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models. package qwen3 diff --git a/x/imagegen/models/zimage/scheduler.go b/x/imagegen/models/zimage/scheduler.go index 6c474f5a4..f5e1ccc46 100644 --- a/x/imagegen/models/zimage/scheduler.go +++ b/x/imagegen/models/zimage/scheduler.go @@ -1,5 +1,3 @@ -//go:build mlx - package zimage import ( diff --git a/x/imagegen/models/zimage/text_encoder.go b/x/imagegen/models/zimage/text_encoder.go index 2c2688c31..65d9ab596 100644 --- a/x/imagegen/models/zimage/text_encoder.go +++ b/x/imagegen/models/zimage/text_encoder.go @@ -1,5 +1,3 @@ -//go:build mlx - package zimage import ( diff --git a/x/imagegen/models/zimage/transformer.go b/x/imagegen/models/zimage/transformer.go index 2c42d8c25..b27064507 100644 --- a/x/imagegen/models/zimage/transformer.go +++ b/x/imagegen/models/zimage/transformer.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package zimage implements the Z-Image diffusion transformer model. package zimage diff --git a/x/imagegen/models/zimage/vae.go b/x/imagegen/models/zimage/vae.go index aca2b1bfc..a31ec210b 100644 --- a/x/imagegen/models/zimage/vae.go +++ b/x/imagegen/models/zimage/vae.go @@ -1,5 +1,3 @@ -//go:build mlx - package zimage import ( diff --git a/x/imagegen/models/zimage/zimage.go b/x/imagegen/models/zimage/zimage.go index e7ce8436d..4058819c7 100644 --- a/x/imagegen/models/zimage/zimage.go +++ b/x/imagegen/models/zimage/zimage.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package zimage implements the Z-Image diffusion transformer model. package zimage diff --git a/x/imagegen/nn/nn.go b/x/imagegen/nn/nn.go index d72474358..0a08f05be 100644 --- a/x/imagegen/nn/nn.go +++ b/x/imagegen/nn/nn.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package nn provides neural network layer types. package nn diff --git a/x/imagegen/nn/nn_test.go b/x/imagegen/nn/nn_test.go index 00e69ccb0..dbc5e2b30 100644 --- a/x/imagegen/nn/nn_test.go +++ b/x/imagegen/nn/nn_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package nn import ( diff --git a/x/imagegen/runner.go b/x/imagegen/runner.go index 0409c4bf7..d92b59059 100644 --- a/x/imagegen/runner.go +++ b/x/imagegen/runner.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package imagegen provides a unified MLX runner for both LLM and image generation models. package imagegen diff --git a/x/imagegen/runner_stub.go b/x/imagegen/runner_stub.go deleted file mode 100644 index 866a4408c..000000000 --- a/x/imagegen/runner_stub.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !mlx - -package imagegen - -import "errors" - -// Execute returns an error when not built with MLX support. -func Execute(args []string) error { - return errors.New("MLX runner not available: build with mlx tag") -} diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go index fbd443e05..e6e74929b 100644 --- a/x/imagegen/safetensors/loader.go +++ b/x/imagegen/safetensors/loader.go @@ -1,5 +1,3 @@ -//go:build mlx - package safetensors import ( diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go index 4dbcf59a3..df7b52465 100644 --- a/x/imagegen/safetensors/safetensors.go +++ b/x/imagegen/safetensors/safetensors.go @@ -1,5 +1,3 @@ -//go:build mlx - package safetensors import ( diff --git a/x/imagegen/safetensors/safetensors_test.go b/x/imagegen/safetensors/safetensors_test.go index f00268751..5f3e10d71 100644 --- a/x/imagegen/safetensors/safetensors_test.go +++ b/x/imagegen/safetensors/safetensors_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package safetensors import ( diff --git a/x/imagegen/tokenizer/tokenizer.go b/x/imagegen/tokenizer/tokenizer.go index bf8ff63af..d2f1aac18 100644 --- a/x/imagegen/tokenizer/tokenizer.go +++ b/x/imagegen/tokenizer/tokenizer.go @@ -1,5 +1,3 @@ -//go:build mlx - // tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models // // Based on standard BPE algorithm (Sennrich et al. 2015) with: diff --git a/x/imagegen/tokenizer/tokenizer_test.go b/x/imagegen/tokenizer/tokenizer_test.go index 2ac79ab1e..a72e447c6 100644 --- a/x/imagegen/tokenizer/tokenizer_test.go +++ b/x/imagegen/tokenizer/tokenizer_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/imagegen/vae/tiling.go b/x/imagegen/vae/tiling.go index 1babfef98..fcb5af701 100644 --- a/x/imagegen/vae/tiling.go +++ b/x/imagegen/vae/tiling.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package vae provides shared utilities for VAE (Variational Autoencoder) operations. package vae diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index a9ff8904c..0216ffeaa 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 7d0d0b060..a452fbcb2 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -1,5 +1,3 @@ -//go:build mlx - package cache import ( diff --git a/x/mlxrunner/cache/recurrent.go b/x/mlxrunner/cache/recurrent.go index 0cbbc01e2..86c592be5 100644 --- a/x/mlxrunner/cache/recurrent.go +++ b/x/mlxrunner/cache/recurrent.go @@ -1,5 +1,3 @@ -//go:build mlx - package cache import "github.com/ollama/ollama/x/mlxrunner/mlx" diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index f1a0e4cca..2f18105af 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -72,14 +72,23 @@ func NewClient(modelName string) (*Client, error) { cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port)) cmd.Env = os.Environ() - // On Linux, set LD_LIBRARY_PATH to include MLX library directories - if runtime.GOOS == "linux" { + // Set library path environment variable for MLX libraries + // Linux: LD_LIBRARY_PATH, Windows: PATH + var libPathEnvVar string + switch runtime.GOOS { + case "linux": + libPathEnvVar = "LD_LIBRARY_PATH" + case "windows": + libPathEnvVar = "PATH" + } + + if libPathEnvVar != "" { libraryPaths := []string{ml.LibOllamaPath} if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil { libraryPaths = append(libraryPaths, mlxDirs...) } - if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok { + if existingPath, ok := os.LookupEnv(libPathEnvVar); ok { libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...) } @@ -87,16 +96,20 @@ func NewClient(modelName string) (*Client, error) { found := false for i := range cmd.Env { - if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") { - cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal + envName := cmd.Env[i] + if runtime.GOOS == "windows" { + envName = strings.ToUpper(envName) + } + if strings.HasPrefix(envName, libPathEnvVar+"=") { + cmd.Env[i] = libPathEnvVar + "=" + pathEnvVal found = true break } } if !found { - cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal) + cmd.Env = append(cmd.Env, libPathEnvVar+"="+pathEnvVal) } - slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) + slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal) } c := &Client{ diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go index 9202dac34..6b6394d60 100644 --- a/x/mlxrunner/imports.go +++ b/x/mlxrunner/imports.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( diff --git a/x/mlxrunner/mlx/CMakeLists.txt b/x/mlxrunner/mlx/CMakeLists.txt index 1ca13bdaf..9825c441b 100644 --- a/x/mlxrunner/mlx/CMakeLists.txt +++ b/x/mlxrunner/mlx/CMakeLists.txt @@ -24,3 +24,7 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(mlx-c) + +# Sync vendored headers with fetched version +file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h") +file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/include/mlx/c/") diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go index 3134a127a..ce0e48eda 100644 --- a/x/mlxrunner/mlx/act.go +++ b/x/mlxrunner/mlx/act.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index 6047aacec..de91813fc 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/array_test.go b/x/mlxrunner/mlx/array_test.go index aab5db7ba..bc6a4ca4a 100644 --- a/x/mlxrunner/mlx/array_test.go +++ b/x/mlxrunner/mlx/array_test.go @@ -1,10 +1,16 @@ -//go:build mlx - package mlx import "testing" +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } +} + func TestFromValue(t *testing.T) { + skipIfNoMLX(t) for got, want := range map[*Array]DType{ FromValue(true): DTypeBool, FromValue(false): DTypeBool, @@ -22,6 +28,7 @@ func TestFromValue(t *testing.T) { } func TestFromValues(t *testing.T) { + skipIfNoMLX(t) for got, want := range map[*Array]DType{ FromValues([]bool{true, false, true}, 3): DTypeBool, FromValues([]uint8{1, 2, 3}, 3): DTypeUint8, diff --git a/x/mlxrunner/mlx/dtype.go b/x/mlxrunner/mlx/dtype.go index 95237c792..b0a0ce6c1 100644 --- a/x/mlxrunner/mlx/dtype.go +++ b/x/mlxrunner/mlx/dtype.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/dynamic.go b/x/mlxrunner/mlx/dynamic.go index a1286da59..38f825d24 100644 --- a/x/mlxrunner/mlx/dynamic.go +++ b/x/mlxrunner/mlx/dynamic.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "dynamic.h" @@ -24,10 +22,16 @@ func CheckInit() error { return initError } -// tryLoadFromDir searches a directory for libmlxc.* and tries to load it. +// tryLoadFromDir searches a directory for the mlxc shared library and tries to load it. // Returns true if the library was successfully loaded. func tryLoadFromDir(dir string) bool { - matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*") + // On Windows, MSVC produces mlxc.dll (no lib prefix) + // On Unix, it's libmlxc.so or libmlxc.dylib + pattern := "libmlxc.*" + if runtime.GOOS == "windows" { + pattern = "mlxc.*" + } + matches, err := fs.Glob(os.DirFS(dir), pattern) if err != nil || len(matches) == 0 { return false } @@ -60,7 +64,10 @@ func tryLoadFromDir(dir string) bool { // Returns true if the library was successfully loaded. func tryLoadByName() bool { libraryName := "libmlxc.dylib" - if runtime.GOOS == "linux" { + switch runtime.GOOS { + case "windows": + libraryName = "mlxc.dll" + case "linux": libraryName = "libmlxc.so" } @@ -81,19 +88,25 @@ func tryLoadByName() bool { func init() { switch runtime.GOOS { - case "darwin": + case "darwin", "linux", "windows": - case "windows": default: return } - // Try OLLAMA_LIBRARY_PATH first + // Try OLLAMA_LIBRARY_PATH first, including mlx_* subdirectories if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok { for _, dir := range filepath.SplitList(paths) { if tryLoadFromDir(dir) { return } + if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*")); err == nil { + for _, mlxDir := range mlxDirs { + if tryLoadFromDir(mlxDir) { + return + } + } + } } } @@ -115,12 +128,21 @@ func init() { searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama")) } + // Also scan mlx_* subdirectories within each search dir + var expanded []string for _, dir := range searchDirs { + expanded = append(expanded, dir) + if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*")); err == nil { + expanded = append(expanded, mlxDirs...) + } + } + + for _, dir := range expanded { if tryLoadFromDir(dir) { return } } initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs) - slog.Warn("MLX dynamic library not available", "error", initError) + slog.Debug("MLX dynamic library not available", "error", initError) } diff --git a/x/mlxrunner/mlx/dynamic.h b/x/mlxrunner/mlx/dynamic.h index f93d8fab7..f29825ce6 100644 --- a/x/mlxrunner/mlx/dynamic.h +++ b/x/mlxrunner/mlx/dynamic.h @@ -3,7 +3,7 @@ #ifdef _WIN32 #include -#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol) +#define DLSYM(handle, symbol) (void*)GetProcAddress((HMODULE)(handle.ctx), symbol) #else #include #define DLSYM(handle, symbol) dlsym(handle.ctx, symbol) @@ -23,9 +23,15 @@ typedef uint16_t float16_t; typedef uint16_t bfloat16_t; #endif -#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1 -#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); } -#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_) +// Undef ERROR to avoid conflict with wingdi.h on Windows +#ifdef ERROR +#undef ERROR +#endif +#define MLX_ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1 +#define CHECK(x) if (!(x)) { MLX_ERROR("CHECK failed: " #x); } +#define CHECK_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x); CHECK(x##_) +// OPTIONAL_LOAD: load symbol if available, leave function pointer NULL otherwise +#define OPTIONAL_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x) typedef struct { void* ctx; diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go index 0570840d6..7feca3b1e 100644 --- a/x/mlxrunner/mlx/fast.go +++ b/x/mlxrunner/mlx/fast.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go index 7ace1f6d3..31550cef1 100644 --- a/x/mlxrunner/mlx/gated_delta.go +++ b/x/mlxrunner/mlx/gated_delta.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c index 29d1330af..ecf9d30c8 100644 --- a/x/mlxrunner/mlx/generated.c +++ b/x/mlxrunner/mlx/generated.c @@ -2299,8 +2299,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_array_item_float32); CHECK_LOAD(handle, mlx_array_item_float64); CHECK_LOAD(handle, mlx_array_item_complex64); - CHECK_LOAD(handle, mlx_array_item_float16); - CHECK_LOAD(handle, mlx_array_item_bfloat16); + OPTIONAL_LOAD(handle, mlx_array_item_float16); + OPTIONAL_LOAD(handle, mlx_array_item_bfloat16); CHECK_LOAD(handle, mlx_array_data_bool); CHECK_LOAD(handle, mlx_array_data_uint8); CHECK_LOAD(handle, mlx_array_data_uint16); @@ -2313,8 +2313,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_array_data_float32); CHECK_LOAD(handle, mlx_array_data_float64); CHECK_LOAD(handle, mlx_array_data_complex64); - CHECK_LOAD(handle, mlx_array_data_float16); - CHECK_LOAD(handle, mlx_array_data_bfloat16); + OPTIONAL_LOAD(handle, mlx_array_data_float16); + OPTIONAL_LOAD(handle, mlx_array_data_bfloat16); CHECK_LOAD(handle, _mlx_array_is_available); CHECK_LOAD(handle, _mlx_array_wait); CHECK_LOAD(handle, _mlx_array_is_contiguous); diff --git a/x/mlxrunner/mlx/generator/generated.c.gotmpl b/x/mlxrunner/mlx/generator/generated.c.gotmpl index c31b34a76..227589aa8 100644 --- a/x/mlxrunner/mlx/generator/generated.c.gotmpl +++ b/x/mlxrunner/mlx/generator/generated.c.gotmpl @@ -11,7 +11,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { {{- range .Functions }} - CHECK_LOAD(handle, {{ .Name }}); + {{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }}); {{- end }} return 0; } diff --git a/x/mlxrunner/mlx/generator/main.go b/x/mlxrunner/mlx/generator/main.go index a98046a2f..d1203add4 100644 --- a/x/mlxrunner/mlx/generator/main.go +++ b/x/mlxrunner/mlx/generator/main.go @@ -17,11 +17,21 @@ import ( //go:embed *.gotmpl var fsys embed.FS +// optionalSymbols lists symbols that may not be present in all builds +// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX). +var optionalSymbols = map[string]bool{ + "mlx_array_item_float16": true, + "mlx_array_item_bfloat16": true, + "mlx_array_data_float16": true, + "mlx_array_data_bfloat16": true, +} + type Function struct { Type, Name, Parameters, Args string + Optional bool } func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function { @@ -104,7 +114,9 @@ func main() { matches := qc.Matches(query, tree.RootNode(), bts) for match := matches.Next(); match != nil; match = matches.Next() { for _, capture := range match.Captures { - funs = append(funs, ParseFunction(&capture.Node, tc, bts)) + fn := ParseFunction(&capture.Node, tc, bts) + fn.Optional = optionalSymbols[fn.Name] + funs = append(funs, fn) } } } diff --git a/x/mlxrunner/mlx/include/mlx/c/README.md b/x/mlxrunner/mlx/include/mlx/c/README.md new file mode 100644 index 000000000..905ca451c --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/README.md @@ -0,0 +1,12 @@ +# Vendored MLX-C Headers + +These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c). +The pinned version is in `MLX_VERSION` at the repo root. + +Headers are automatically refreshed when you run a CMake build: + +```shell +cmake --preset 'MLX CUDA 13' +``` + +See the [MLX Engine](../../../../../../../docs/development.md#mlx-engine-optional) section of the development docs for full build instructions. diff --git a/x/mlxrunner/mlx/include/mlx/c/array.h b/x/mlxrunner/mlx/include/mlx/c/array.h new file mode 100644 index 000000000..a3b382bb2 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/array.h @@ -0,0 +1,420 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ARRAY_H +#define MLX_ARRAY_H + +#include "mlx/c/string.h" + +#include +#include +#include +#include + +// Complex number support +#ifdef _MSC_VER +#define _CRT_USE_C_COMPLEX_H +#include +typedef _Fcomplex mlx_complex64_t; +#else +#include +typedef float _Complex mlx_complex64_t; +#endif + +#include "half.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_array Array + * MLX N-dimensional array object. + */ +/**@{*/ + +/** + * A N-dimensional array object. + */ +typedef struct mlx_array_ { + void* ctx; +} mlx_array; + +static mlx_array mlx_array_empty; + +/** + * Array element type. + */ +typedef enum mlx_dtype_ { + MLX_BOOL, + MLX_UINT8, + MLX_UINT16, + MLX_UINT32, + MLX_UINT64, + MLX_INT8, + MLX_INT16, + MLX_INT32, + MLX_INT64, + MLX_FLOAT16, + MLX_FLOAT32, + MLX_FLOAT64, + MLX_BFLOAT16, + MLX_COMPLEX64, +} mlx_dtype; + +/** + * Size of given mlx_dtype datatype in bytes. + */ +size_t mlx_dtype_size(mlx_dtype dtype); + +/** + * Get array description. + */ +int mlx_array_tostring(mlx_string* str, const mlx_array arr); + +/** + * New empty array. + */ +mlx_array mlx_array_new(void); + +/** + * Free an array. + */ +int mlx_array_free(mlx_array arr); + +/** + * New array from a bool scalar. + */ +mlx_array mlx_array_new_bool(bool val); +/** + * New array from a int scalar. + */ +mlx_array mlx_array_new_int(int val); +/** + * New array from a float32 scalar. + */ +mlx_array mlx_array_new_float32(float val); +/** + * New array from a float scalar. + * Same as float32. + */ +mlx_array mlx_array_new_float(float val); +/** + * New array from a float64 scalar. + */ +mlx_array mlx_array_new_float64(double val); +/** + * New array from a double scalar. + * Same as float64. + */ +mlx_array mlx_array_new_double(double val); +/** + * New array from a complex scalar. + */ +mlx_array mlx_array_new_complex(float real_val, float imag_val); +/** + * New array from existing buffer. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + */ +mlx_array mlx_array_new_data( + const void* data, + const int* shape, + int dim, + mlx_dtype dtype); +/** + * New array from existing buffer. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + * @param dtor Callback for when the buffer is no longer needed. + */ +mlx_array mlx_array_new_data_managed( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void (*dtor)(void*)); +/** + * New array from existing buffer. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + * @param payload Payload pointer passed to the `dtor` callback instead of + * `data`. + * @param dtor Callback for when the buffer is no longer needed. + */ +mlx_array mlx_array_new_data_managed_payload( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void* payload, + void (*dtor)(void*)); +/** + * Set array to provided src array. + */ +int mlx_array_set(mlx_array* arr, const mlx_array src); +/** + * Set array to a bool scalar. + */ +int mlx_array_set_bool(mlx_array* arr, bool val); +/** + * Set array to a int scalar. + */ +int mlx_array_set_int(mlx_array* arr, int val); +/** + * Set array to a float32 scalar. + */ +int mlx_array_set_float32(mlx_array* arr, float val); +/** + * Set array to a float scalar. + */ +int mlx_array_set_float(mlx_array* arr, float val); +/** + * Set array to a float64 scalar. + */ +int mlx_array_set_float64(mlx_array* arr, double val); +/** + * Set array to a double scalar. + */ +int mlx_array_set_double(mlx_array* arr, double val); +/** + * Set array to a complex scalar. + */ +int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val); +/** + * Set array to specified data and shape. + * @param arr Destination array. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + */ +int mlx_array_set_data( + mlx_array* arr, + const void* data, + const int* shape, + int dim, + mlx_dtype dtype); + +/** + * The size of the array's datatype in bytes. + */ +size_t mlx_array_itemsize(const mlx_array arr); +/** + * Number of elements in the array. + */ +size_t mlx_array_size(const mlx_array arr); +/** + * The number of bytes in the array. + */ +size_t mlx_array_nbytes(const mlx_array arr); +/** + * The array's dimension. + */ +size_t mlx_array_ndim(const mlx_array arr); +/** + * The shape of the array. + * Returns: a pointer to the sizes of each dimension. + */ +const int* mlx_array_shape(const mlx_array arr); +/** + * The strides of the array. + * Returns: a pointer to the sizes of each dimension. + */ +const size_t* mlx_array_strides(const mlx_array arr); +/** + * The shape of the array in a particular dimension. + */ +int mlx_array_dim(const mlx_array arr, int dim); +/** + * The array element type. + */ +mlx_dtype mlx_array_dtype(const mlx_array arr); + +/** + * Evaluate the array. + */ +int mlx_array_eval(mlx_array arr); + +/** + * Access the value of a scalar array. + */ +int mlx_array_item_bool(bool* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint8(uint8_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint16(uint16_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint32(uint32_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint64(uint64_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int8(int8_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int16(int16_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int32(int32_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int64(int64_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float32(float* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float64(double* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr); + +#ifdef HAS_FLOAT16 +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float16(float16_t* res, const mlx_array arr); +#endif + +#ifdef HAS_BFLOAT16 +/** + * Access the value of a scalar array. + */ +int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr); +#endif + +/** + * Returns a pointer to the array data, cast to `bool*`. + * Array must be evaluated, otherwise returns NULL. + */ +const bool* mlx_array_data_bool(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint8_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint8_t* mlx_array_data_uint8(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint16_t* mlx_array_data_uint16(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint32_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint32_t* mlx_array_data_uint32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint64_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint64_t* mlx_array_data_uint64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int8_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int8_t* mlx_array_data_int8(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int16_t* mlx_array_data_int16(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int32_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int32_t* mlx_array_data_int32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int64_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int64_t* mlx_array_data_int64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `float32*`. + * Array must be evaluated, otherwise returns NULL. + */ +const float* mlx_array_data_float32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `float64*`. + * Array must be evaluated, otherwise returns NULL. + */ +const double* mlx_array_data_float64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `_Complex*`. + * Array must be evaluated, otherwise returns NULL. + */ +const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr); + +#ifdef HAS_FLOAT16 +/** + * Returns a pointer to the array data, cast to `float16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const float16_t* mlx_array_data_float16(const mlx_array arr); +#endif + +#ifdef HAS_BFLOAT16 +/** + * Returns a pointer to the array data, cast to `bfloat16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr); +#endif + +/** + * Check if the array is available. + * Internal function: use at your own risk. + */ +int _mlx_array_is_available(bool* res, const mlx_array arr); + +/** + * Wait on the array to be available. After this `_mlx_array_is_available` + * returns `true`. Internal function: use at your own risk. + */ +int _mlx_array_wait(const mlx_array arr); + +/** + * Whether the array is contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_contiguous(bool* res, const mlx_array arr); + +/** + * Whether the array's rows are contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr); + +/** + * Whether the array's columns are contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/closure.h b/x/mlxrunner/mlx/include/mlx/c/closure.h new file mode 100644 index 000000000..33f711572 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/closure.h @@ -0,0 +1,197 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_CLOSURE_H +#define MLX_CLOSURE_H + +#include "mlx/c/array.h" +#include "mlx/c/map.h" +#include "mlx/c/optional.h" +#include "mlx/c/stream.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_closure Closures + * MLX closure objects. + */ +/**@{*/ + +typedef struct mlx_closure_ { + void* ctx; +} mlx_closure; +mlx_closure mlx_closure_new(void); +int mlx_closure_free(mlx_closure cls); +mlx_closure mlx_closure_new_func( + int (*fun)(mlx_vector_array*, const mlx_vector_array)); +mlx_closure mlx_closure_new_func_payload( + int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_set(mlx_closure* cls, const mlx_closure src); +int mlx_closure_apply( + mlx_vector_array* res, + mlx_closure cls, + const mlx_vector_array input); + +mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)); + +typedef struct mlx_closure_kwargs_ { + void* ctx; +} mlx_closure_kwargs; +mlx_closure_kwargs mlx_closure_kwargs_new(void); +int mlx_closure_kwargs_free(mlx_closure_kwargs cls); +mlx_closure_kwargs mlx_closure_kwargs_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)); +mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_kwargs_set( + mlx_closure_kwargs* cls, + const mlx_closure_kwargs src); +int mlx_closure_kwargs_apply( + mlx_vector_array* res, + mlx_closure_kwargs cls, + const mlx_vector_array input_0, + const mlx_map_string_to_array input_1); + +typedef struct mlx_closure_value_and_grad_ { + void* ctx; +} mlx_closure_value_and_grad; +mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void); +int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls); +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func( + int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)); +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload( + int (*fun)( + mlx_vector_array*, + mlx_vector_array*, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_value_and_grad_set( + mlx_closure_value_and_grad* cls, + const mlx_closure_value_and_grad src); +int mlx_closure_value_and_grad_apply( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + mlx_closure_value_and_grad cls, + const mlx_vector_array input); + +typedef struct mlx_closure_custom_ { + void* ctx; +} mlx_closure_custom; +mlx_closure_custom mlx_closure_custom_new(void); +int mlx_closure_custom_free(mlx_closure_custom cls); +mlx_closure_custom mlx_closure_custom_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)); +mlx_closure_custom mlx_closure_custom_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_set( + mlx_closure_custom* cls, + const mlx_closure_custom src); +int mlx_closure_custom_apply( + mlx_vector_array* res, + mlx_closure_custom cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const mlx_vector_array input_2); + +typedef struct mlx_closure_custom_jvp_ { + void* ctx; +} mlx_closure_custom_jvp; +mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void); +int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls); +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)); +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_jvp_set( + mlx_closure_custom_jvp* cls, + const mlx_closure_custom_jvp src); +int mlx_closure_custom_jvp_apply( + mlx_vector_array* res, + mlx_closure_custom_jvp cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const int* input_2, + size_t input_2_num); + +typedef struct mlx_closure_custom_vmap_ { + void* ctx; +} mlx_closure_custom_vmap; +mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void); +int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls); +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)); +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_vmap_set( + mlx_closure_custom_vmap* cls, + const mlx_closure_custom_vmap src); +int mlx_closure_custom_vmap_apply( + mlx_vector_array* res_0, + mlx_vector_int* res_1, + mlx_closure_custom_vmap cls, + const mlx_vector_array input_0, + const int* input_1, + size_t input_1_num); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/compile.h b/x/mlxrunner/mlx/include/mlx/c/compile.h new file mode 100644 index 000000000..04567fb3a --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/compile.h @@ -0,0 +1,57 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_COMPILE_H +#define MLX_COMPILE_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup compile Compilation operations + */ +/**@{*/ + +typedef enum mlx_compile_mode_ { + MLX_COMPILE_MODE_DISABLED, + MLX_COMPILE_MODE_NO_SIMPLIFY, + MLX_COMPILE_MODE_NO_FUSE, + MLX_COMPILE_MODE_ENABLED +} mlx_compile_mode; +int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless); +int mlx_detail_compile( + mlx_closure* res, + const mlx_closure fun, + uintptr_t fun_id, + bool shapeless, + const uint64_t* constants, + size_t constants_num); +int mlx_detail_compile_clear_cache(void); +int mlx_detail_compile_erase(uintptr_t fun_id); +int mlx_disable_compile(void); +int mlx_enable_compile(void); +int mlx_set_compile_mode(mlx_compile_mode mode); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/cuda.h b/x/mlxrunner/mlx/include/mlx/c/cuda.h new file mode 100644 index 000000000..4734f8c51 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/cuda.h @@ -0,0 +1,39 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_CUDA_H +#define MLX_CUDA_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup cuda Cuda specific operations + */ +/**@{*/ + +int mlx_cuda_is_available(bool* res); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/device.h b/x/mlxrunner/mlx/include/mlx/c/device.h new file mode 100644 index 000000000..4b74e39d3 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/device.h @@ -0,0 +1,154 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_DEVICE_H +#define MLX_DEVICE_H + +#include +#include + +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_device Device + * MLX device object. + */ +/**@{*/ + +/** + * A MLX device object. + */ +typedef struct mlx_device_ { + void* ctx; +} mlx_device; + +/** + * Device type. + */ +typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type; + +/** + * Returns a new empty device. + */ +mlx_device mlx_device_new(void); + +/** + * Returns a new device of specified `type`, with specified `index`. + */ +mlx_device mlx_device_new_type(mlx_device_type type, int index); +/** + * Free a device. + */ +int mlx_device_free(mlx_device dev); +/** + * Set device to provided src device. + */ +int mlx_device_set(mlx_device* dev, const mlx_device src); +/** + * Get device description. + */ +int mlx_device_tostring(mlx_string* str, mlx_device dev); +/** + * Check if devices are the same. + */ +bool mlx_device_equal(mlx_device lhs, mlx_device rhs); +/** + * Returns the index of the device. + */ +int mlx_device_get_index(int* index, mlx_device dev); +/** + * Returns the type of the device. + */ +int mlx_device_get_type(mlx_device_type* type, mlx_device dev); +/** + * Returns the default MLX device. + */ +int mlx_get_default_device(mlx_device* dev); +/** + * Set the default MLX device. + */ +int mlx_set_default_device(mlx_device dev); +/** + * Check if device is available. + */ +int mlx_device_is_available(bool* avail, mlx_device dev); +/** + * Get the number of available devices for a device type. + */ +int mlx_device_count(int* count, mlx_device_type type); + +/** + * A MLX device info object. + * Contains key-value pairs with device properties. + * Keys vary by backend but common keys include: + * - device_name (string): Device name + * - architecture (string): Architecture identifier + * Additional keys may be present depending on the backend. + */ +typedef struct mlx_device_info_ { + void* ctx; +} mlx_device_info; + +/** + * Returns a new empty device info object. + */ +mlx_device_info mlx_device_info_new(void); +/** + * Get device information for a device. + */ +int mlx_device_info_get(mlx_device_info* info, mlx_device dev); +/** + * Free a device info object. + */ +int mlx_device_info_free(mlx_device_info info); +/** + * Check if a key exists in the device info. + * Returns 0 on success, 1 on error. + * Sets *exists to true if the key exists, false otherwise. + */ +int mlx_device_info_has_key( + bool* exists, + mlx_device_info info, + const char* key); +/** + * Check if a value is a string type. + * Returns 0 on success, 1 on error. + * Sets *is_string to true if the value is a string, false if it's a size_t. + */ +int mlx_device_info_is_string( + bool* is_string, + mlx_device_info info, + const char* key); +/** + * Get a string value from device info. + * Returns 0 on success, 1 on error, 2 if key not found or wrong type. + */ +int mlx_device_info_get_string( + const char** value, + mlx_device_info info, + const char* key); +/** + * Get a size_t value from device info. + * Returns 0 on success, 1 on error, 2 if key not found or wrong type. + */ +int mlx_device_info_get_size( + size_t* value, + mlx_device_info info, + const char* key); +/** + * Get all keys from device info. + * Returns 0 on success, 1 on error. + */ +int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/distributed.h b/x/mlxrunner/mlx/include/mlx/c/distributed.h new file mode 100644 index 000000000..c3b0baeee --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/distributed.h @@ -0,0 +1,83 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_DISTRIBUTED_H +#define MLX_DISTRIBUTED_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup distributed Distributed collectives + */ +/**@{*/ + +int mlx_distributed_all_gather( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream S); +int mlx_distributed_all_max( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_all_min( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_all_sum( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_recv( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_recv_like( + mlx_array* res, + const mlx_array x, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_send( + mlx_array* res, + const mlx_array x, + int dst, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_sum_scatter( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/distributed_group.h b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h new file mode 100644 index 000000000..3cfccc806 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h @@ -0,0 +1,58 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_DISTRIBUTED_GROUP_H +#define MLX_DISTRIBUTED_GROUP_H + +#include + +#include "mlx/c/stream.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_distributed_group MLX distributed + */ +/**@{*/ + +/** + * A MLX distributed group object. + */ +typedef struct mlx_distributed_group_ { + void* ctx; +} mlx_distributed_group; + +/** + * Get the rank. + */ +int mlx_distributed_group_rank(mlx_distributed_group group); + +/** + * Get the group size. + */ +int mlx_distributed_group_size(mlx_distributed_group group); + +/** + * Split the group. + */ +mlx_distributed_group +mlx_distributed_group_split(mlx_distributed_group group, int color, int key); + +/** + * Check if distributed is available. + */ +bool mlx_distributed_is_available(void); + +/** + * Initialize distributed. + */ +mlx_distributed_group mlx_distributed_init(bool strict); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/error.h b/x/mlxrunner/mlx/include/mlx/c/error.h new file mode 100644 index 000000000..8c063a403 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/error.h @@ -0,0 +1,41 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ERROR_H +#define MLX_ERROR_H + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_error Error management + */ +/**@{*/ + +typedef void (*mlx_error_handler_func)(const char* msg, void* data); + +/** + * Set the error handler. + */ +void mlx_set_error_handler( + mlx_error_handler_func handler, + void* data, + void (*dtor)(void*)); + +/** + * Throw an error. + */ +void _mlx_error(const char* file, const int line, const char* fmt, ...); + +/** + * Throw an error. Macro which passes file name and line number to _mlx_error(). + */ +#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__) + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/export.h b/x/mlxrunner/mlx/include/mlx/c/export.h new file mode 100644 index 000000000..52cb2835c --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/export.h @@ -0,0 +1,75 @@ +/* Copyright © 2023-2025 Apple Inc. */ + +#ifndef MLX_EXPORT_H +#define MLX_EXPORT_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup export Function serialization + */ +/**@{*/ +int mlx_export_function( + const char* file, + const mlx_closure fun, + const mlx_vector_array args, + bool shapeless); +int mlx_export_function_kwargs( + const char* file, + const mlx_closure_kwargs fun, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs, + bool shapeless); + +typedef struct mlx_function_exporter_ { + void* ctx; +} mlx_function_exporter; +mlx_function_exporter mlx_function_exporter_new( + const char* file, + const mlx_closure fun, + bool shapeless); +int mlx_function_exporter_free(mlx_function_exporter xfunc); +int mlx_function_exporter_apply( + const mlx_function_exporter xfunc, + const mlx_vector_array args); +int mlx_function_exporter_apply_kwargs( + const mlx_function_exporter xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs); + +typedef struct mlx_imported_function_ { + void* ctx; +} mlx_imported_function; +mlx_imported_function mlx_imported_function_new(const char* file); +int mlx_imported_function_free(mlx_imported_function xfunc); +int mlx_imported_function_apply( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args); +int mlx_imported_function_apply_kwargs( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/fast.h b/x/mlxrunner/mlx/include/mlx/c/fast.h new file mode 100644 index 000000000..c825d00e5 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/fast.h @@ -0,0 +1,206 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_FAST_H +#define MLX_FAST_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup fast Fast custom operations + */ +/**@{*/ + +typedef struct mlx_fast_cuda_kernel_config_ { + void* ctx; +} mlx_fast_cuda_kernel_config; +mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void); +void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls); + +int mlx_fast_cuda_kernel_config_add_output_arg( + mlx_fast_cuda_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +int mlx_fast_cuda_kernel_config_set_grid( + mlx_fast_cuda_kernel_config cls, + int grid1, + int grid2, + int grid3); +int mlx_fast_cuda_kernel_config_set_thread_group( + mlx_fast_cuda_kernel_config cls, + int thread1, + int thread2, + int thread3); +int mlx_fast_cuda_kernel_config_set_init_value( + mlx_fast_cuda_kernel_config cls, + float value); +int mlx_fast_cuda_kernel_config_set_verbose( + mlx_fast_cuda_kernel_config cls, + bool verbose); +int mlx_fast_cuda_kernel_config_add_template_arg_dtype( + mlx_fast_cuda_kernel_config cls, + const char* name, + mlx_dtype dtype); +int mlx_fast_cuda_kernel_config_add_template_arg_int( + mlx_fast_cuda_kernel_config cls, + const char* name, + int value); +int mlx_fast_cuda_kernel_config_add_template_arg_bool( + mlx_fast_cuda_kernel_config cls, + const char* name, + bool value); + +typedef struct mlx_fast_cuda_kernel_ { + void* ctx; +} mlx_fast_cuda_kernel; + +mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + int shared_memory); + +void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls); + +int mlx_fast_cuda_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_cuda_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_cuda_kernel_config config, + const mlx_stream stream); + +int mlx_fast_layer_norm( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + const mlx_array bias /* may be null */, + float eps, + const mlx_stream s); + +typedef struct mlx_fast_metal_kernel_config_ { + void* ctx; +} mlx_fast_metal_kernel_config; +mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void); +void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls); + +int mlx_fast_metal_kernel_config_add_output_arg( + mlx_fast_metal_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +int mlx_fast_metal_kernel_config_set_grid( + mlx_fast_metal_kernel_config cls, + int grid1, + int grid2, + int grid3); +int mlx_fast_metal_kernel_config_set_thread_group( + mlx_fast_metal_kernel_config cls, + int thread1, + int thread2, + int thread3); +int mlx_fast_metal_kernel_config_set_init_value( + mlx_fast_metal_kernel_config cls, + float value); +int mlx_fast_metal_kernel_config_set_verbose( + mlx_fast_metal_kernel_config cls, + bool verbose); +int mlx_fast_metal_kernel_config_add_template_arg_dtype( + mlx_fast_metal_kernel_config cls, + const char* name, + mlx_dtype dtype); +int mlx_fast_metal_kernel_config_add_template_arg_int( + mlx_fast_metal_kernel_config cls, + const char* name, + int value); +int mlx_fast_metal_kernel_config_add_template_arg_bool( + mlx_fast_metal_kernel_config cls, + const char* name, + bool value); + +typedef struct mlx_fast_metal_kernel_ { + void* ctx; +} mlx_fast_metal_kernel; + +mlx_fast_metal_kernel mlx_fast_metal_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + bool atomic_outputs); + +void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls); + +int mlx_fast_metal_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_metal_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_metal_kernel_config config, + const mlx_stream stream); + +int mlx_fast_rms_norm( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + float eps, + const mlx_stream s); +int mlx_fast_rope( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + int offset, + const mlx_array freqs /* may be null */, + const mlx_stream s); +int mlx_fast_rope_dynamic( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + const mlx_array offset, + const mlx_array freqs /* may be null */, + const mlx_stream s); +int mlx_fast_scaled_dot_product_attention( + mlx_array* res, + const mlx_array queries, + const mlx_array keys, + const mlx_array values, + float scale, + const char* mask_mode, + const mlx_array mask_arr /* may be null */, + const mlx_array sinks /* may be null */, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/fft.h b/x/mlxrunner/mlx/include/mlx/c/fft.h new file mode 100644 index 000000000..779803e9b --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/fft.h @@ -0,0 +1,138 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_FFT_H +#define MLX_FFT_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup fft FFT operations + */ +/**@{*/ + +int mlx_fft_fft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_fft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_fftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_fftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_ifft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_irfft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_irfft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_irfftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_rfft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_rfft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_rfftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/half.h b/x/mlxrunner/mlx/include/mlx/c/half.h new file mode 100644 index 000000000..958d555f5 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/half.h @@ -0,0 +1,26 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_HALF_H +#define MLX_HALF_H + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__) +#define HAS_FLOAT16 +#include +typedef __fp16 float16_t; +#endif + +#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__) +#define HAS_BFLOAT16 +#include +typedef __bf16 bfloat16_t; +#endif + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/io.h b/x/mlxrunner/mlx/include/mlx/c/io.h new file mode 100644 index 000000000..6eb205c9a --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/io.h @@ -0,0 +1,63 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_IO_H +#define MLX_IO_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup io IO operations + */ +/**@{*/ + +int mlx_load_reader( + mlx_array* res, + mlx_io_reader in_stream, + const mlx_stream s); +int mlx_load(mlx_array* res, const char* file, const mlx_stream s); +int mlx_load_safetensors_reader( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + mlx_io_reader in_stream, + const mlx_stream s); +int mlx_load_safetensors( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + const char* file, + const mlx_stream s); +int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a); +int mlx_save(const char* file, const mlx_array a); +int mlx_save_safetensors_writer( + mlx_io_writer in_stream, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata); +int mlx_save_safetensors( + const char* file, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/io_types.h b/x/mlxrunner/mlx/include/mlx/c/io_types.h new file mode 100644 index 000000000..88349b57c --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/io_types.h @@ -0,0 +1,104 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_IO_TYPES_H +#define MLX_IO_TYPES_H + +#include + +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_io_types IO Types + * MLX IO type objects. + */ +/**@{*/ + +/** + * A MLX IO reader object. + */ +typedef struct mlx_io_reader_ { + void* ctx; +} mlx_io_reader; +/** + * A MLX IO writer object. + */ +typedef struct mlx_io_writer_ { + void* ctx; +} mlx_io_writer; + +/** + * Virtual table for custom IO reader and writer objects. + */ +typedef struct mlx_io_vtable_ { + bool (*is_open)(void*); + bool (*good)(void*); + size_t (*tell)(void*); + void (*seek)(void*, int64_t off, int whence); + void (*read)(void*, char* data, size_t n); + void (*read_at_offset)(void*, char* data, size_t n, size_t off); + void (*write)(void*, const char* data, size_t n); + const char* (*label)(void*); + void (*free)(void*); +} mlx_io_vtable; + +/** + * Returns a new custom IO reader. + * `vtable` operates on user descriptor `desc`. + */ +mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable); + +/** + * Get IO reader user descriptor. + */ +int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io); + +/** + * Get IO reader description. + */ +int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io); + +/** + * Free IO reader. + * + * Note that MLX arrays are lazily evaluated, so the underlying object may + * be not freed right away. The ``free()`` callback from ``mlx_io_vtable`` + * will be called when the underlying object is actually freed. + */ +int mlx_io_reader_free(mlx_io_reader io); + +/** + * Returns a new custom IO writer. + * `vtable` operates on user descriptor `desc`. + */ +mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable); + +/** + * Get IO writer user descriptor. + */ +int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io); + +/** + * Get IO writer description. + */ +int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io); + +/** + * Free IO writer. + * + * Note that MLX arrays are lazily evaluated, so the underlying object may + * be not freed right away. The ``free()`` callback from ``mlx_io_vtable`` + * will be called when the underlying object is actually freed. + */ +int mlx_io_writer_free(mlx_io_writer io); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/linalg.h b/x/mlxrunner/mlx/include/mlx/c/linalg.h new file mode 100644 index 000000000..91d5d661e --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/linalg.h @@ -0,0 +1,128 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_LINALG_H +#define MLX_LINALG_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup linalg Linear algebra operations + */ +/**@{*/ + +int mlx_linalg_cholesky( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +int mlx_linalg_cholesky_inv( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +int mlx_linalg_cross( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s); +int mlx_linalg_eig( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_eigh( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const char* UPLO, + const mlx_stream s); +int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_eigvalsh( + mlx_array* res, + const mlx_array a, + const char* UPLO, + const mlx_stream s); +int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_lu_factor( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_norm( + mlx_array* res, + const mlx_array a, + double ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_norm_matrix( + mlx_array* res, + const mlx_array a, + const char* ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_norm_l2( + mlx_array* res, + const mlx_array a, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_qr( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_solve( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_linalg_solve_triangular( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool upper, + const mlx_stream s); +int mlx_linalg_svd( + mlx_vector_array* res, + const mlx_array a, + bool compute_uv, + const mlx_stream s); +int mlx_linalg_tri_inv( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/map.h b/x/mlxrunner/mlx/include/mlx/c/map.h new file mode 100644 index 000000000..56abe84f1 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/map.h @@ -0,0 +1,149 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_MAP_H +#define MLX_MAP_H + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_map Maps + * MLX map objects. + */ +/**@{*/ + +/** + * A string-to-array map + */ +typedef struct mlx_map_string_to_array_ { + void* ctx; +} mlx_map_string_to_array; + +/** + * Returns a new empty string-to-array map. + */ +mlx_map_string_to_array mlx_map_string_to_array_new(void); +/** + * Set map to provided src map. + */ +int mlx_map_string_to_array_set( + mlx_map_string_to_array* map, + const mlx_map_string_to_array src); +/** + * Free a string-to-array map. + */ +int mlx_map_string_to_array_free(mlx_map_string_to_array map); +/** + * Insert a new `value` at the specified `key` in the map. + */ +int mlx_map_string_to_array_insert( + mlx_map_string_to_array map, + const char* key, + const mlx_array value); +/** + * Returns the value indexed at the specified `key` in the map. + */ +int mlx_map_string_to_array_get( + mlx_array* value, + const mlx_map_string_to_array map, + const char* key); + +/** + * An iterator over a string-to-array map. + */ +typedef struct mlx_map_string_to_array_iterator_ { + void* ctx; + void* map_ctx; +} mlx_map_string_to_array_iterator; +/** + * Returns a new iterator over the given map. + */ +mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new( + mlx_map_string_to_array map); +/** + * Free iterator. + */ +int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it); +/** + * Increment iterator. + */ +int mlx_map_string_to_array_iterator_next( + const char** key, + mlx_array* value, + mlx_map_string_to_array_iterator it); + +/** + * A string-to-string map + */ +typedef struct mlx_map_string_to_string_ { + void* ctx; +} mlx_map_string_to_string; + +/** + * Returns a new empty string-to-string map. + */ +mlx_map_string_to_string mlx_map_string_to_string_new(void); +/** + * Set map to provided src map. + */ +int mlx_map_string_to_string_set( + mlx_map_string_to_string* map, + const mlx_map_string_to_string src); +/** + * Free a string-to-string map. + */ +int mlx_map_string_to_string_free(mlx_map_string_to_string map); +/** + * Insert a new `value` at the specified `key` in the map. + */ +int mlx_map_string_to_string_insert( + mlx_map_string_to_string map, + const char* key, + const char* value); +/** + * Returns the value indexed at the specified `key` in the map. + */ +int mlx_map_string_to_string_get( + const char** value, + const mlx_map_string_to_string map, + const char* key); + +/** + * An iterator over a string-to-string map. + */ +typedef struct mlx_map_string_to_string_iterator_ { + void* ctx; + void* map_ctx; +} mlx_map_string_to_string_iterator; +/** + * Returns a new iterator over the given map. + */ +mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new( + mlx_map_string_to_string map); +/** + * Free iterator. + */ +int mlx_map_string_to_string_iterator_free( + mlx_map_string_to_string_iterator it); +/** + * Increment iterator. + */ +int mlx_map_string_to_string_iterator_next( + const char** key, + const char** value, + mlx_map_string_to_string_iterator it); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/memory.h b/x/mlxrunner/mlx/include/mlx/c/memory.h new file mode 100644 index 000000000..bae9e08ec --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/memory.h @@ -0,0 +1,47 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_MEMORY_H +#define MLX_MEMORY_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup memory Memory operations + */ +/**@{*/ + +int mlx_clear_cache(void); +int mlx_get_active_memory(size_t* res); +int mlx_get_cache_memory(size_t* res); +int mlx_get_memory_limit(size_t* res); +int mlx_get_peak_memory(size_t* res); +int mlx_reset_peak_memory(void); +int mlx_set_cache_limit(size_t* res, size_t limit); +int mlx_set_memory_limit(size_t* res, size_t limit); +int mlx_set_wired_limit(size_t* res, size_t limit); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/metal.h b/x/mlxrunner/mlx/include/mlx/c/metal.h new file mode 100644 index 000000000..5877b224b --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/metal.h @@ -0,0 +1,41 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_METAL_H +#define MLX_METAL_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup metal Metal specific operations + */ +/**@{*/ + +int mlx_metal_is_available(bool* res); +int mlx_metal_start_capture(const char* path); +int mlx_metal_stop_capture(void); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/mlx.h b/x/mlxrunner/mlx/include/mlx/c/mlx.h new file mode 100644 index 000000000..ffadac89a --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/mlx.h @@ -0,0 +1,34 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ALL_H +#define MLX_ALL_H + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/compile.h" +#include "mlx/c/cuda.h" +#include "mlx/c/device.h" +#include "mlx/c/distributed.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/error.h" +#include "mlx/c/export.h" +#include "mlx/c/fast.h" +#include "mlx/c/fft.h" +#include "mlx/c/half.h" +#include "mlx/c/io.h" +#include "mlx/c/io_types.h" +#include "mlx/c/linalg.h" +#include "mlx/c/map.h" +#include "mlx/c/memory.h" +#include "mlx/c/metal.h" +#include "mlx/c/ops.h" +#include "mlx/c/optional.h" +#include "mlx/c/random.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/transforms.h" +#include "mlx/c/transforms_impl.h" +#include "mlx/c/vector.h" +#include "mlx/c/version.h" + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/ops.h b/x/mlxrunner/mlx/include/mlx/c/ops.h new file mode 100644 index 000000000..a1446fb9e --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/ops.h @@ -0,0 +1,1235 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_OPS_H +#define MLX_OPS_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup ops Core array operations + */ +/**@{*/ + +int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_add( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_addmm( + mlx_array* res, + const mlx_array c, + const mlx_array a, + const mlx_array b, + float alpha, + float beta, + const mlx_stream s); +int mlx_all_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_all_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_all( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_allclose( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s); +int mlx_any_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_any_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_any( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_arange( + mlx_array* res, + double start, + double stop, + double step, + mlx_dtype dtype, + const mlx_stream s); +int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arctan2( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_argmax_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_argmax( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_argmin_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_argmin( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_argpartition_axis( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s); +int mlx_argpartition( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s); +int mlx_argsort_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_array_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool equal_nan, + const mlx_stream s); +int mlx_as_strided( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const int64_t* strides, + size_t strides_num, + size_t offset, + const mlx_stream s); +int mlx_astype( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s); +int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_bitwise_and( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_bitwise_or( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_bitwise_xor( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_block_masked_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int block_size, + const mlx_array mask_out /* may be null */, + const mlx_array mask_lhs /* may be null */, + const mlx_array mask_rhs /* may be null */, + const mlx_stream s); +int mlx_broadcast_arrays( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_stream s); +int mlx_broadcast_to( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_clip( + mlx_array* res, + const mlx_array a, + const mlx_array a_min /* may be null */, + const mlx_array a_max /* may be null */, + const mlx_stream s); +int mlx_concatenate_axis( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s); +int mlx_concatenate( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s); +int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_contiguous( + mlx_array* res, + const mlx_array a, + bool allow_col_major, + const mlx_stream s); +int mlx_conv1d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int groups, + const mlx_stream s); +int mlx_conv2d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int groups, + const mlx_stream s); +int mlx_conv3d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int groups, + const mlx_stream s); +int mlx_conv_general( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + const int* stride, + size_t stride_num, + const int* padding_lo, + size_t padding_lo_num, + const int* padding_hi, + size_t padding_hi_num, + const int* kernel_dilation, + size_t kernel_dilation_num, + const int* input_dilation, + size_t input_dilation_num, + int groups, + bool flip, + const mlx_stream s); +int mlx_conv_transpose1d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int output_padding, + int groups, + const mlx_stream s); +int mlx_conv_transpose2d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int output_padding_0, + int output_padding_1, + int groups, + const mlx_stream s); +int mlx_conv_transpose3d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int output_padding_0, + int output_padding_1, + int output_padding_2, + int groups, + const mlx_stream s); +int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cummax( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cummin( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cumprod( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cumsum( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_depends( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array dependencies); +int mlx_dequantize( + mlx_array* res, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + mlx_optional_dtype dtype, + const mlx_stream s); +int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +int mlx_diagonal( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + const mlx_stream s); +int mlx_divide( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_divmod( + mlx_vector_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_einsum( + mlx_array* res, + const char* subscripts, + const mlx_vector_array operands, + const mlx_stream s); +int mlx_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_expand_dims_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_expand_dims( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_eye( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype dtype, + const mlx_stream s); +int mlx_flatten( + mlx_array* res, + const mlx_array a, + int start_axis, + int end_axis, + const mlx_stream s); +int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_floor_divide( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_from_fp8( + mlx_array* res, + const mlx_array x, + mlx_dtype dtype, + const mlx_stream s); +int mlx_full( + mlx_array* res, + const int* shape, + size_t shape_num, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s); +int mlx_full_like( + mlx_array* res, + const mlx_array a, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s); +int mlx_gather( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const int* axes, + size_t axes_num, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s); +int mlx_gather_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s); +int mlx_gather_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool sorted_indices, + const mlx_stream s); +int mlx_gather_qmm( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + bool sorted_indices, + const mlx_stream s); +int mlx_greater( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_greater_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_hadamard_transform( + mlx_array* res, + const mlx_array a, + mlx_optional_float scale, + const mlx_stream s); +int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); +int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_inner( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_isclose( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s); +int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_kron( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_left_shift( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_less( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_less_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_linspace( + mlx_array* res, + double start, + double stop, + int num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_logaddexp( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logcumsumexp( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_logical_and( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_logical_or( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logsumexp_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_logsumexp_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_logsumexp( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_masked_scatter( + mlx_array* res, + const mlx_array a, + const mlx_array mask, + const mlx_array src, + const mlx_stream s); +int mlx_matmul( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_max_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_max_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_max( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_maximum( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_mean_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_mean_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_mean( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_median( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_meshgrid( + mlx_vector_array* res, + const mlx_vector_array arrays, + bool sparse, + const char* indexing, + const mlx_stream s); +int mlx_min_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_min_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_min( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_minimum( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_moveaxis( + mlx_array* res, + const mlx_array a, + int source, + int destination, + const mlx_stream s); +int mlx_multiply( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_nan_to_num( + mlx_array* res, + const mlx_array a, + float nan, + mlx_optional_float posinf, + mlx_optional_float neginf, + const mlx_stream s); +int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_not_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_number_of_elements( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool inverted, + mlx_dtype dtype, + const mlx_stream s); +int mlx_ones( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_outer( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_pad( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const int* low_pad_size, + size_t low_pad_size_num, + const int* high_pad_size, + size_t high_pad_size_num, + const mlx_array pad_value, + const char* mode, + const mlx_stream s); +int mlx_pad_symmetric( + mlx_array* res, + const mlx_array a, + int pad_width, + const mlx_array pad_value, + const char* mode, + const mlx_stream s); +int mlx_partition_axis( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s); +int mlx_partition( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s); +int mlx_power( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_prod_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_prod_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_prod( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_put_along_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s); +int mlx_qqmm( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array w_scales /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_quantize( + mlx_vector_array* res, + const mlx_array w, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_quantized_matmul( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_remainder( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_repeat_axis( + mlx_array* res, + const mlx_array arr, + int repeats, + int axis, + const mlx_stream s); +int mlx_repeat( + mlx_array* res, + const mlx_array arr, + int repeats, + const mlx_stream s); +int mlx_reshape( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_right_shift( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_roll_axis( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + int axis, + const mlx_stream s); +int mlx_roll_axes( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_roll( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const mlx_stream s); +int mlx_round( + mlx_array* res, + const mlx_array a, + int decimals, + const mlx_stream s); +int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_scatter( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_add( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_add_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_add_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s); +int mlx_scatter_max( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_max_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_min( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_min_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_prod( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_prod_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_segmented_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array segments, + const mlx_stream s); +int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_slice( + mlx_array* res, + const mlx_array a, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_dynamic( + mlx_array* res, + const mlx_array a, + const mlx_array start, + const int* axes, + size_t axes_num, + const int* slice_size, + size_t slice_size_num, + const mlx_stream s); +int mlx_slice_update( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_update_dynamic( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const mlx_array start, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_softmax_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool precise, + const mlx_stream s); +int mlx_softmax_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool precise, + const mlx_stream s); +int mlx_softmax( + mlx_array* res, + const mlx_array a, + bool precise, + const mlx_stream s); +int mlx_sort_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_split( + mlx_vector_array* res, + const mlx_array a, + int num_splits, + int axis, + const mlx_stream s); +int mlx_split_sections( + mlx_vector_array* res, + const mlx_array a, + const int* indices, + size_t indices_num, + int axis, + const mlx_stream s); +int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_squeeze_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_squeeze_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_stack_axis( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s); +int mlx_stack( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s); +int mlx_std_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_std_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_std( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_subtract( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_sum_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_sum_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_sum( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_swapaxes( + mlx_array* res, + const mlx_array a, + int axis1, + int axis2, + const mlx_stream s); +int mlx_take_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s); +int mlx_take( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_stream s); +int mlx_take_along_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s); +int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tensordot( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const int* axes_a, + size_t axes_a_num, + const int* axes_b, + size_t axes_b_num, + const mlx_stream s); +int mlx_tensordot_axis( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s); +int mlx_tile( + mlx_array* res, + const mlx_array arr, + const int* reps, + size_t reps_num, + const mlx_stream s); +int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s); +int mlx_topk_axis( + mlx_array* res, + const mlx_array a, + int k, + int axis, + const mlx_stream s); +int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +int mlx_trace( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + mlx_dtype dtype, + const mlx_stream s); +int mlx_transpose_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tri( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype type, + const mlx_stream s); +int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +int mlx_unflatten( + mlx_array* res, + const mlx_array a, + int axis, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_var_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_var_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_var( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_view( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s); +int mlx_where( + mlx_array* res, + const mlx_array condition, + const mlx_array x, + const mlx_array y, + const mlx_stream s); +int mlx_zeros( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/optional.h b/x/mlxrunner/mlx/include/mlx/c/optional.h new file mode 100644 index 000000000..ff9ea14e5 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/optional.h @@ -0,0 +1,51 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_OPTIONAL_H +#define MLX_OPTIONAL_H + +#include + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_optional Optionals + * MLX optional scalars. + */ +/**@{*/ + +/** + * A int optional. + */ +typedef struct mlx_optional_int_ { + int value; + bool has_value; +} mlx_optional_int; + +/** + * A float optional. + */ +typedef struct mlx_optional_float_ { + float value; + bool has_value; +} mlx_optional_float; + +/** + * A dtype optional. + */ +typedef struct mlx_optional_dtype_ { + mlx_dtype value; + bool has_value; +} mlx_optional_dtype; + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/random.h b/x/mlxrunner/mlx/include/mlx/c/random.h new file mode 100644 index 000000000..dbce0be37 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/random.h @@ -0,0 +1,166 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_RANDOM_H +#define MLX_RANDOM_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup random Random number operations + */ +/**@{*/ + +int mlx_random_bernoulli( + mlx_array* res, + const mlx_array p, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_bits( + mlx_array* res, + const int* shape, + size_t shape_num, + int width, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical_shape( + mlx_array* res, + const mlx_array logits, + int axis, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical_num_samples( + mlx_array* res, + const mlx_array logits_, + int axis, + int num_samples, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical( + mlx_array* res, + const mlx_array logits, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_gumbel( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_key(mlx_array* res, uint64_t seed); +int mlx_random_laplace( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_multivariate_normal( + mlx_array* res, + const mlx_array mean, + const mlx_array cov, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_normal_broadcast( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array loc /* may be null */, + const mlx_array scale /* may be null */, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_normal( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_permutation( + mlx_array* res, + const mlx_array x, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_permutation_arange( + mlx_array* res, + int x, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_randint( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_seed(uint64_t seed); +int mlx_random_split_num( + mlx_array* res, + const mlx_array key, + int num, + const mlx_stream s); +int mlx_random_split( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array key, + const mlx_stream s); +int mlx_random_truncated_normal( + mlx_array* res, + const mlx_array lower, + const mlx_array upper, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_uniform( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/stream.h b/x/mlxrunner/mlx/include/mlx/c/stream.h new file mode 100644 index 000000000..d5865b806 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/stream.h @@ -0,0 +1,88 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_STREAM_H +#define MLX_STREAM_H + +#include + +#include "mlx/c/device.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_stream Stream + * MLX stream object. + */ +/**@{*/ + +/** + * A MLX stream object. + */ +typedef struct mlx_stream_ { + void* ctx; +} mlx_stream; + +/** + * Returns a new empty stream. + */ +mlx_stream mlx_stream_new(void); + +/** + * Returns a new stream on a device. + */ +mlx_stream mlx_stream_new_device(mlx_device dev); +/** + * Set stream to provided src stream. + */ +int mlx_stream_set(mlx_stream* stream, const mlx_stream src); +/** + * Free a stream. + */ +int mlx_stream_free(mlx_stream stream); +/** + * Get stream description. + */ +int mlx_stream_tostring(mlx_string* str, mlx_stream stream); +/** + * Check if streams are the same. + */ +bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs); +/** + * Return the device of the stream. + */ +int mlx_stream_get_device(mlx_device* dev, mlx_stream stream); +/** + * Return the index of the stream. + */ +int mlx_stream_get_index(int* index, mlx_stream stream); +/** + * Synchronize with the provided stream. + */ +int mlx_synchronize(mlx_stream stream); +/** + * Returns the default stream on the given device. + */ +int mlx_get_default_stream(mlx_stream* stream, mlx_device dev); +/** + * Set default stream. + */ +int mlx_set_default_stream(mlx_stream stream); +/** + * Returns the current default CPU stream. + */ +mlx_stream mlx_default_cpu_stream_new(void); + +/** + * Returns the current default GPU stream. + */ +mlx_stream mlx_default_gpu_stream_new(void); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/string.h b/x/mlxrunner/mlx/include/mlx/c/string.h new file mode 100644 index 000000000..0d2a356ba --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/string.h @@ -0,0 +1,55 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_STRING_H +#define MLX_STRING_H + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_string String + * MLX string object. + */ +/**@{*/ + +/** + * A MLX string object. + */ +typedef struct mlx_string_ { + void* ctx; +} mlx_string; + +/** + * Returns a new empty string. + */ +mlx_string mlx_string_new(void); + +/** + * Returns a new string, copying contents from `str`, which must end with `\0`. + */ +mlx_string mlx_string_new_data(const char* str); + +/** + * Set string to src string. + */ +int mlx_string_set(mlx_string* str, const mlx_string src); + +/** + * Returns a pointer to the string contents. + * The pointer is valid for the life duration of the string. + */ +const char* mlx_string_data(mlx_string str); + +/** + * Free string. + */ +int mlx_string_free(mlx_string str); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/transforms.h b/x/mlxrunner/mlx/include/mlx/c/transforms.h new file mode 100644 index 000000000..b2434619b --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/transforms.h @@ -0,0 +1,68 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_TRANSFORMS_H +#define MLX_TRANSFORMS_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup transforms Transform operations + */ +/**@{*/ + +int mlx_async_eval(const mlx_vector_array outputs); +int mlx_checkpoint(mlx_closure* res, const mlx_closure fun); +int mlx_custom_function( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp /* may be null */, + const mlx_closure_custom_jvp fun_jvp /* may be null */, + const mlx_closure_custom_vmap fun_vmap /* may be null */); +int mlx_custom_vjp( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp); +int mlx_eval(const mlx_vector_array outputs); +int mlx_jvp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array tangents); +int mlx_value_and_grad( + mlx_closure_value_and_grad* res, + const mlx_closure fun, + const int* argnums, + size_t argnums_num); +int mlx_vjp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array cotangents); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h b/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h new file mode 100644 index 000000000..2b1356ebc --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h @@ -0,0 +1,54 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_TRANSFORMS_IMPL_H +#define MLX_TRANSFORMS_IMPL_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup transforms_impl Implementation detail operations + */ +/**@{*/ + +int mlx_detail_vmap_replace( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num); +int mlx_detail_vmap_trace( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/vector.h b/x/mlxrunner/mlx/include/mlx/c/vector.h new file mode 100644 index 000000000..81bcf7495 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/vector.h @@ -0,0 +1,133 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_VECTOR_H +#define MLX_VECTOR_H + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_vector Vectors + * MLX vector objects. + */ +/**@{*/ + +/** + * A vector of array. + */ +typedef struct mlx_vector_array_ { + void* ctx; +} mlx_vector_array; +mlx_vector_array mlx_vector_array_new(void); +int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src); +int mlx_vector_array_free(mlx_vector_array vec); +mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size); +mlx_vector_array mlx_vector_array_new_value(const mlx_array val); +int mlx_vector_array_set_data( + mlx_vector_array* vec, + const mlx_array* data, + size_t size); +int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val); +int mlx_vector_array_append_data( + mlx_vector_array vec, + const mlx_array* data, + size_t size); +int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val); +size_t mlx_vector_array_size(mlx_vector_array vec); +int mlx_vector_array_get( + mlx_array* res, + const mlx_vector_array vec, + size_t idx); + +/** + * A vector of vector_array. + */ +typedef struct mlx_vector_vector_array_ { + void* ctx; +} mlx_vector_vector_array; +mlx_vector_vector_array mlx_vector_vector_array_new(void); +int mlx_vector_vector_array_set( + mlx_vector_vector_array* vec, + const mlx_vector_vector_array src); +int mlx_vector_vector_array_free(mlx_vector_vector_array vec); +mlx_vector_vector_array mlx_vector_vector_array_new_data( + const mlx_vector_array* data, + size_t size); +mlx_vector_vector_array mlx_vector_vector_array_new_value( + const mlx_vector_array val); +int mlx_vector_vector_array_set_data( + mlx_vector_vector_array* vec, + const mlx_vector_array* data, + size_t size); +int mlx_vector_vector_array_set_value( + mlx_vector_vector_array* vec, + const mlx_vector_array val); +int mlx_vector_vector_array_append_data( + mlx_vector_vector_array vec, + const mlx_vector_array* data, + size_t size); +int mlx_vector_vector_array_append_value( + mlx_vector_vector_array vec, + const mlx_vector_array val); +size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec); +int mlx_vector_vector_array_get( + mlx_vector_array* res, + const mlx_vector_vector_array vec, + size_t idx); + +/** + * A vector of int. + */ +typedef struct mlx_vector_int_ { + void* ctx; +} mlx_vector_int; +mlx_vector_int mlx_vector_int_new(void); +int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src); +int mlx_vector_int_free(mlx_vector_int vec); +mlx_vector_int mlx_vector_int_new_data(int* data, size_t size); +mlx_vector_int mlx_vector_int_new_value(int val); +int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size); +int mlx_vector_int_set_value(mlx_vector_int* vec, int val); +int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size); +int mlx_vector_int_append_value(mlx_vector_int vec, int val); +size_t mlx_vector_int_size(mlx_vector_int vec); +int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx); + +/** + * A vector of string. + */ +typedef struct mlx_vector_string_ { + void* ctx; +} mlx_vector_string; +mlx_vector_string mlx_vector_string_new(void); +int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src); +int mlx_vector_string_free(mlx_vector_string vec); +mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size); +mlx_vector_string mlx_vector_string_new_value(const char* val); +int mlx_vector_string_set_data( + mlx_vector_string* vec, + const char** data, + size_t size); +int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val); +int mlx_vector_string_append_data( + mlx_vector_string vec, + const char** data, + size_t size); +int mlx_vector_string_append_value(mlx_vector_string vec, const char* val); +size_t mlx_vector_string_size(mlx_vector_string vec); +int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/version.h b/x/mlxrunner/mlx/include/mlx/c/version.h new file mode 100644 index 000000000..96dd23877 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/version.h @@ -0,0 +1,18 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_VERSION_H +#define MLX_VERSION_H + +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int mlx_version(mlx_string* str_); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/io.go b/x/mlxrunner/mlx/io.go index 84868e005..0ddbd3b59 100644 --- a/x/mlxrunner/mlx/io.go +++ b/x/mlxrunner/mlx/io.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" @@ -7,6 +5,7 @@ import "C" import ( "iter" + "runtime" "unsafe" ) @@ -21,10 +20,17 @@ func Load(path string) iter.Seq2[string, *Array] { cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) - cpu := C.mlx_default_cpu_stream_new() - defer C.mlx_stream_free(cpu) + // Use GPU stream so tensors load directly to GPU memory (CUDA has Load::eval_gpu). + // macOS Metal doesn't implement eval_gpu for Load, so fall back to CPU stream. + var stream C.mlx_stream + if runtime.GOOS == "darwin" { + stream = C.mlx_default_cpu_stream_new() + } else { + stream = C.mlx_default_gpu_stream_new() + } + defer C.mlx_stream_free(stream) - C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu) + C.mlx_load_safetensors(&string2array, &string2string, cPath, stream) it := C.mlx_map_string_to_array_iterator_new(string2array) defer C.mlx_map_string_to_array_iterator_free(it) diff --git a/x/mlxrunner/mlx/memory.go b/x/mlxrunner/mlx/memory.go index cf36c304c..a243b72c0 100644 --- a/x/mlxrunner/mlx/memory.go +++ b/x/mlxrunner/mlx/memory.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go index 962d192ae..d64a6b0ee 100644 --- a/x/mlxrunner/mlx/mlx.go +++ b/x/mlxrunner/mlx/mlx.go @@ -1,19 +1,22 @@ -//go:build mlx - package mlx -//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release -//go:generate cmake --build build --parallel -//go:generate cmake --install build -//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h" +//go:generate sh -c "go run generator/main.go -output=. ./include/mlx/c/*.h" // #cgo CXXFLAGS: -std=c++17 -// #cgo CPPFLAGS: -I${SRCDIR}/dist/include -// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++ +// #cgo CPPFLAGS: -I${SRCDIR}/include +// #cgo LDFLAGS: -lstdc++ // #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate // #include "generated.h" import "C" +// Version returns the MLX core library version string. +func Version() string { + str := C.mlx_string_new() + defer C.mlx_string_free(str) + C.mlx_version(&str) + return C.GoString(C.mlx_string_data(str)) +} + func doEval(outputs []*Array, async bool) { vector := C.mlx_vector_array_new() defer C.mlx_vector_array_free(vector) diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go index 3d5691368..d3a99a6cd 100644 --- a/x/mlxrunner/mlx/nn.go +++ b/x/mlxrunner/mlx/nn.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx type Linear struct { diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go index 2f97ba8d2..3d8ec7dec 100644 --- a/x/mlxrunner/mlx/ops.go +++ b/x/mlxrunner/mlx/ops.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index 283a2141c..a2e25d68b 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/random.go b/x/mlxrunner/mlx/random.go index 6afdbbab4..82c3d6785 100644 --- a/x/mlxrunner/mlx/random.go +++ b/x/mlxrunner/mlx/random.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/slice.go b/x/mlxrunner/mlx/slice.go index ab1324774..ea642ebf7 100644 --- a/x/mlxrunner/mlx/slice.go +++ b/x/mlxrunner/mlx/slice.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/stream.go b/x/mlxrunner/mlx/stream.go index 83a3eeffd..9b01b4a85 100644 --- a/x/mlxrunner/mlx/stream.go +++ b/x/mlxrunner/mlx/stream.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" @@ -27,6 +25,22 @@ var DefaultDevice = sync.OnceValue(func() Device { return Device{d} }) +// GPUIsAvailable returns true if a GPU device is available. +func GPUIsAvailable() bool { + dev := C.mlx_device_new_type(C.MLX_GPU, 0) + defer C.mlx_device_free(dev) + var avail C.bool + C.mlx_device_is_available(&avail, dev) + return bool(avail) +} + +// SetDefaultDeviceGPU sets the default MLX device to GPU. +func SetDefaultDeviceGPU() { + dev := C.mlx_device_new_type(C.MLX_GPU, 0) + C.mlx_set_default_device(dev) + C.mlx_device_free(dev) +} + type Stream struct { ctx C.mlx_stream } diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go index 4cdf6df33..3a85b6eb0 100644 --- a/x/mlxrunner/model/base/base.go +++ b/x/mlxrunner/model/base/base.go @@ -1,5 +1,3 @@ -//go:build mlx - package base import ( diff --git a/x/mlxrunner/model/base/base_stub.go b/x/mlxrunner/model/base/base_stub.go deleted file mode 100644 index 318d8f911..000000000 --- a/x/mlxrunner/model/base/base_stub.go +++ /dev/null @@ -1,3 +0,0 @@ -//go:build !mlx - -package base diff --git a/x/mlxrunner/model/linear.go b/x/mlxrunner/model/linear.go index fffdbdb29..788e4e3f0 100644 --- a/x/mlxrunner/model/linear.go +++ b/x/mlxrunner/model/linear.go @@ -1,5 +1,3 @@ -//go:build mlx - package model import ( diff --git a/x/mlxrunner/model/quant.go b/x/mlxrunner/model/quant.go index 10896e4b4..d4a56c35c 100644 --- a/x/mlxrunner/model/quant.go +++ b/x/mlxrunner/model/quant.go @@ -1,5 +1,3 @@ -//go:build mlx - package model import ( diff --git a/x/mlxrunner/model/root.go b/x/mlxrunner/model/root.go index c912f7f4c..1c05ee6a8 100644 --- a/x/mlxrunner/model/root.go +++ b/x/mlxrunner/model/root.go @@ -1,5 +1,3 @@ -//go:build mlx - package model import ( diff --git a/x/mlxrunner/model/root_stub.go b/x/mlxrunner/model/root_stub.go deleted file mode 100644 index 3fcda9c25..000000000 --- a/x/mlxrunner/model/root_stub.go +++ /dev/null @@ -1,3 +0,0 @@ -//go:build !mlx - -package model diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 852b04dcc..3ce148c02 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index acaef79bf..08a376d43 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index a25b23d03..df9da7a99 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -1,5 +1,3 @@ -//go:build mlx - package sample import ( diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index 9c7d7e775..a9972bfdc 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( @@ -29,6 +27,13 @@ func Execute(args []string) error { return fmt.Errorf("MLX not available: %w", err) } + if mlx.GPUIsAvailable() { + mlx.SetDefaultDeviceGPU() + slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu") + } else { + slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu") + } + var ( modelName string port int diff --git a/x/mlxrunner/server_stub.go b/x/mlxrunner/server_stub.go deleted file mode 100644 index 3b0f35500..000000000 --- a/x/mlxrunner/server_stub.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !mlx - -package mlxrunner - -import "errors" - -// Execute returns an error when not built with MLX support. -func Execute(args []string) error { - return errors.New("MLX runner not available: build with mlx tag") -} diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index edf66657c..01f0559a0 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package gemma3 provides the Gemma 3 text model implementation for MLX. package gemma3 diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index fb9c4af6f..2e8365580 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX. // This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE). package glm4_moe_lite diff --git a/x/models/glm4_moe_lite/parser.go b/x/models/glm4_moe_lite/parser.go index de1b2cc17..9a9985f68 100644 --- a/x/models/glm4_moe_lite/parser.go +++ b/x/models/glm4_moe_lite/parser.go @@ -1,5 +1,3 @@ -//go:build mlx - package glm4_moe_lite import ( diff --git a/x/models/glm4_moe_lite/parser_test.go b/x/models/glm4_moe_lite/parser_test.go index 0ce382709..d15b4e803 100644 --- a/x/models/glm4_moe_lite/parser_test.go +++ b/x/models/glm4_moe_lite/parser_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package glm4_moe_lite import ( diff --git a/x/models/glm4_moe_lite/render.go b/x/models/glm4_moe_lite/render.go index 4998604bf..d15a99e51 100644 --- a/x/models/glm4_moe_lite/render.go +++ b/x/models/glm4_moe_lite/render.go @@ -1,5 +1,3 @@ -//go:build mlx - package glm4_moe_lite import ( diff --git a/x/models/glm4_moe_lite/render_test.go b/x/models/glm4_moe_lite/render_test.go index f0d576bec..91b871acc 100644 --- a/x/models/glm4_moe_lite/render_test.go +++ b/x/models/glm4_moe_lite/render_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package glm4_moe_lite import ( diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index fc7f34488..18e39bc0a 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package llama provides a Llama-style decoder-only transformer for MLX. package llama diff --git a/x/models/nn/nn.go b/x/models/nn/nn.go index 78f1b92b6..07047024d 100644 --- a/x/models/nn/nn.go +++ b/x/models/nn/nn.go @@ -1,5 +1,3 @@ -//go:build mlx - package nn import "github.com/ollama/ollama/x/mlxrunner/mlx" diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 85d427f58..71596af98 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package qwen3 provides the Qwen3 text model implementation for MLX. package qwen3 diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index fbee82b59..642ea1bba 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package qwen3_5 provides the Qwen 3.5 text and MoE implementation for MLX. package qwen3_5 diff --git a/x/models/qwen3_5/qwen3_5_test.go b/x/models/qwen3_5/qwen3_5_test.go index 0a70da189..8165cd484 100644 --- a/x/models/qwen3_5/qwen3_5_test.go +++ b/x/models/qwen3_5/qwen3_5_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package qwen3_5 import ( diff --git a/x/models/qwen3_5_moe/qwen3_5_moe.go b/x/models/qwen3_5_moe/qwen3_5_moe.go index 9e0be26be..a505b458e 100644 --- a/x/models/qwen3_5_moe/qwen3_5_moe.go +++ b/x/models/qwen3_5_moe/qwen3_5_moe.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases. package qwen3_5_moe diff --git a/x/tokenizer/tokenizer.go b/x/tokenizer/tokenizer.go index 301e51aea..a1ce5e8ee 100644 --- a/x/tokenizer/tokenizer.go +++ b/x/tokenizer/tokenizer.go @@ -1,5 +1,3 @@ -//go:build mlx - // tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models // // Based on standard BPE algorithm (Sennrich et al. 2015) with: diff --git a/x/tokenizer/tokenizer_benchmark_test.go b/x/tokenizer/tokenizer_benchmark_test.go index e65a59786..9f3d2a10e 100644 --- a/x/tokenizer/tokenizer_benchmark_test.go +++ b/x/tokenizer/tokenizer_benchmark_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_bpe.go b/x/tokenizer/tokenizer_bpe.go index 1e625c20a..9037f15bc 100644 --- a/x/tokenizer/tokenizer_bpe.go +++ b/x/tokenizer/tokenizer_bpe.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import "container/heap" diff --git a/x/tokenizer/tokenizer_correctness_test.go b/x/tokenizer/tokenizer_correctness_test.go index 2fe94d279..91adc167d 100644 --- a/x/tokenizer/tokenizer_correctness_test.go +++ b/x/tokenizer/tokenizer_correctness_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_decode.go b/x/tokenizer/tokenizer_decode.go index e02d2a88b..0056948a2 100644 --- a/x/tokenizer/tokenizer_decode.go +++ b/x/tokenizer/tokenizer_decode.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_encode.go b/x/tokenizer/tokenizer_encode.go index 1b71ea6d3..3eb629e56 100644 --- a/x/tokenizer/tokenizer_encode.go +++ b/x/tokenizer/tokenizer_encode.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_ggml_parity_test.go b/x/tokenizer/tokenizer_ggml_parity_test.go index 4cef3d3dd..ee9b68f38 100644 --- a/x/tokenizer/tokenizer_ggml_parity_test.go +++ b/x/tokenizer/tokenizer_ggml_parity_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_load.go b/x/tokenizer/tokenizer_load.go index d2a253e17..efd086628 100644 --- a/x/tokenizer/tokenizer_load.go +++ b/x/tokenizer/tokenizer_load.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_load_test.go b/x/tokenizer/tokenizer_load_test.go index 136399c2e..caf2b0d35 100644 --- a/x/tokenizer/tokenizer_load_test.go +++ b/x/tokenizer/tokenizer_load_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import (