From 4d5ff25724c6749f2e855471d9ca2ff26ef04059 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 25 Feb 2026 15:06:37 -0800 Subject: [PATCH] mlxrunner: Report actual memory usage from runner The MLX runner previously reported a static VRAM estimate that was computed at load time and consisted only of the weights. This is strictly less than the actual memory usage, as it does not include the KV cache or compute graph. --- llm/server.go | 36 ++++++++++--------------------- server/routes.go | 3 +++ server/sched.go | 10 +++++---- server/sched_test.go | 3 +-- x/imagegen/server.go | 11 +++------- x/mlxrunner/client.go | 49 ++++++++++++++++++++++++------------------- x/mlxrunner/server.go | 7 ++++--- 7 files changed, 56 insertions(+), 63 deletions(-) diff --git a/llm/server.go b/llm/server.go index de8ad0f75..291dd47fe 100644 --- a/llm/server.go +++ b/llm/server.go @@ -74,8 +74,7 @@ type LlamaServer interface { Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error - VRAMSize() uint64 // Total VRAM across all GPUs - TotalSize() uint64 + MemorySize() (total, vram uint64) VRAMByGPU(id ml.DeviceID) uint64 Pid() int GetPort() int @@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system // Windows CUDA should not use mmap for best performance // Linux with a model larger than free space, mmap leads to thrashing // For CPU loads we want the memory to be allocated, not FS cache + totalSize, _ := s.MemorySize() if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || - (runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) || + (runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) || (len(gpus) == 0 && s.options.UseMMap == nil) || (len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) || (s.options.UseMMap != nil && !*s.options.UseMMap) { @@ -1848,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil } -func (s *llmServer) VRAMSize() uint64 { +func (s *llmServer) MemorySize() (total, vram uint64) { if s.mem == nil { - return 0 + return 0, 0 } - var mem uint64 - for _, g := range s.mem.GPUs { - mem += g.Size() + vram += g.Size() } + total = s.mem.InputWeights + s.mem.CPU.Size() + vram + // Some elements are always on CPU. However, if we have allocated all layers // on the GPU then include the CPU components as well, to represent complete offloading. noCPULayers := true @@ -1869,25 +1869,11 @@ func (s *llmServer) VRAMSize() uint64 { } } if noCPULayers { - mem += s.mem.InputWeights - mem += s.mem.CPU.Graph + vram += s.mem.InputWeights + vram += s.mem.CPU.Graph } - return mem -} - -func (s *llmServer) TotalSize() uint64 { - if s.mem == nil { - return 0 - } - - mem := s.mem.InputWeights - mem += s.mem.CPU.Size() - for _, g := range s.mem.GPUs { - mem += g.Size() - } - - return mem + return total, vram } func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 { diff --git a/server/routes.go b/server/routes.go index cbe771d9f..6fd9af5f4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1951,6 +1951,9 @@ func (s *Server) PsHandler(c *gin.Context) { } if v.llama != nil { mr.ContextLength = v.llama.ContextLength() + total, vram := v.llama.MemorySize() + mr.Size = int64(total) + mr.SizeVRAM = int64(vram) } // The scheduler waits to set expiresAt, so if a model is loading it's // possible that it will be set to the unix epoch. For those cases, just diff --git a/server/sched.go b/server/sched.go index 17315beb0..af768cf56 100644 --- a/server/sched.go +++ b/server/sched.go @@ -536,6 +536,7 @@ iGPUScan: } } + totalSize, vramSize := llama.MemorySize() runner := &runnerRef{ model: req.model, modelPath: req.model.ModelPath, @@ -545,8 +546,8 @@ iGPUScan: sessionDuration: sessionDuration, gpus: gpuIDs, discreteGPUs: discreteGPUs, - vramSize: llama.VRAMSize(), - totalSize: llama.TotalSize(), + totalSize: totalSize, + vramSize: vramSize, loading: true, pid: llama.Pid(), } @@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool { sessionDuration = req.sessionDuration.Duration } + totalSize, vramSize := server.MemorySize() runner := &runnerRef{ model: req.model, modelPath: req.model.ModelPath, @@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool { loading: false, isImagegen: isImagegen, sessionDuration: sessionDuration, - totalSize: server.TotalSize(), - vramSize: server.VRAMSize(), + totalSize: totalSize, + vramSize: vramSize, } s.loadedMu.Lock() diff --git a/server/sched_test.go b/server/sched_test.go index 4b1ed54f2..a21f0a709 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -861,8 +861,7 @@ func (s *mockLlm) Close() error { s.closeCalled = true return s.closeResp } -func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } -func (s *mockLlm) TotalSize() uint64 { return s.totalSize } +func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize } func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] } func (s *mockLlm) Pid() int { return -1 } func (s *mockLlm) GetPort() int { return -1 } diff --git a/x/imagegen/server.go b/x/imagegen/server.go index f79b8d3e9..2beb66bda 100644 --- a/x/imagegen/server.go +++ b/x/imagegen/server.go @@ -374,14 +374,9 @@ func (s *Server) Close() error { return nil } -// VRAMSize returns the estimated VRAM usage. -func (s *Server) VRAMSize() uint64 { - return s.vramSize -} - -// TotalSize returns the total memory usage. -func (s *Server) TotalSize() uint64 { - return s.vramSize +// MemorySize returns the total and VRAM memory usage. +func (s *Server) MemorySize() (total, vram uint64) { + return s.vramSize, s.vramSize } // VRAMByGPU returns VRAM usage for a specific GPU. diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index 2152c382f..48e2830a3 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -24,14 +24,13 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/x/imagegen" - "github.com/ollama/ollama/x/imagegen/manifest" ) // Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. type Client struct { port int modelName string - vramSize uint64 + memory uint done chan error client *http.Client lastErr string @@ -98,18 +97,9 @@ func NewClient(modelName string) (*Client, error) { slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) } - // Estimate VRAM based on tensor size from manifest - var vramSize uint64 - if modelManifest, err := manifest.LoadManifest(modelName); err == nil { - vramSize = uint64(modelManifest.TotalTensorSize()) - } else { - vramSize = 8 * 1024 * 1024 * 1024 - } - c := &Client{ port: port, modelName: modelName, - vramSize: vramSize, done: make(chan error, 1), client: &http.Client{Timeout: 10 * time.Minute}, cmd: cmd, @@ -347,9 +337,15 @@ func (c *Client) Pid() int { return -1 } +type statusResponse struct { + Status int + Progress int + Memory uint +} + // Ping implements llm.LlamaServer. func (c *Client) Ping(ctx context.Context) error { - reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port) + reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port) req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) if err != nil { return err @@ -362,6 +358,12 @@ func (c *Client) Ping(ctx context.Context) error { if resp.StatusCode != http.StatusOK { return fmt.Errorf("health check failed: %d", resp.StatusCode) } + + var status statusResponse + if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { + return err + } + c.memory = status.Memory return nil } @@ -388,19 +390,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) { return tokens, nil } -// TotalSize implements llm.LlamaServer. -func (c *Client) TotalSize() uint64 { - return c.vramSize +func (c *Client) currentMemory() uint64 { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := c.Ping(ctx); err != nil { + slog.Warn("failed to get current memory", "error", err) + } + return uint64(c.memory) +} + +// MemorySize implements llm.LlamaServer. +func (c *Client) MemorySize() (total, vram uint64) { + mem := c.currentMemory() + return mem, mem } // VRAMByGPU implements llm.LlamaServer. func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 { - return c.vramSize -} - -// VRAMSize implements llm.LlamaServer. -func (c *Client) VRAMSize() uint64 { - return c.vramSize + return c.currentMemory() } // WaitUntilRunning implements llm.LlamaServer. diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index 89688cfbc..19ebd59a8 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -50,9 +50,10 @@ func Execute(args []string) error { mux := http.NewServeMux() mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { - if err := json.NewEncoder(w).Encode(map[string]any{ - "status": 0, - "progress": 100, + if err := json.NewEncoder(w).Encode(statusResponse{ + Status: 0, + Progress: 100, + Memory: uint(mlx.ActiveMemory() + mlx.CacheMemory()), }); err != nil { slog.Error("Failed to encode response", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError)