mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
175 lines
4.4 KiB
Go
175 lines
4.4 KiB
Go
package mlxrunner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
"github.com/ollama/ollama/x/mlxrunner/sample"
|
|
"github.com/ollama/ollama/x/tokenizer"
|
|
)
|
|
|
|
type Request struct {
|
|
TextCompletionsRequest
|
|
Responses chan CompletionResponse
|
|
Pipeline func(Request) error
|
|
|
|
Ctx context.Context
|
|
|
|
Sampler *sample.Sampler
|
|
}
|
|
|
|
type TextCompletionsRequest struct {
|
|
Prompt string `json:"prompt"`
|
|
Options struct {
|
|
Temperature float32 `json:"temperature"`
|
|
TopP float32 `json:"top_p"`
|
|
MinP float32 `json:"min_p"`
|
|
TopK int `json:"top_k"`
|
|
RepeatLastN int `json:"repeat_last_n"`
|
|
PresencePenalty float32 `json:"presence_penalty"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
|
|
// Deprecated: use MaxTokens instead
|
|
NumPredict int `json:"num_predict"`
|
|
} `json:"options"`
|
|
}
|
|
|
|
type Runner struct {
|
|
Model base.Model
|
|
Tokenizer *tokenizer.Tokenizer
|
|
Requests chan Request
|
|
cache kvCache
|
|
contextLength int
|
|
}
|
|
|
|
func (r *Runner) Load(modelName string) error {
|
|
root, err := model.Open(modelName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer root.Close()
|
|
|
|
m, err := base.New(root)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Load all tensor blobs from manifest
|
|
tensors, err := loadTensorsFromManifest(root)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Assign weights to model (model-specific logic)
|
|
loadWeights := base.Weights(m)
|
|
if err := loadWeights(tensors); err != nil {
|
|
return err
|
|
}
|
|
|
|
r.Model = m
|
|
r.Tokenizer = m.Tokenizer()
|
|
r.contextLength = m.MaxContextLength()
|
|
return nil
|
|
}
|
|
|
|
// loadTensorsFromManifest loads all tensor blobs from the manifest into a
|
|
// flat map, deduplicating by digest and remapping safetensors key suffixes.
|
|
//
|
|
// Uses a two-phase approach: first loads all raw tensors, then remaps
|
|
// .bias → _qbias with complete knowledge of which base names have .scale
|
|
// entries. This avoids a race condition where Go map iteration order could
|
|
// cause .bias to be processed before .scale within the same blob.
|
|
func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) {
|
|
// Phase 1: Load all tensors raw from all blobs
|
|
rawTensors := make(map[string]*mlx.Array)
|
|
seen := make(map[string]bool)
|
|
for _, layer := range root.Manifest.GetTensorLayers("") {
|
|
if seen[layer.Digest] {
|
|
continue
|
|
}
|
|
seen[layer.Digest] = true
|
|
blobPath := root.Manifest.BlobPath(layer.Digest)
|
|
for name, arr := range mlx.Load(blobPath) {
|
|
rawTensors[name] = arr
|
|
}
|
|
}
|
|
|
|
// Phase 2: Identify all base names that have .scale tensors and remap them
|
|
scaleBaseNames := make(map[string]bool)
|
|
allTensors := make(map[string]*mlx.Array, len(rawTensors))
|
|
for name, arr := range rawTensors {
|
|
if strings.HasSuffix(name, ".scale") {
|
|
baseName := strings.TrimSuffix(name, ".scale")
|
|
allTensors[baseName+"_scale"] = arr
|
|
scaleBaseNames[baseName] = true
|
|
}
|
|
}
|
|
|
|
// Phase 3: Process remaining tensors with complete scale knowledge
|
|
for name, arr := range rawTensors {
|
|
if strings.HasSuffix(name, ".scale") {
|
|
continue // already handled
|
|
}
|
|
if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
|
|
baseName := strings.TrimSuffix(name, ".bias")
|
|
if scaleBaseNames[baseName] {
|
|
allTensors[baseName+"_qbias"] = arr
|
|
} else {
|
|
allTensors[name] = arr
|
|
}
|
|
} else {
|
|
allTensors[name] = arr
|
|
}
|
|
}
|
|
|
|
slog.Info("Loaded tensors from manifest", "count", len(allTensors))
|
|
return allTensors, nil
|
|
}
|
|
|
|
func (r *Runner) Run(host, port string, mux http.Handler) error {
|
|
g, ctx := errgroup.WithContext(context.Background())
|
|
|
|
g.Go(func() error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case request := <-r.Requests:
|
|
if err := request.Pipeline(request); err != nil {
|
|
slog.Info("Request terminated", "error", err)
|
|
var statusErr api.StatusError
|
|
if !errors.As(err, &statusErr) {
|
|
statusErr = api.StatusError{
|
|
StatusCode: http.StatusInternalServerError,
|
|
ErrorMessage: err.Error(),
|
|
}
|
|
}
|
|
select {
|
|
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
|
case <-request.Ctx.Done():
|
|
}
|
|
}
|
|
|
|
close(request.Responses)
|
|
}
|
|
}
|
|
})
|
|
|
|
g.Go(func() error {
|
|
slog.Info("Starting HTTP server", "host", host, "port", port)
|
|
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
|
|
})
|
|
|
|
return g.Wait()
|
|
}
|