MLX: harden for init failures (#14777)

The CLI now links to the lazy-load MLX code, but that still happens in
init functions.  On internal MLX errors, the CLI exits before it has a
chance to start.  This change re-wires the MLX error handling so it
doesn't exit by default.  The MLX based runners currently expect exits
on failure, so they re-initialize the default error handling.  We can
refine error handling for better go stack traces in the future.
This commit is contained in:
Daniel Hiltgen
2026-03-10 22:52:23 -07:00
committed by GitHub
parent 54e05172a0
commit 87d21c7fc0
4 changed files with 85 additions and 0 deletions

View File

@@ -79,6 +79,10 @@ func main() {
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
}
// Restore strict error handling now that we know MLX is working.
// During init(), a safe handler prevented exit(-1) on GPU errors.
mlx.RestoreDefaultErrorHandler()
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)

View File

@@ -8,6 +8,7 @@ package mlx
// Use generated wrappers instead of direct MLX headers
#include "mlx.h"
#include "mlx_error_handler.h"
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
@@ -1836,11 +1837,31 @@ func init() {
return
}
// Enter safe mode: replace the default exit(-1) error handler with one
// that logs and stores errors. This prevents a GPU init failure from
// killing the entire process during startup.
C.mlx_set_safe_init_mode()
// Lock main goroutine to OS thread for CUDA context stability.
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
runtime.LockOSThread()
RandomState[0] = RandomKey(uint64(time.Now().UnixMilli()))
Keep(RandomState[0]) // Global state should persist
// Check if the RandomKey call silently failed under safe mode
if C.mlx_had_init_error() != 0 {
msg := C.GoString(C.mlx_get_init_error())
mlxInitError = fmt.Errorf("MLX GPU init failed: %s", msg)
mlxInitialized = false
return
}
}
// RestoreDefaultErrorHandler restores the default MLX error handler (exit on error).
// Call this from runner entry points after confirming MLX is available,
// to get the original strict error behavior during actual MLX work.
func RestoreDefaultErrorHandler() {
C.mlx_set_default_error_mode()
}
// RandomKey creates a PRNG key from a seed

View File

@@ -0,0 +1,38 @@
// mlx_error_handler.c - Safe error handling for MLX initialization
// Provides a non-fatal error handler for use during init(), so that
// GPU failures are captured instead of calling exit(-1).
#include "mlx_error_handler.h"
#include "mlx.h"
#include <string.h>
static char mlx_init_error_msg[1024] = {0};
static int mlx_init_error_flag = 0;
// Error handler that silently stores the error message.
// The error is surfaced on the Go side via mlxInitError / GetMLXInitError()
// only when MLX is actually needed.
static void mlx_silent_error_handler(const char* msg, void* data) {
(void)data;
strncpy(mlx_init_error_msg, msg, sizeof(mlx_init_error_msg) - 1);
mlx_init_error_msg[sizeof(mlx_init_error_msg) - 1] = '\0';
mlx_init_error_flag = 1;
}
void mlx_set_safe_init_mode(void) {
mlx_init_error_flag = 0;
mlx_init_error_msg[0] = '\0';
mlx_set_error_handler(mlx_silent_error_handler, NULL, NULL);
}
void mlx_set_default_error_mode(void) {
mlx_set_error_handler(NULL, NULL, NULL);
}
int mlx_had_init_error(void) {
return mlx_init_error_flag;
}
const char* mlx_get_init_error(void) {
return mlx_init_error_flag ? mlx_init_error_msg : NULL;
}

View File

@@ -0,0 +1,22 @@
// mlx_error_handler.h - Safe error handling for MLX initialization
// This replaces the default exit(-1) MLX error handler during init()
// so that GPU failures don't kill the process.
#ifndef MLX_ERROR_HANDLER_H
#define MLX_ERROR_HANDLER_H
// Enter safe mode before any MLX compute calls during init().
// Replaces the default exit(-1) handler with one that silently stores errors.
void mlx_set_safe_init_mode(void);
// Restore the default MLX error handler (exit on error).
// Call from runner entry points after confirming MLX is available.
void mlx_set_default_error_mode(void);
// Check whether an error occurred while in safe init mode.
int mlx_had_init_error(void);
// Get the error message from the last init error, or NULL if none.
const char* mlx_get_init_error(void);
#endif // MLX_ERROR_HANDLER_H