diff --git a/cmd/cmd.go b/cmd/cmd.go index 031b200a8..518458d07 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -606,17 +606,6 @@ func RunHandler(cmd *cobra.Command, args []string) error { } opts.WordWrap = !nowrap - useImagegen := false - if cmd.Flags().Lookup("imagegen") != nil { - useImagegen, err = cmd.Flags().GetBool("imagegen") - if err != nil { - return err - } - } - if useImagegen { - opts.Options["use_imagegen_runner"] = true - } - // Fill out the rest of the options based on information about the // model. client, err := api.ClientFromEnvironment() diff --git a/server/routes.go b/server/routes.go index 43f76ea08..053b2fdc1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -161,7 +161,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return nil, nil, nil, fmt.Errorf("%s %w", name, err) } - useImagegen, _ := requestOpts["use_imagegen_runner"].(bool) + // Deprecated runner override option; ignore if present. delete(requestOpts, "use_imagegen_runner") opts, err := s.modelOptions(model, requestOpts) @@ -169,7 +169,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return nil, nil, nil, err } - runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen) + runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive) var runner *runnerRef select { case runner = <-runnerCh: diff --git a/server/sched.go b/server/sched.go index 4a64223e5..3d0dac863 100644 --- a/server/sched.go +++ b/server/sched.go @@ -33,7 +33,6 @@ type LlmRequest struct { successCh chan *runnerRef errCh chan error schedAttempts uint - useImagegen bool } type Scheduler struct { @@ -106,7 +105,7 @@ func schedulerModelKey(m *Model) string { } // context must be canceled to decrement ref count and release the runner -func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) { +func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) { if opts.NumCtx < 4 { opts.NumCtx = 4 } @@ -123,7 +122,6 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses sessionDuration: sessionDuration, successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), - useImagegen: useImagegen, } key := schedulerModelKey(req.model) @@ -593,20 +591,15 @@ iGPUScan: return false } -// loadMLX loads an experimental safetensors model using the unified MLX runner. -// This supports both LLM (completion) and image generation models. +// loadMLX loads an experimental safetensors model using MLX runners. +// Image models use x/imagegen; LLM models use x/mlxrunner. func (s *Scheduler) loadMLX(req *LlmRequest) bool { modelName := req.model.ShortName var server llm.LlamaServer var err error - isImagegen := false if slices.Contains(req.model.Config.Capabilities, "image") { - server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen) - isImagegen = true - } else if req.useImagegen { - server, err = imagegen.NewServer(modelName, imagegen.ModeLLM) - isImagegen = true + server, err = imagegen.NewServer(modelName) } else { server, err = mlxrunner.NewClient(modelName) } @@ -628,7 +621,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool { llama: server, Options: &req.opts, loading: false, - isImagegen: isImagegen, + isImagegen: slices.Contains(req.model.Config.Capabilities, "image"), sessionDuration: sessionDuration, totalSize: totalSize, vramSize: vramSize, @@ -737,8 +730,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool runner.refMu.Lock() defer runner.refMu.Unlock() - // Check if runner type (imagegen vs mlxrunner) matches what's requested - wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image") + // Check if runner type (imagegen vs mlxrunner) matches what's requested. + wantImagegen := slices.Contains(req.model.Config.Capabilities, "image") if runner.isImagegen != wantImagegen { return true } diff --git a/server/sched_test.go b/server/sched_test.go index a21f0a709..0b79c7834 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) { s.getSystemInfoFn = getSystemInfoFn s.newServerFn = a.newServer slog.Info("a") - successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false) + successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) slog.Info("b") - successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false) + successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) require.Empty(t, successCh1b) require.Len(t, errCh1b, 1) @@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) { c.req.model.ModelPath = "bad path" slog.Info("c") - successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false) + successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration) // Starts in pending channel, then should be quickly processed to return an error time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload require.Empty(t, successCh1c) @@ -470,7 +470,7 @@ func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) { s.loadedMu.Unlock() reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"} - successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false) + successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil) require.Empty(t, successCh) require.Empty(t, errCh) @@ -499,7 +499,7 @@ func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) { s.loadedMu.Unlock() reqCtx, cancelReq := context.WithCancel(ctx) - successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false) + successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil) cancelReq() select { @@ -574,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) { s.getGpuFn = getGpuFn s.getSystemInfoFn = getSystemInfoFn s.newServerFn = scenario1a.newServer - successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false) + successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) s.Run(ctx) select { diff --git a/x/imagegen/cmd/engine/README.md b/x/imagegen/cmd/engine/README.md index 02dc0b979..3991c02a8 100644 --- a/x/imagegen/cmd/engine/README.md +++ b/x/imagegen/cmd/engine/README.md @@ -10,17 +10,7 @@ go build -tags mlx -o engine ./x/imagegen/cmd/engine ## Text Generation -```bash -./engine -model /path/to/model -prompt "Hello" -max-tokens 100 -``` - -Options: - -- `-temperature` - sampling temperature (default 0.7) -- `-top-p` - nucleus sampling (default 0.9) -- `-top-k` - top-k sampling (default 40) - -Supports: Llama, Gemma3, GPT-OSS +Text generation models are no longer supported by this engine. ## Image Generation diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go index f0e705d1c..6ec7de9e1 100644 --- a/x/imagegen/cmd/engine/main.go +++ b/x/imagegen/cmd/engine/main.go @@ -18,9 +18,6 @@ import ( "github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/models/flux2" - "github.com/ollama/ollama/x/imagegen/models/gemma3" - "github.com/ollama/ollama/x/imagegen/models/gpt_oss" - "github.com/ollama/ollama/x/imagegen/models/llama" "github.com/ollama/ollama/x/imagegen/models/zimage" "github.com/ollama/ollama/x/imagegen/safetensors" ) @@ -170,11 +167,11 @@ func main() { log.Fatal(err) } - // Load image if provided and model supports it + // Load image if provided and model supports it. var image *mlx.Array if *imagePath != "" { if mm, ok := m.(interface{ ImageSize() int32 }); ok { - image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize()) + image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize()) if err != nil { log.Fatal("load image:", err) } @@ -236,14 +233,8 @@ func load(modelPath string) (Model, error) { } switch kind { - case "gpt_oss": - return gpt_oss.Load(modelPath) - case "gemma3": - return gemma3.Load(modelPath) - case "gemma3_text": - return gemma3.LoadText(modelPath) default: - return llama.Load(modelPath) + return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind) } } diff --git a/x/imagegen/models/gemma3/image.go b/x/imagegen/image_processor.go similarity index 73% rename from x/imagegen/models/gemma3/image.go rename to x/imagegen/image_processor.go index 9532d852d..7a562feb5 100644 --- a/x/imagegen/models/gemma3/image.go +++ b/x/imagegen/image_processor.go @@ -1,6 +1,6 @@ //go:build mlx -package gemma3 +package imagegen import ( "fmt" @@ -13,8 +13,8 @@ import ( "golang.org/x/image/draw" ) -// ProcessImage loads and preprocesses an image for the vision tower -// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP +// ProcessImage loads and preprocesses an image for multimodal vision towers. +// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP. func ProcessImage(path string, imageSize int32) (*mlx.Array, error) { f, err := os.Open(path) if err != nil { @@ -30,20 +30,20 @@ func ProcessImage(path string, imageSize int32) (*mlx.Array, error) { return ProcessImageData(img, imageSize) } -// ProcessImageData preprocesses an image.Image for the vision tower +// ProcessImageData preprocesses an image.Image for multimodal vision towers. func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) { - // Resize to target size using bilinear interpolation + // Resize to target size using bilinear interpolation. resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize))) draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil) - // Convert to float32 array [H, W, C] and normalize - // SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0 + // Convert to float32 array [H, W, C] and normalize. + // SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0. data := make([]float32, imageSize*imageSize*3) idx := 0 for y := int32(0); y < imageSize; y++ { for x := int32(0); x < imageSize; x++ { r, g, b, _ := resized.At(int(x), int(y)).RGBA() - // RGBA returns 16-bit values, convert to 8-bit + // RGBA returns 16-bit values, convert to 8-bit. data[idx] = float32(r>>8)/127.5 - 1.0 data[idx+1] = float32(g>>8)/127.5 - 1.0 data[idx+2] = float32(b>>8)/127.5 - 1.0 @@ -51,8 +51,8 @@ func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) { } } - // Create MLX array [1, H, W, C] for NHWC layout + // Create MLX array [1, H, W, C] for NHWC layout. arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3}) - mlx.Eval(arr) // Materialize to prevent use-after-free + mlx.Eval(arr) // Materialize to prevent use-after-free. return arr, nil } diff --git a/x/imagegen/llm.go b/x/imagegen/llm.go deleted file mode 100644 index eda3b64b1..000000000 --- a/x/imagegen/llm.go +++ /dev/null @@ -1,420 +0,0 @@ -//go:build mlx - -package imagegen - -import ( - "encoding/json" - "errors" - "fmt" - "log/slog" - "net/http" - "strings" - "sync" - "time" - - "github.com/ollama/ollama/x/imagegen/cache" - "github.com/ollama/ollama/x/imagegen/manifest" - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite" - "github.com/ollama/ollama/x/imagegen/tokenizer" -) - -// TextModel is the interface for LLM text generation models. -type TextModel interface { - Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array - NewCache(maxSeqLen int32) []cache.Cache - Tokenizer() *tokenizer.Tokenizer - VocabSize() int32 - MaxContextLength() int32 - NumLayers() int -} - -// llmState holds the state for LLM generation -type llmState struct { - model TextModel -} - -var llmMu sync.Mutex - -// Dedicated stream for generation (like mlx-lm's generation_stream) -var generationStream *mlx.Stream - -// withStream runs fn with the generation stream as default -func withStream(fn func()) { - // Lazy initialization of generationStream - if generationStream == nil { - generationStream = mlx.NewStream() - } - orig := mlx.GetDefaultStream() - mlx.SetDefaultStream(generationStream) - fn() - mlx.SetDefaultStream(orig) -} - -// Decoder wraps model + cache for autoregressive generation. -// This matches the pattern from cmd/engine/generate.go -type Decoder struct { - model TextModel - caches []cache.Cache - vocabSize int32 - temp float32 - token *mlx.Array // Current token (kept across iterations) - oldCacheState []*mlx.Array // Preallocated slice for old cache state -} - -func NewDecoder(m TextModel, temp float32) *Decoder { - caches := m.NewCache(0) - return &Decoder{ - model: m, - caches: caches, - vocabSize: m.VocabSize(), - temp: temp, - oldCacheState: make([]*mlx.Array, 0, len(caches)*2), - } -} - -func (d *Decoder) prefill(inputIDs []int32) int { - processed := 0 - - // Track old cache state to free after each chunk - var oldCacheState []*mlx.Array - - // Process all-but-1 tokens in chunks, eval cache state for memory management - for len(inputIDs) > 1 { - chunkSize := min(2048, len(inputIDs)-1) - if chunkSize <= 0 { - break - } - chunk := inputIDs[:chunkSize] - - // Save old cache state before forward - oldCacheState = oldCacheState[:0] - for _, c := range d.caches { - oldCacheState = append(oldCacheState, c.State()...) - } - - var cacheState []*mlx.Array - withStream(func() { - x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))}) - d.model.Forward(x, d.caches) - for _, c := range d.caches { - cacheState = append(cacheState, c.State()...) - } - }) - mlx.Eval(cacheState...) - - // Free old cache state - for _, arr := range oldCacheState { - if arr != nil { - arr.Free() - } - } - - inputIDs = inputIDs[chunkSize:] - processed += chunkSize - } - - // Save old cache state before final step - oldCacheState = oldCacheState[:0] - for _, c := range d.caches { - oldCacheState = append(oldCacheState, c.State()...) - } - - // Final token + sampling - withStream(func() { - x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))}) - mlx.Eval(x) // Materialize before any other evals - logits := d.model.Forward(x, d.caches) - d.token = sample(logits, d.temp, d.vocabSize) - }) - // Keep cache state (token auto-kept by AsyncEval) - for _, c := range d.caches { - mlx.Keep(c.State()...) - } - mlx.AsyncEval(d.token) - - // Free old cache state from before final step - for _, arr := range oldCacheState { - if arr != nil { - arr.Free() - } - } - - mlx.ClearCache() - - return processed + len(inputIDs) -} - -func (d *Decoder) step() int32 { - prevToken := d.token - - // Save old cache state (reuse preallocated slice) - d.oldCacheState = d.oldCacheState[:0] - for _, c := range d.caches { - d.oldCacheState = append(d.oldCacheState, c.State()...) - } - - withStream(func() { - logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches) - d.token = sample(logits, d.temp, d.vocabSize) - }) - // Keep token and new cache state so they survive cleanup - mlx.Keep(d.token) - for _, c := range d.caches { - mlx.Keep(c.State()...) - } - mlx.AsyncEval(d.token) - - // Sync on previous token (GPU already working on next step) - val := prevToken.ItemInt32() - - // Free old token and old cache state - prevToken.Free() - for _, arr := range d.oldCacheState { - arr.Free() - } - return val -} - -// sample samples from logits using temperature scaling -func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array { - // Get last position logits: [1, L, vocab] -> [vocab] - shape := logits.Shape() - seqLen := shape[1] - lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize}) - lastLogits = mlx.Reshape(lastLogits, vocabSize) - - if temp <= 0 || temp < 0.01 { - // Greedy decoding - return mlx.Argmax(lastLogits, -1, false) - } - - // Apply temperature scaling - scaled := mlx.DivScalar(lastLogits, temp) - return mlx.RandomCategorical(scaled, -1, 1) -} - -// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage. -func (s *server) loadLLMModel() error { - // Load the manifest to get model information - modelManifest, err := manifest.LoadManifest(s.modelName) - if err != nil { - return fmt.Errorf("failed to load manifest: %w", err) - } - - // Detect model architecture from config.json - configData, err := modelManifest.ReadConfig("config.json") - if err != nil { - return fmt.Errorf("failed to read config.json: %w", err) - } - - var modelConfig struct { - Architectures []string `json:"architectures"` - ModelType string `json:"model_type"` - } - if err := json.Unmarshal(configData, &modelConfig); err != nil { - return fmt.Errorf("failed to parse config.json: %w", err) - } - - arch := "" - if len(modelConfig.Architectures) > 0 { - arch = modelConfig.Architectures[0] - } - if arch == "" { - arch = modelConfig.ModelType - } - - slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType) - - // Load the appropriate model based on architecture - var model TextModel - archLower := strings.ToLower(arch) - - switch { - case strings.Contains(archLower, "glm4moelite"): - m, err := glm4_moe_lite.LoadFromManifest(modelManifest) - if err != nil { - return fmt.Errorf("failed to load glm4-moe-lite model: %w", err) - } - model = m - slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers()) - - default: - return fmt.Errorf("LLM architecture %q is not yet supported. "+ - "Supported architectures: glm4-moe-lite. "+ - "Please convert your model to GGUF format or use a supported architecture", arch) - } - - s.llmModel = &llmState{ - model: model, - } - - return nil -} - -// handleLLMCompletion handles LLM text generation requests. -func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) { - if s.llmModel == nil { - http.Error(w, "LLM model not loaded", http.StatusInternalServerError) - return - } - - // Serialize generation requests - llmMu.Lock() - defer llmMu.Unlock() - - if err := s.llmGenerate(w, r, req); err != nil { - slog.Error("LLM generation failed", "error", err) - // Don't send error if we've already started streaming - } -} - -// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine -func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error { - state := s.llmModel - - // Set up streaming response - w.Header().Set("Content-Type", "application/x-ndjson") - w.Header().Set("Transfer-Encoding", "chunked") - flusher, ok := w.(http.Flusher) - if !ok { - return errors.New("streaming not supported") - } - - tok := state.model.Tokenizer() - - // The prompt is already formatted by the server using the model's renderer - // (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here. - prompt := req.Prompt - - // Tokenize the prompt - inputIDs := tok.Encode(prompt, true) - slog.Debug("tokenized prompt", "num_tokens", len(inputIDs)) - - // Generation parameters - maxTokens := int(state.model.MaxContextLength()) - if maxTokens <= 0 { - maxTokens = 4096 - } - if req.Options != nil && req.Options.NumPredict > 0 { - maxTokens = req.Options.NumPredict - } - - temperature := float32(0.7) - if req.Options != nil && req.Options.Temperature > 0 { - temperature = float32(req.Options.Temperature) - } - - // Enable MLX compilation for better performance - mlx.EnableCompile() - - // Create decoder with fresh caches - dec := NewDecoder(state.model, temperature) - - prefillStart := time.Now() - prefillTokens := dec.prefill(inputIDs) - // Prefill measurement includes time to first token - firstToken := dec.step() - prefillDuration := time.Since(prefillStart) - promptEvalDuration := prefillDuration - - enc := json.NewEncoder(w) - ctx := r.Context() - generated := 0 - stopReason := "max_tokens" - - // Handle first token - generated++ - if tok.IsEOS(firstToken) { - resp := Response{ - Done: true, - StopReason: fmt.Sprintf("first_token_eos:%d", firstToken), - PromptEvalCount: prefillTokens, - PromptEvalDuration: int(promptEvalDuration.Nanoseconds()), - } - enc.Encode(resp) - flusher.Flush() - return nil - } - - text := tok.Decode([]int32{firstToken}) - resp := Response{Content: text} - enc.Encode(resp) - flusher.Flush() - - genStart := time.Now() - - // Generation loop - for n := 1; n < maxTokens; n++ { - // Check for cancellation - select { - case <-ctx.Done(): - stopReason = fmt.Sprintf("context_cancelled:%d", generated) - break - default: - } - if stopReason != "max_tokens" { - break - } - - token := dec.step() - generated++ - - if tok.IsEOS(token) { - stopReason = fmt.Sprintf("eos_token:%d", token) - break - } - - text := tok.Decode([]int32{token}) - - // Check for stop sequences - if req.Options != nil && len(req.Options.Stop) > 0 { - shouldStop := false - var matchedStop string - for _, stop := range req.Options.Stop { - if strings.Contains(text, stop) { - text = strings.Split(text, stop)[0] - shouldStop = true - matchedStop = stop - break - } - } - if shouldStop { - if text != "" { - resp := Response{Content: text} - enc.Encode(resp) - flusher.Flush() - } - stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop) - break - } - } - - resp := Response{Content: text} - enc.Encode(resp) - flusher.Flush() - - // Periodically clear MLX cache - if n%256 == 0 { - mlx.ClearCache() - } - } - - // Clean up - mlx.ClearCache() - - // Send final response with stats - evalDuration := time.Since(genStart) - resp = Response{ - Done: true, - StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated), - PromptEvalCount: prefillTokens, - PromptEvalDuration: int(promptEvalDuration.Nanoseconds()), - EvalCount: generated, - EvalDuration: int(evalDuration.Nanoseconds()), - } - enc.Encode(resp) - flusher.Flush() - - return nil -} diff --git a/x/imagegen/models/gemma3/gemma3.go b/x/imagegen/models/gemma3/gemma3.go deleted file mode 100644 index b56adc797..000000000 --- a/x/imagegen/models/gemma3/gemma3.go +++ /dev/null @@ -1,614 +0,0 @@ -//go:build mlx - -package gemma3 - -import ( - "encoding/json" - "fmt" - "math" - "os" - "path/filepath" - - "github.com/ollama/ollama/x/imagegen/cache" - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/nn" - "github.com/ollama/ollama/x/imagegen/safetensors" - "github.com/ollama/ollama/x/imagegen/tokenizer" -) - -// TextConfig holds configuration for the text model -type TextConfig struct { - HiddenSize int32 `json:"hidden_size"` - NumHiddenLayers int32 `json:"num_hidden_layers"` - IntermediateSize int32 `json:"intermediate_size"` - NumAttentionHeads int32 `json:"num_attention_heads"` - NumKeyValueHeads int32 `json:"num_key_value_heads"` - HeadDim int32 `json:"head_dim"` - VocabSize int32 `json:"vocab_size"` - RMSNormEps float32 `json:"rms_norm_eps"` - RopeTheta float32 `json:"rope_theta"` - RopeLocalBaseFreq float32 `json:"rope_local_base_freq"` - MaxPositionEmbeddings int32 `json:"max_position_embeddings"` - SlidingWindow int32 `json:"sliding_window"` - SlidingWindowPattern int32 `json:"sliding_window_pattern"` - - // Computed fields - Scale float32 `json:"-"` -} - -// TextModel is the Gemma 3 text-only model -type TextModel struct { - EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` - Layers []*DecoderLayer `weight:"model.layers"` - Norm *nn.RMSNorm `weight:"model.norm"` - Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually - - // Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward - NormScaled *mlx.Array `weight:"-"` - - tok *tokenizer.Tokenizer - *TextConfig -} - -// DecoderLayer is a single transformer block -type DecoderLayer struct { - InputNorm *nn.RMSNorm `weight:"input_layernorm"` - Attention *Attention - PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"` - PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"` - MLP *MLP - PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"` - - // Precomputed (1 + weight) for Gemma-style RMSNorm - InputNormScaled *mlx.Array `weight:"-"` - PostAttnNormScaled *mlx.Array `weight:"-"` - PreFFNormScaled *mlx.Array `weight:"-"` - PostFFNormScaled *mlx.Array `weight:"-"` - - // Whether this layer uses sliding window attention - IsSliding bool - LayerIdx int32 -} - -// Attention implements Gemma 3 attention with Q/K normalization -type Attention struct { - QProj *nn.Linear `weight:"self_attn.q_proj"` - KProj *nn.Linear `weight:"self_attn.k_proj"` - VProj *nn.Linear `weight:"self_attn.v_proj"` - OProj *nn.Linear `weight:"self_attn.o_proj"` - QNorm *nn.RMSNorm `weight:"self_attn.q_norm"` - KNorm *nn.RMSNorm `weight:"self_attn.k_norm"` - - // Precomputed (1 + weight) for Gemma-style RMSNorm - QNormScaled *mlx.Array `weight:"-"` - KNormScaled *mlx.Array `weight:"-"` -} - -// MLP is the feed-forward network with GELU activation -type MLP struct { - GateProj *nn.Linear `weight:"mlp.gate_proj"` - UpProj *nn.Linear `weight:"mlp.up_proj"` - DownProj *nn.Linear `weight:"mlp.down_proj"` -} - -// LoadText loads the text-only Gemma 3 model -func LoadText(modelPath string) (*TextModel, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("load config: %w", err) - } - var cfg TextConfig - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("parse config: %w", err) - } - - // Compute scale - cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) - - // Set defaults if not specified - if cfg.RopeTheta == 0 { - cfg.RopeTheta = 1000000 - } - if cfg.RopeLocalBaseFreq == 0 { - cfg.RopeLocalBaseFreq = 10000 - } - if cfg.RMSNormEps == 0 { - cfg.RMSNormEps = 1e-6 - } - - weights, err := safetensors.LoadModelWeights(modelPath) - if err != nil { - return nil, fmt.Errorf("load weights: %w", err) - } - - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) - if err != nil { - return nil, fmt.Errorf("load tokenizer: %w", err) - } - - m := &TextModel{ - Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), - TextConfig: &cfg, - tok: tok, - } - - // Initialize layer metadata - for i := range m.Layers { - m.Layers[i] = &DecoderLayer{ - LayerIdx: int32(i), - IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern), - } - } - - if err := safetensors.LoadModule(m, weights, ""); err != nil { - return nil, err - } - - // Tied embeddings for output - m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil) - - mlx.Eval(mlx.Collect(m)...) - weights.ReleaseAll() - - // Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation - precomputeGemmaScaledWeights(m) - - return m, nil -} - -// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers -// This avoids creating temporary arrays on every forward pass -func precomputeGemmaScaledWeights(m *TextModel) { - m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0) - - for _, layer := range m.Layers { - layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0) - layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0) - layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0) - layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0) - - layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0) - layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0) - } - - // Eval all the precomputed weights - var scaled []*mlx.Array - scaled = append(scaled, m.NormScaled) - for _, layer := range m.Layers { - scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled, - layer.PreFFNormScaled, layer.PostFFNormScaled, - layer.Attention.QNormScaled, layer.Attention.KNormScaled) - } - mlx.Eval(scaled...) -} - -// isLayerSliding determines if a layer uses sliding window attention -// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc. -func isLayerSliding(layerIdx, pattern int32) bool { - if pattern <= 0 { - return false // No sliding window - } - // Layer is full attention if (layerIdx + 1) % pattern == 0 - return (layerIdx+1)%pattern != 0 -} - -// Forward runs the text model forward pass -func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - B, L := tokens.Shape()[0], tokens.Shape()[1] - - // Get embeddings and scale by sqrt(hidden_size) - h := m.EmbedTokens.Forward(tokens) - h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize)))) - - for i, layer := range m.Layers { - h = layer.Forward(h, caches[i], B, L, m.TextConfig) - } - - // Final norm and output projection - return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps)) -} - -// Forward runs a decoder layer -func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array { - // Pre-attention norm (use precomputed scaled weight) - normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps) - - // Attention - attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg) - - // Post-attention norm and residual - attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) - h := mlx.Add(x, attnOut) - - // Pre-FFN norm - normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps) - - // MLP - mlpOut := l.MLP.Forward(normed) - - // Post-FFN norm and residual - mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) - return mlx.Add(h, mlpOut) -} - -// Forward runs attention with Q/K normalization -func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array { - q := a.QProj.Forward(x) - k := a.KProj.Forward(x) - v := a.VProj.Forward(x) - - // Reshape to [B, num_heads, L, head_dim] - q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - - // Q/K normalization after reshaping (use precomputed scaled weight) - q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps) - k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps) - - // Apply RoPE with appropriate theta - ropeTheta := cfg.RopeTheta - if isSliding { - ropeTheta = cfg.RopeLocalBaseFreq - } - q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) - k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) - - // Update cache - k, v = c.Update(k, v, int(L)) - - // Repeat K/V for GQA if needed - repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads - if repeatFactor > 1 { - k = nn.RepeatKV(k, repeatFactor) - v = nn.RepeatKV(v, repeatFactor) - } - - // Attention - out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) - return a.OProj.Forward(out) -} - -// compiledGeluApprox is a singleton compiled GELU function shared across all layers -var compiledGeluApprox *mlx.CompiledFunc - -// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed -func getCompiledGeluApprox() *mlx.CompiledFunc { - if compiledGeluApprox == nil { - compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array { - return []*mlx.Array{geluApproxImpl(inputs[0])} - }, true) - } - return compiledGeluApprox -} - -// Forward runs the MLP with GELU approximation (tanh variant) -func (m *MLP) Forward(x *mlx.Array) *mlx.Array { - gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0] - return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) -} - -// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh): -// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) -func geluApproxImpl(x *mlx.Array) *mlx.Array { - // Constants - const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi) - const coeff = 0.044715 - - // x^3 - x3 := mlx.Mul(mlx.Mul(x, x), x) - // x + 0.044715 * x^3 - inner := mlx.Add(x, mlx.MulScalar(x3, coeff)) - // sqrt(2/pi) * (x + 0.044715 * x^3) - scaled := mlx.MulScalar(inner, sqrt2OverPi) - // tanh(...) - tanh := mlx.Tanh(scaled) - // 1 + tanh(...) - onePlusTanh := mlx.AddScalar(tanh, 1.0) - // 0.5 * x * (1 + tanh(...)) - return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh) -} - -// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight) -// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight) -func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array { - // Gemma uses (1 + weight) instead of weight - scaledWeight := mlx.AddScalar(weight, 1.0) - return mlx.RMSNorm(x, scaledWeight, eps) -} - -// Interface methods -func (m *TextModel) NumLayers() int { return len(m.Layers) } -func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings } -func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize } - -// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template -func (m *TextModel) Tokenizer() *tokenizer.Tokenizer { - return m.tok -} - -// FormatPrompt applies the Gemma 3 chat template to a prompt -func (m *TextModel) FormatPrompt(prompt string) string { - // Gemma 3 chat format: user\n{prompt}\nmodel\n - return fmt.Sprintf("user\n%s\nmodel\n", prompt) -} - -func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache { - caches := make([]cache.Cache, len(m.Layers)) - for i := range caches { - if m.Layers[i].IsSliding { - // Use rotating cache for sliding window layers - caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow)) - } else { - // Use regular cache for global attention layers - caches[i] = cache.NewKVCache() - } - } - return caches -} - -// Config holds config for the full multimodal model -type Config struct { - TextConfig TextConfig `json:"text_config"` - VisionConfig VisionConfig `json:"vision_config"` - - // Image token config (from config.json) - BOITokenIndex int32 `json:"boi_token_index"` // = 255999 - EOITokenIndex int32 `json:"eoi_token_index"` // = 256000 - ImageTokenIndex int32 `json:"image_token_index"` // = 262144 - MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256 -} - -// Model is the full Gemma 3 multimodal model -type Model struct { - VisionTower *VisionTower `weight:"vision_tower"` - Projector *MultiModalProjector `weight:"multi_modal_projector"` - TextModel *TextModel `weight:"language_model"` - Config *Config - tok *tokenizer.Tokenizer -} - -// Load loads the full multimodal Gemma 3 model -func Load(modelPath string) (*Model, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("load config: %w", err) - } - - var cfg Config - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("parse config: %w", err) - } - - // Set defaults for text config (multimodal config often has incomplete text_config) - // These defaults match transformers.Gemma3TextConfig defaults - tc := &cfg.TextConfig - if tc.HeadDim == 0 { - tc.HeadDim = 256 // Gemma 3 uses head_dim=256 - } - if tc.NumAttentionHeads == 0 { - // Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim) - tc.NumAttentionHeads = 8 - } - if tc.NumKeyValueHeads == 0 { - // Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio) - tc.NumKeyValueHeads = 4 - } - if tc.VocabSize == 0 { - tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!) - } - if tc.RopeTheta == 0 { - tc.RopeTheta = 1000000 - } - if tc.RopeLocalBaseFreq == 0 { - tc.RopeLocalBaseFreq = 10000 - } - if tc.RMSNormEps == 0 { - tc.RMSNormEps = 1e-6 - } - if tc.SlidingWindowPattern == 0 { - tc.SlidingWindowPattern = 6 - } - if tc.MaxPositionEmbeddings == 0 { - tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default - } - - // Compute text model scale - tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim))) - - // Set defaults for image token config - if cfg.BOITokenIndex == 0 { - cfg.BOITokenIndex = 255999 // - } - if cfg.EOITokenIndex == 0 { - cfg.EOITokenIndex = 256000 // - } - if cfg.ImageTokenIndex == 0 { - cfg.ImageTokenIndex = 262144 // - } - if cfg.MMTokensPerImage == 0 { - cfg.MMTokensPerImage = 256 - } - - weights, err := safetensors.LoadModelWeights(modelPath) - if err != nil { - return nil, fmt.Errorf("load weights: %w", err) - } - - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) - if err != nil { - return nil, fmt.Errorf("load tokenizer: %w", err) - } - - m := &Model{ - VisionTower: &VisionTower{ - Embeddings: &VisionEmbeddings{}, - Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers), - Config: &cfg.VisionConfig, - }, - Projector: &MultiModalProjector{}, - TextModel: &TextModel{ - Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers), - TextConfig: &cfg.TextConfig, - }, - Config: &cfg, - tok: tok, - } - - // Initialize text layer metadata - for i := range m.TextModel.Layers { - m.TextModel.Layers[i] = &DecoderLayer{ - LayerIdx: int32(i), - IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern), - } - } - - // Initialize vision encoder layers - for i := range m.VisionTower.Encoder { - m.VisionTower.Encoder[i] = &VisionEncoderLayer{} - } - - if err := safetensors.LoadModule(m, weights, ""); err != nil { - return nil, err - } - - // Tied embeddings for text output - m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil) - m.TextModel.tok = tok - - mlx.Eval(mlx.Collect(m)...) - weights.ReleaseAll() - - // Precompute (1 + weight) for Gemma-style RMSNorm - precomputeGemmaScaledWeights(m.TextModel) - - // Precompute projector's scaled weight - m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0) - mlx.Eval(m.Projector.SoftEmbNormScaled) - - return m, nil -} - -// Forward runs the text-only forward pass -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - return m.TextModel.Forward(tokens, caches) -} - -// ForwardWithImage runs the multimodal forward pass -// tokens: [B, L] input token IDs (with image placeholder tokens) -// image: [B, H, W, C] preprocessed image tensor -func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array { - B, L := tokens.Shape()[0], tokens.Shape()[1] - cfg := m.Config.TextConfig - - // Find image token position FIRST before any eval that might free tokens - imageStartPos := int32(-1) - if image != nil && B == 1 { - tokenData := tokens.DataInt32() // This evals tokens - for i, t := range tokenData { - if t == m.Config.ImageTokenIndex { - imageStartPos = int32(i) - break - } - } - } - - // Get text embeddings and scale - h := m.TextModel.EmbedTokens.Forward(tokens) - h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize)))) - - // Process image if provided - if image != nil && imageStartPos >= 0 { - // Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden] - visionFeatures := m.VisionTower.Forward(image) - - // Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden] - imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps) - - // Eval h and imageEmbeds together so neither gets freed - mlx.Eval(h, imageEmbeds) - - // Cast imageEmbeds to match text embeddings dtype (bf16) - if imageEmbeds.Dtype() != h.Dtype() { - imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype()) - mlx.Eval(imageEmbeds) - } - - // Insert image embeddings at the known position - h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos) - } - - // Run through text model layers - for i, layer := range m.TextModel.Layers { - h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig) - } - - // Final norm and output projection - return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps)) -} - -// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings -// at a known position (to avoid re-scanning tokens after eval) -// textEmbeds: [B, L, hidden_size] text embeddings -// imageEmbeds: [B, 256, hidden_size] image embeddings from projector -// startPos: starting position of image tokens in the sequence -func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array { - numImageTokens := imageEmbeds.Shape()[1] - L := textEmbeds.Shape()[1] - - // Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L] - afterStart := startPos + numImageTokens - - // Slice before image tokens: textEmbeds[:, 0:startPos, :] - before := mlx.SliceAxis(textEmbeds, 1, 0, startPos) - - // Slice after image tokens: textEmbeds[:, startPos+256:L, :] - after := mlx.SliceAxis(textEmbeds, 1, afterStart, L) - - // Concatenate: before + imageEmbeds + after along axis 1 - return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1) -} - -// Interface methods for Model -func (m *Model) NumLayers() int { return len(m.TextModel.Layers) } -func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings } -func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize } -func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } -func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) } -func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize } - -// FormatPrompt applies the Gemma 3 multimodal chat template -func (m *Model) FormatPrompt(prompt string) string { - return fmt.Sprintf("user\n%s\nmodel\n", prompt) -} - -// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image -func (m *Model) FormatPromptWithImage(prompt string) string { - return fmt.Sprintf("user\n%s\nmodel\n", prompt) -} - -// ExpandImageTokens expands into 256 image placeholder tokens -// Input tokens containing boi_token (255999) are expanded to: -// boi_token + 256 * image_token + eoi_token -func (m *Model) ExpandImageTokens(tokens []int32) []int32 { - result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1) - - for _, t := range tokens { - if t == m.Config.BOITokenIndex { - // Expand: boi + 256 * image_token + eoi - result = append(result, m.Config.BOITokenIndex) - for i := int32(0); i < m.Config.MMTokensPerImage; i++ { - result = append(result, m.Config.ImageTokenIndex) - } - result = append(result, m.Config.EOITokenIndex) - } else { - result = append(result, t) - } - } - - return result -} diff --git a/x/imagegen/models/gemma3/projector.go b/x/imagegen/models/gemma3/projector.go deleted file mode 100644 index ecdbe6941..000000000 --- a/x/imagegen/models/gemma3/projector.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build mlx - -package gemma3 - -import ( - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/nn" -) - -// MultiModalProjector projects vision features to text embedding space -type MultiModalProjector struct { - // mm_input_projection_weight: [vision_hidden, text_hidden] - InputProjection *mlx.Array `weight:"mm_input_projection_weight"` - SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"` - - // Precomputed (1 + weight) for Gemma-style RMSNorm - SoftEmbNormScaled *mlx.Array `weight:"-"` -} - -// Forward projects vision features to text space -// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152]) -// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560]) -func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array { - // Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152] - // 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens - B := visionFeatures.Shape()[0] - visionHidden := visionFeatures.Shape()[2] - - // Reshape to [B, 64, 64, hidden] - gridSize := int32(64) // sqrt(4096) - pooledSize := int32(16) // 64/4 - h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden) - - // Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling - h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden) - - // Average over pooling dimensions (axes 2 and 4) - h = mlx.Mean(h, 4, false) - h = mlx.Mean(h, 2, false) - - // h is now [B, 16, 16, hidden], reshape to [B, 256, hidden] - numTokens := pooledSize * pooledSize - h = mlx.Reshape(h, B, numTokens, visionHidden) - - // Apply Gemma-style RMS norm (use precomputed 1 + weight) - h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps) - - // Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden] - return mlx.Linear(h, p.InputProjection) -} diff --git a/x/imagegen/models/gemma3/vision.go b/x/imagegen/models/gemma3/vision.go deleted file mode 100644 index 1c4d8e54f..000000000 --- a/x/imagegen/models/gemma3/vision.go +++ /dev/null @@ -1,138 +0,0 @@ -//go:build mlx - -package gemma3 - -import ( - "math" - - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/nn" -) - -// VisionConfig holds configuration for the SigLIP vision tower -type VisionConfig struct { - HiddenSize int32 `json:"hidden_size"` - ImageSize int32 `json:"image_size"` - IntermediateSize int32 `json:"intermediate_size"` - NumAttentionHeads int32 `json:"num_attention_heads"` - NumHiddenLayers int32 `json:"num_hidden_layers"` - PatchSize int32 `json:"patch_size"` -} - -// VisionTower is the SigLIP vision encoder -type VisionTower struct { - Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"` - Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"` - PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"` - Config *VisionConfig -} - -// VisionEmbeddings handles patch and position embeddings -type VisionEmbeddings struct { - // PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX - PatchWeight *mlx.Array `weight:"patch_embedding.weight"` - PatchBias *mlx.Array `weight:"patch_embedding.bias"` - PosEmbed *nn.Embedding `weight:"position_embedding"` -} - -// VisionEncoderLayer is a single transformer encoder layer -type VisionEncoderLayer struct { - LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"` - Attention *VisionAttention `weight:"self_attn"` - LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"` - MLP *VisionMLP `weight:"mlp"` -} - -// VisionAttention implements multi-head self-attention -type VisionAttention struct { - QProj *nn.Linear `weight:"q_proj"` - KProj *nn.Linear `weight:"k_proj"` - VProj *nn.Linear `weight:"v_proj"` - OutProj *nn.Linear `weight:"out_proj"` -} - -// VisionMLP is the feed-forward network -type VisionMLP struct { - FC1 *nn.Linear `weight:"fc1"` - FC2 *nn.Linear `weight:"fc2"` -} - -// Forward runs the vision tower on preprocessed images -// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX) -// Output: [B, num_patches, hidden_size] -func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array { - // Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O] - // Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C] - weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1) - h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding - - // Add bias: [O] -> [1, 1, 1, O] for broadcasting - bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0]) - h = mlx.Add(h, bias) - - // h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden] - B := h.Shape()[0] - gridH, gridW := h.Shape()[1], h.Shape()[2] - hidden := h.Shape()[3] - numPatches := gridH * gridW - h = mlx.Reshape(h, B, numPatches, hidden) - - // Add position embeddings - posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32) - posEmbed := v.Embeddings.PosEmbed.Forward(posIds) - h = mlx.Add(h, posEmbed) - - // Encoder layers - headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads) - scale := float32(1.0 / math.Sqrt(float64(headDim))) - for _, layer := range v.Encoder { - h = layer.Forward(h, v.Config, scale) - } - - // Final layer norm - h = v.PostLayerNorm.Forward(h) - - return h -} - -// Forward runs a vision encoder layer -func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array { - // Pre-norm attention - h := l.LayerNorm1.Forward(x) - h = l.Attention.Forward(h, cfg, scale) - x = mlx.Add(x, h) - - // Pre-norm MLP - h = l.LayerNorm2.Forward(x) - h = l.MLP.Forward(h) - return mlx.Add(x, h) -} - -// Forward runs multi-head self-attention -func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array { - B, L := x.Shape()[0], x.Shape()[1] - headDim := cfg.HiddenSize / cfg.NumAttentionHeads - - q := a.QProj.Forward(x) - k := a.KProj.Forward(x) - v := a.VProj.Forward(x) - - // Reshape to [B, num_heads, L, head_dim] - q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3) - k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3) - v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3) - - // Scaled dot-product attention (no causal mask for vision) - out := mlx.ScaledDotProductAttention(q, k, v, scale, false) - - // Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden] - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize) - - return a.OutProj.Forward(out) -} - -// Forward runs the MLP with GELU activation -func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array { - h := mlx.GELU(m.FC1.Forward(x)) - return m.FC2.Forward(h) -} diff --git a/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go deleted file mode 100644 index 3931693b8..000000000 --- a/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go +++ /dev/null @@ -1,840 +0,0 @@ -//go:build mlx - -// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX. -// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE). -package glm4_moe_lite - -import ( - "encoding/json" - "fmt" - "math" - - "github.com/ollama/ollama/x/imagegen/cache" - "github.com/ollama/ollama/x/imagegen/manifest" - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/nn" - "github.com/ollama/ollama/x/imagegen/safetensors" - "github.com/ollama/ollama/x/imagegen/tokenizer" -) - -// RopeScaling holds RoPE scaling configuration -type RopeScaling struct { - Factor float32 `json:"factor"` - MscaleAllDim float32 `json:"mscale_all_dim"` -} - -// Config holds GLM4-MoE-Lite model configuration -type Config struct { - HiddenSize int32 `json:"hidden_size"` - NumHiddenLayers int32 `json:"num_hidden_layers"` - IntermediateSize int32 `json:"intermediate_size"` - MoEIntermediateSize int32 `json:"moe_intermediate_size"` - NumAttentionHeads int32 `json:"num_attention_heads"` - NumKeyValueHeads int32 `json:"num_key_value_heads"` - VocabSize int32 `json:"vocab_size"` - RMSNormEps float32 `json:"rms_norm_eps"` - RopeTheta float32 `json:"rope_theta"` - MaxPositionEmbeddings int32 `json:"max_position_embeddings"` - AttentionBias bool `json:"attention_bias"` - - // MLA (Multi-head Latent Attention) parameters - QLoraRank int32 `json:"q_lora_rank"` - KVLoraRank int32 `json:"kv_lora_rank"` - QKRopeHeadDim int32 `json:"qk_rope_head_dim"` - QKNopeHeadDim int32 `json:"qk_nope_head_dim"` - VHeadDim int32 `json:"v_head_dim"` - - // MoE parameters - NRoutedExperts int32 `json:"n_routed_experts"` - NSharedExperts int32 `json:"n_shared_experts"` - NumExpertsPerTok int32 `json:"num_experts_per_tok"` - RoutedScalingFactor float32 `json:"routed_scaling_factor"` - NormTopKProb bool `json:"norm_topk_prob"` - FirstKDenseReplace int32 `json:"first_k_dense_replace"` - NGroup int32 `json:"n_group"` - TopKGroup int32 `json:"topk_group"` - - // RoPE scaling - RopeScaling *RopeScaling `json:"rope_scaling"` - - // Quantization parameters (set during load based on model quantization) - QuantGroupSize int `json:"-"` // Group size for quantization (default 64) - QuantBits int `json:"-"` // Bits per weight (4 or 8) - QuantMode string `json:"-"` // Quantization mode ("affine", etc.) - - // Computed fields - QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim - Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment -} - -// MLAAttention implements Multi-head Latent Attention with absorption. -// This uses absorbed MLA which operates in latent space for reduced KV cache. -type MLAAttention struct { - // Low-rank query projections - QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"` - QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"` - QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"` - - // Low-rank KV projections (with shared rope component) - KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"` - KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"` - - // Absorbed MLA projections (derived from kv_b_proj) - // EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim] - // UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank] - EmbedQ *nn.MultiLinear `weight:"-"` - UnembedOut *nn.MultiLinear `weight:"-"` - - // Output projection - OProj nn.LinearLayer `weight:"self_attn.o_proj"` -} - -// Forward computes absorbed MLA attention output. -// This operates in latent space for reduced KV cache memory. -func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - // Query path: q_a_proj -> layernorm -> q_b_proj - q := a.QAProj.Forward(x) - q = a.QALayerNorm.Forward(q, cfg.RMSNormEps) - q = a.QBProj.Forward(q) - - // Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim] - q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim) - q = mlx.Transpose(q, 0, 2, 1, 3) - - // Split Q into nope and rope parts - qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim}) - qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim}) - - // KV path: get compressed KV and k_pe - compressedKV := a.KVAProjWithMQA.Forward(x) - - // Split into compressed_kv and k_pe (shared rope component) - kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank}) - kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim}) - - // k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim] - kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim) - kPE = mlx.Transpose(kPE, 0, 2, 1, 3) - - // Apply layernorm to get kv latent representation - kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps) - // kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting - kvLatent = mlx.ExpandDims(kvLatent, 1) - - // Apply RoPE to the rope parts - offset := 0 - if c != nil { - offset = c.Offset() - } - qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) - kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) - - // ABSORBED MLA: project q_nope to latent space - // qNope: [B, num_heads, L, qk_nope_head_dim] - // EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim] - // Result: [B, num_heads, L, kv_lora_rank] - qLatent := a.EmbedQ.Forward(qNope) - - // Keys = concat(kvLatent, kPE) - // kvLatent: [B, 1, L, kv_lora_rank] - // kPE: [B, 1, L, qk_rope_head_dim] - // keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim] - keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3) - - // Cache the smaller latent representation - // We cache keys (latent + rope) and use empty values since values are derived from keys - cachedL := L - if c != nil { - // Create placeholder values with 0 dims for cache (we don't actually use cached values) - placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32) - keys, _ = c.Update(keys, placeholderValues, int(L)) - cachedL = int32(keys.Shape()[2]) - } - - // Values are the first kv_lora_rank dims of keys (slice off rope part) - values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank}) - - // Queries = concat(qLatent, qPE) - // qLatent: [B, num_heads, L, kv_lora_rank] - // qPE: [B, num_heads, L, qk_rope_head_dim] - // queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim] - queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3) - - // Attention in latent space - // queries: [B, num_heads, L, kv_lora_rank + rope_dim] - // keys: [B, 1, cachedL, kv_lora_rank + rope_dim] - // values: [B, 1, cachedL, kv_lora_rank] - out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1) - - // ABSORBED MLA: unembed from latent space - // out: [B, num_heads, L, kv_lora_rank] - // UnembedOut: [num_heads, v_head_dim, kv_lora_rank] - // Result: [B, num_heads, L, v_head_dim] - out = a.UnembedOut.Forward(out) - - // Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim] - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim) - - return a.OProj.Forward(out) -} - -// DenseMLP implements the standard SwiGLU MLP for dense layers -type DenseMLP struct { - GateProj nn.LinearLayer `weight:"mlp.gate_proj"` - UpProj nn.LinearLayer `weight:"mlp.up_proj"` - DownProj nn.LinearLayer `weight:"mlp.down_proj"` -} - -// Forward applies the SwiGLU MLP -func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array { - gate := mlx.SiLU(m.GateProj.Forward(x)) - up := m.UpProj.Forward(x) - return m.DownProj.Forward(mlx.Mul(gate, up)) -} - -// MoEGate implements the expert gating mechanism -type MoEGate struct { - Gate nn.LinearLayer `weight:"mlp.gate"` - EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"` -} - -// Forward computes expert selection indices and scores -func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) { - // Compute gate logits through linear layer (handles both quantized and non-quantized) - gates := g.Gate.Forward(x) - - // Sigmoid scoring - scores := mlx.Sigmoid(gates) - origScores := scores - - // Add correction bias if present - if g.EScoreCorrectionBias != nil { - scores = mlx.Add(scores, g.EScoreCorrectionBias) - } - - // Group-wise expert selection (simplified for n_group=1) - // Select top-k experts - topK := cfg.NumExpertsPerTok - negScores := mlx.Neg(scores) - inds := mlx.Argpartition(negScores, int(topK)-1, -1) - - shape := inds.Shape() - inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK}) - - // Get scores for selected experts - scores = mlx.TakeAlongAxis(origScores, inds, -1) - - // Normalize if configured - if topK > 1 && cfg.NormTopKProb { - sumScores := mlx.Sum(scores, -1, true) - scores = mlx.Div(scores, sumScores) - } - - // Apply routing scaling factor - scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor) - - return inds, scores -} - -// SwitchMLP implements the MoE expert computation using stacked weights -// Note: No weight tags - these are populated manually by stacking expert weights -type SwitchMLP struct { - // Dequantized weights (used when GatherQMM not available) - GateWeight *mlx.Array - UpWeight *mlx.Array - DownWeight *mlx.Array - - // Quantized weights (used with GatherQMM for 4/8-bit affine) - GateWeightQ, GateScales, GateBiases *mlx.Array - UpWeightQ, UpScales, UpBiases *mlx.Array - DownWeightQ, DownScales, DownBiases *mlx.Array - - // Quantization bits per projection (supports mixed precision Q4/Q8) - GateBits int - UpBits int - DownBits int - - // Quantization group size per projection (detected from tensor shapes) - GateGroupSize int - UpGroupSize int - DownGroupSize int - - // If true, use GatherQMM with quantized weights - UseQuantized bool -} - -// Forward applies the switched expert MLP -func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array { - shape := x.Shape() - B, L := shape[0], shape[1] - topK := cfg.NumExpertsPerTok - - // Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D] - xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2) - - // Flatten for gather_mm: [B*L, 1, 1, D] - xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize) - - // Flatten indices: [B, L, topK] -> [B*L, topK] - idxFlat := mlx.Reshape(indices, B*L, topK) - - // Sort for efficient gather (when we have many tokens) - doSort := B*L >= 64 - var invOrder *mlx.Array - n := B * L * topK - - if doSort { - idxAll := mlx.Flatten(idxFlat) - order := mlx.Argsort(idxAll, 0) - invOrder = mlx.Argsort(order, 0) - // Reorder x based on sorted indices - xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1) - idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) - } - - var gate, up, hidden, down *mlx.Array - - if s.UseQuantized { - // Use GatherQMM for quantized weights (faster, keeps weights quantized) - // Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down) - gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases, - nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort) - up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, - nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) - - hidden = mlx.Mul(mlx.SiLU(gate), up) - - down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, - nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) - } else { - // Use GatherMM for dequantized/non-quantized weights - gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort) - up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort) - - hidden = mlx.Mul(mlx.SiLU(gate), up) - - down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort) - } - - // Unsort if we sorted - if doSort { - down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize) - } else { - down = mlx.Squeeze(down, 2) - } - - return mlx.Reshape(down, B, L, topK, cfg.HiddenSize) -} - -// SharedExperts implements the shared expert MLP -type SharedExperts struct { - GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"` - UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"` - DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"` -} - -// Forward applies the shared expert MLP -func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array { - gate := mlx.SiLU(s.GateProj.Forward(x)) - up := s.UpProj.Forward(x) - return s.DownProj.Forward(mlx.Mul(gate, up)) -} - -// MoE implements the full Mixture of Experts layer -type MoE struct { - Gate *MoEGate - SwitchMLP *SwitchMLP - SharedExperts *SharedExperts -} - -// Forward applies the MoE layer -func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array { - shape := x.Shape() - B, L := shape[0], shape[1] - - // Get expert indices and scores - inds, scores := m.Gate.Forward(x, cfg) - - // Apply routed experts - expertOut := m.SwitchMLP.Forward(x, inds, cfg) - - // Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK - scoresExpanded := mlx.ExpandDims(scores, -1) - y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false) - - // Add shared experts if present - if m.SharedExperts != nil { - y = mlx.Add(y, m.SharedExperts.Forward(x)) - } - - return mlx.Reshape(y, B, L, cfg.HiddenSize) -} - -// DenseBlock represents a dense transformer block (for first_k_dense_replace layers) -type DenseBlock struct { - Attention *MLAAttention - MLP *DenseMLP - InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"` - PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"` -} - -// Forward applies the dense block -func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - // Pre-norm attention with residual - r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) - h := mlx.Add(x, r) - - // Pre-norm MLP with residual - r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps)) - return mlx.Add(h, r) -} - -// MoEBlock represents a MoE transformer block -type MoEBlock struct { - Attention *MLAAttention - MoE *MoE - InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"` - PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"` -} - -// Forward applies the MoE block -func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - // Pre-norm attention with residual - r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) - h := mlx.Add(x, r) - - // Pre-norm MoE with residual - r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg) - return mlx.Add(h, r) -} - -// Block interface for both dense and MoE blocks -type Block interface { - Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array -} - -// Model represents the complete GLM4-MoE-Lite model -type Model struct { - EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` - Layers []Block `weight:"-"` // Loaded manually due to different block types - Norm *nn.RMSNorm `weight:"model.norm"` - LMHead nn.LinearLayer `weight:"lm_head"` - - tok *tokenizer.Tokenizer - *Config -} - -// computeScale computes the attention scale. -// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner. -func computeScale(cfg *Config) float32 { - keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim - scale := float32(1.0 / math.Sqrt(float64(keyLength))) - if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 { - s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0 - scale *= s * s - } - return scale -} - -// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support. -// Currently only 4-bit and 8-bit affine quantization are supported. -func supportsGatherQMM(mode string, bits int) bool { - return mode == "affine" && (bits == 4 || bits == 8) -} - -// ExpertWeight holds a single expert's weight with optional quantization components. -type ExpertWeight struct { - Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight - Scales *mlx.Array // Quantization scales (nil if not quantized) - Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases) - Bits int // Quantization bits (4 or 8), 0 if not quantized - GroupSize int // Quantization group size, 0 if not quantized -} - -// getQuantParams returns quantization parameters from model metadata. -// Returns groupSize, bits, and mode for the model's quantization type. -func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) { - groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization()) - // Use metadata group_size if available (overrides default) - if gs := weights.GroupSize(); gs > 0 { - groupSize = gs - } - return groupSize, bits, mode -} - -// loadExpertWeight loads an expert weight. -// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components. -// Otherwise dequantizes and returns only the weight. -func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight { - w, _ := weights.GetTensor(path + ".weight") - if w == nil { - return nil - } - - // Check if this is a quantized weight by looking for scales - scalePath := path + ".weight_scale" - if weights.HasTensor(scalePath) { - scales, _ := weights.GetTensor(scalePath) - var qbiases *mlx.Array - qbiasPath := path + ".weight_qbias" - if weights.HasTensor(qbiasPath) { - qbiases, _ = weights.GetTensor(qbiasPath) - } - - // Get quantization params from metadata - groupSize, bits, mode := getQuantParams(weights) - - // Update config with group size (for GatherQMM calls) - if cfg.QuantGroupSize == 0 { - cfg.QuantGroupSize = groupSize - } - - // If GatherQMM is supported and requested, return quantized components - if useQuantized && supportsGatherQMM(mode, bits) { - return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize} - } - - // Otherwise dequantize - return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)} - } - - return &ExpertWeight{Weight: w} -} - -// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format. -// Returns embed_q and unembed_out weights for per-head projections. -// -// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] -// Output: -// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent -// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output -func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) { - path := prefix + ".self_attn.kv_b_proj" - w, err := weights.GetTensor(path + ".weight") - if err != nil || w == nil { - return nil, nil - } - - // Check if quantized and dequantize - scalePath := path + ".weight_scale" - if weights.HasTensor(scalePath) { - scales, _ := weights.GetTensor(scalePath) - var qbiases *mlx.Array - qbiasPath := path + ".weight_qbias" - if weights.HasTensor(qbiasPath) { - qbiases, _ = weights.GetTensor(qbiasPath) - } - - groupSize, bits, mode := getQuantParams(weights) - w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode) - } - - // w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] - // Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank] - headDim := cfg.QKNopeHeadDim + cfg.VHeadDim - w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank) - - // Split into wk and wv - // wk: [num_heads, qk_nope_head_dim, kv_lora_rank] - // wv: [num_heads, v_head_dim, kv_lora_rank] - wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank}) - wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank}) - - // Transform for absorbed MLA: - // embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim] - // This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection) - embedQ := mlx.Transpose(wk, 0, 2, 1) - - // unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank] - // This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection) - unembedOut := wv - - return embedQ, unembedOut -} - -// StackedExpertWeights holds stacked weights for all experts. -type StackedExpertWeights struct { - Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed] - Scales *mlx.Array // Stacked scales (nil if not quantized) - Biases *mlx.Array // Stacked biases (nil if not quantized) - Bits int // Quantization bits (4 or 8), 0 if not quantized - GroupSize int // Quantization group size, 0 if not quantized -} - -// collectAndStackExpertWeights loads and stacks expert weights for one projection type. -func collectAndStackExpertWeights( - weights safetensors.WeightSource, - prefix string, - projName string, - numExperts int32, - useQuantized bool, - cfg *Config, -) *StackedExpertWeights { - var w, s, b []*mlx.Array - var bits, groupSize int - - for e := int32(0); e < numExperts; e++ { - path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName) - ew := loadExpertWeight(weights, path, useQuantized, cfg) - if ew == nil { - continue - } - w = append(w, ew.Weight) - if ew.Scales != nil { - s = append(s, ew.Scales) - } - if ew.Biases != nil { - b = append(b, ew.Biases) - } - if e == 0 { - bits = ew.Bits - groupSize = ew.GroupSize - } - } - - result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize} - if len(w) > 0 { - result.Weight = mlx.Stack(w, 0) - if len(s) > 0 { - result.Scales = mlx.Stack(s, 0) - } - if len(b) > 0 { - result.Biases = mlx.Stack(b, 0) - } - } - return result -} - -// sanitizeExpertWeights stacks individual expert weights into tensors. -// If useQuantized is true and weights support GatherQMM, returns quantized components. -// Otherwise returns dequantized weights with nil scales/biases. -// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down). -func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) { - gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg) - up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg) - down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg) - return gate, up, down -} - -// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage). -func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { - // Read config from manifest - configData, err := modelManifest.ReadConfig("config.json") - if err != nil { - return nil, fmt.Errorf("load config: %w", err) - } - - var cfg Config - if err := json.Unmarshal(configData, &cfg); err != nil { - return nil, fmt.Errorf("parse config: %w", err) - } - - // Compute derived fields - cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim - cfg.Scale = computeScale(&cfg) - - // Load weights from manifest blobs - weights, err := manifest.LoadWeightsFromManifest(modelManifest, "") - if err != nil { - return nil, fmt.Errorf("load weights: %w", err) - } - - if err := weights.Load(0); err != nil { - return nil, fmt.Errorf("load weight data: %w", err) - } - - // Set up quantization parameters (only if model is actually quantized) - // Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading - quantization := weights.Quantization() - useQuantized := false - if quantization != "" { - _, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization) - useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) - } - - // Load tokenizer from manifest with config files for EOS token detection - tokData, err := modelManifest.ReadConfig("tokenizer.json") - if err != nil { - return nil, fmt.Errorf("load tokenizer config: %w", err) - } - - // Build tokenizer config with companion files for EOS/BOS token loading - tokConfig := &tokenizer.TokenizerConfig{ - ConfigJSON: configData, // Already loaded above, contains eos_token_id - } - - // Try to load generation_config.json if available (preferred source for EOS) - if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil { - tokConfig.GenerationConfigJSON = genConfigData - } - - // Try to load tokenizer_config.json if available - if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil { - tokConfig.TokenizerConfigJSON = tokConfigData - } - - tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) - if err != nil { - return nil, fmt.Errorf("parse tokenizer: %w", err) - } - - m := &Model{ - Layers: make([]Block, cfg.NumHiddenLayers), - Config: &cfg, - tok: tok, - } - - // Load embedding, norm, and lm_head - if err := safetensors.LoadModule(m, weights, ""); err != nil { - return nil, err - } - - // Load layers manually due to different block types - for i := int32(0); i < cfg.NumHiddenLayers; i++ { - prefix := fmt.Sprintf("model.layers.%d", i) - - // Load attention (same for both block types) - attn := &MLAAttention{} - if err := safetensors.LoadModule(attn, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d attention: %w", i, err) - } - - // Sanitize MLA weights for absorbed attention - embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg) - attn.EmbedQ = nn.NewMultiLinear(embedQ) - attn.UnembedOut = nn.NewMultiLinear(unembedOut) - - if i < cfg.FirstKDenseReplace { - // Dense block - block := &DenseBlock{Attention: attn} - if err := safetensors.LoadModule(block, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d dense: %w", i, err) - } - m.Layers[i] = block - } else { - // MoE block - block := &MoEBlock{Attention: attn} - if err := safetensors.LoadModule(block, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d moe block: %w", i, err) - } - - // Stack expert weights (pass cfg so group sizes can be detected) - gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg) - - switchMLP := &SwitchMLP{UseQuantized: useQuantized} - if useQuantized { - switchMLP.GateWeightQ = gate.Weight - switchMLP.GateScales = gate.Scales - switchMLP.GateBiases = gate.Biases - switchMLP.GateBits = gate.Bits - switchMLP.GateGroupSize = gate.GroupSize - switchMLP.UpWeightQ = up.Weight - switchMLP.UpScales = up.Scales - switchMLP.UpBiases = up.Biases - switchMLP.UpBits = up.Bits - switchMLP.UpGroupSize = up.GroupSize - switchMLP.DownWeightQ = down.Weight - switchMLP.DownScales = down.Scales - switchMLP.DownBiases = down.Biases - switchMLP.DownBits = down.Bits - switchMLP.DownGroupSize = down.GroupSize - } else { - switchMLP.GateWeight = gate.Weight - switchMLP.UpWeight = up.Weight - switchMLP.DownWeight = down.Weight - } - - block.MoE = &MoE{ - Gate: &MoEGate{}, - SwitchMLP: switchMLP, - } - - // Load gate weights - if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d gate: %w", i, err) - } - - // Load shared experts if present - if cfg.NSharedExperts > 0 { - block.MoE.SharedExperts = &SharedExperts{} - if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d shared experts: %w", i, err) - } - } - - m.Layers[i] = block - } - } - - mlx.Eval(mlx.Collect(m)...) - weights.ReleaseAll() - - return m, nil -} - -// Forward computes the forward pass of the model -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - B, L := tokens.Shape()[0], tokens.Shape()[1] - - h := m.EmbedTokens.Forward(tokens) - - for i, layer := range m.Layers { - var c cache.Cache - if caches != nil { - c = caches[i] - } - h = layer.Forward(h, c, B, L, m.Config) - } - - h = m.Norm.Forward(h, m.RMSNormEps) - return m.LMHead.Forward(h) -} - -// Interface methods - -// NumLayers returns the number of transformer layers -func (m *Model) NumLayers() int { return len(m.Layers) } - -// MaxContextLength returns the maximum context length -func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings } - -// VocabSize returns the vocabulary size -func (m *Model) VocabSize() int32 { return m.Config.VocabSize } - -// Tokenizer returns the model's tokenizer -func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } - -// NewCache creates a new KV cache for the model -func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { - caches := make([]cache.Cache, len(m.Layers)) - for i := range caches { - caches[i] = cache.NewKVCache() - } - return caches -} - -// FormatPrompt applies the GLM-4 chat template with thinking enabled by default. -// This follows the GLM-4.7 format with tag for reasoning mode. -func (m *Model) FormatPrompt(prompt string) string { - return "[gMASK]<|user|>" + prompt + "<|assistant|>" -} - -// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control. -// When think is true, the prompt ends with to enable reasoning mode. -// When think is false, the prompt ends with to skip reasoning. -func (m *Model) FormatPromptWithThinking(prompt string, think bool) string { - if think { - return "[gMASK]<|user|>" + prompt + "<|assistant|>" - } - return "[gMASK]<|user|>" + prompt + "<|assistant|>" -} - -// NewRenderer returns a new Renderer for formatting multi-turn conversations. -func (m *Model) NewRenderer() *Renderer { - return &Renderer{} -} - -// NewParser returns a new Parser for extracting thinking and tool calls from output. -func (m *Model) NewParser() *Parser { - return &Parser{} -} diff --git a/x/imagegen/models/glm4_moe_lite/parser.go b/x/imagegen/models/glm4_moe_lite/parser.go deleted file mode 100644 index c81ec5a40..000000000 --- a/x/imagegen/models/glm4_moe_lite/parser.go +++ /dev/null @@ -1,479 +0,0 @@ -//go:build mlx - -package glm4_moe_lite - -import ( - "context" - "encoding/json" - "encoding/xml" - "fmt" - "log/slog" - "strings" - "unicode" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/logutil" -) - -type parserState int - -const ( - parserState_LookingForThinkingOpen parserState = iota - parserState_ThinkingStartedEatingWhitespace - parserState_CollectingThinking - parserState_ThinkingDoneEatingWhitespace - parserState_CollectingContent - parserState_ToolStartedEatingWhitespace - parserState_CollectingToolContent -) - -const ( - thinkingOpenTag = "" - thinkingCloseTag = "" - toolOpenTag = "" - toolCloseTag = "" -) - -// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls. -// GLM-4's prompt ends with when thinking is enabled, so the parser -// must start in CollectingThinking state (the model outputs thinking content directly). -type Parser struct { - state parserState - buffer strings.Builder - tools []api.Tool -} - -// HasToolSupport returns true as GLM4 supports tool calling. -func (p *Parser) HasToolSupport() bool { - return true -} - -// HasThinkingSupport returns true as GLM4 supports thinking mode. -func (p *Parser) HasThinkingSupport() bool { - return true -} - -// Init initializes the parser with tools and thinking configuration. -func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { - p.tools = tools - // When thinking is enabled (nil or true), the prompt ends with , - // so model output starts directly with thinking content (no opening tag). - if thinkValue == nil || thinkValue.Bool() { - p.state = parserState_CollectingThinking - } - return tools -} - -type parserEvent interface { - isParserEvent() -} - -type eventContent struct { - content string -} - -func (eventContent) isParserEvent() {} - -type eventRawToolCall struct { - raw string -} - -func (eventRawToolCall) isParserEvent() {} - -type eventThinkingContent struct { - content string -} - -func (eventThinkingContent) isParserEvent() {} - -// Add processes new output text and returns parsed content, thinking, and tool calls. -func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { - p.buffer.WriteString(s) - events := p.parseEvents() - - var toolCalls []api.ToolCall - var contentSb strings.Builder - var thinkingSb strings.Builder - - for _, event := range events { - switch event := event.(type) { - case eventRawToolCall: - toolCall, err := parseToolCall(event, p.tools) - if err != nil { - slog.Warn("glm-4 tool call parsing failed", "error", err) - return "", "", nil, err - } - toolCalls = append(toolCalls, toolCall) - case eventThinkingContent: - thinkingSb.WriteString(event.content) - case eventContent: - contentSb.WriteString(event.content) - } - } - - return contentSb.String(), thinkingSb.String(), toolCalls, nil -} - -func (p *Parser) parseEvents() []parserEvent { - var all []parserEvent - - keepLooping := true - for keepLooping { - var events []parserEvent - events, keepLooping = p.eat() - if len(events) > 0 { - all = append(all, events...) - } - } - - if len(all) > 0 { - slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String()) - } - - return all -} - -// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer -// and transitions to the next state. Returns (nil, false) if only whitespace remains -// in the buffer (needs more input), or (nil, true) if we successfully transitioned. -func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) { - trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) - p.buffer.Reset() - if trimmed == "" { - return nil, false // Still only whitespace, keep waiting for more input - } - p.state = nextState - p.buffer.WriteString(trimmed) - return nil, true // Successfully transitioned -} - -// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace), -// the content after (optionally trimmed of leading whitespace), and updates the buffer -func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) { - split := strings.SplitN(p.buffer.String(), tag, 2) - before := split[0] - before = strings.TrimRightFunc(before, unicode.IsSpace) - after := split[1] - if trimAfter { - after = strings.TrimLeftFunc(after, unicode.IsSpace) - } - p.buffer.Reset() - p.buffer.WriteString(after) - return before, after -} - -func (p *Parser) eat() ([]parserEvent, bool) { - var events []parserEvent - - switch p.state { - case parserState_LookingForThinkingOpen: - trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) - if strings.HasPrefix(trimmed, thinkingOpenTag) { - // Found opening tag - after := strings.TrimPrefix(trimmed, thinkingOpenTag) - after = strings.TrimLeftFunc(after, unicode.IsSpace) - p.buffer.Reset() - p.buffer.WriteString(after) - if after == "" { - p.state = parserState_ThinkingStartedEatingWhitespace - } else { - p.state = parserState_CollectingThinking - } - return events, true - } else if strings.HasPrefix(thinkingOpenTag, trimmed) { - // Partial opening tag seen, keep accumulating - return events, false - } else if trimmed == "" { - // Only whitespace, keep accumulating - return events, false - } else { - // No thinking tag found, skip to content collection - p.state = parserState_CollectingContent - // Don't trim - we want to keep the original content - return events, true - } - - case parserState_ThinkingStartedEatingWhitespace: - return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking) - - case parserState_CollectingThinking: - acc := p.buffer.String() - if strings.Contains(acc, thinkingCloseTag) { - thinking, remaining := p.splitAtTag(thinkingCloseTag, true) - if len(thinking) > 0 { - events = append(events, eventThinkingContent{content: thinking}) - } - if remaining == "" { - p.state = parserState_ThinkingDoneEatingWhitespace - } else { - p.state = parserState_CollectingContent - } - return events, true - } else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 { - // Partial closing tag - withhold it along with any trailing whitespace before it - beforePartialTag := acc[:len(acc)-overlapLen] - trailingWsLen := trailingWhitespaceLen(beforePartialTag) - ambiguousStart := len(beforePartialTag) - trailingWsLen - - unambiguous := acc[:ambiguousStart] - ambiguous := acc[ambiguousStart:] - p.buffer.Reset() - p.buffer.WriteString(ambiguous) - if len(unambiguous) > 0 { - events = append(events, eventThinkingContent{content: unambiguous}) - } - return events, false - } else { - // Pure thinking content - withhold trailing whitespace (might precede closing tag) - whitespaceLen := trailingWhitespaceLen(acc) - ambiguousStart := len(acc) - whitespaceLen - - unambiguous := acc[:ambiguousStart] - ambiguous := acc[ambiguousStart:] - p.buffer.Reset() - p.buffer.WriteString(ambiguous) - if len(unambiguous) > 0 { - events = append(events, eventThinkingContent{content: unambiguous}) - } - return events, false - } - - case parserState_ThinkingDoneEatingWhitespace: - return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent) - - case parserState_CollectingContent: - if strings.Contains(p.buffer.String(), toolOpenTag) { - before, after := p.splitAtTag(toolOpenTag, true) - if len(before) > 0 { - events = append(events, eventContent{content: before}) - } - if after == "" { - p.state = parserState_ToolStartedEatingWhitespace - } else { - p.state = parserState_CollectingToolContent - } - return events, true - } else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 { - beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen] - trailingWsLen := trailingWhitespaceLen(beforePartialTag) - ambiguousStart := len(beforePartialTag) - trailingWsLen - - unambiguous := p.buffer.String()[:ambiguousStart] - ambiguous := p.buffer.String()[ambiguousStart:] - p.buffer.Reset() - p.buffer.WriteString(ambiguous) - if len(unambiguous) > 0 { - events = append(events, eventContent{content: unambiguous}) - } - return events, false - } else { - whitespaceLen := trailingWhitespaceLen(p.buffer.String()) - ambiguousStart := len(p.buffer.String()) - whitespaceLen - - unambiguous := p.buffer.String()[:ambiguousStart] - ambiguous := p.buffer.String()[ambiguousStart:] - p.buffer.Reset() - p.buffer.WriteString(ambiguous) - if len(unambiguous) > 0 { - events = append(events, eventContent{content: unambiguous}) - } - return events, false - } - - case parserState_ToolStartedEatingWhitespace: - return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent) - - case parserState_CollectingToolContent: - acc := p.buffer.String() - if strings.Contains(acc, toolCloseTag) { - toolContent, _ := p.splitAtTag(toolCloseTag, true) - if len(toolContent) == 0 { - slog.Warn("glm4 tool call closing tag found but no content before it") - } - events = append(events, eventRawToolCall{raw: toolContent}) - p.state = parserState_CollectingContent - return events, true - } else { - // Keep accumulating - tool calls are not streamed - // We just wait for the closing tag - return events, false - } - - default: - panic("unreachable") - } -} - -// overlap returns the length of the overlap between the end of s and the start of tag. -func overlap(s, tag string) int { - for i := 1; i <= len(tag) && i <= len(s); i++ { - if strings.HasSuffix(s, tag[:i]) { - return i - } - } - return 0 -} - -// trailingWhitespaceLen returns the length of trailing whitespace in s. -func trailingWhitespaceLen(s string) int { - trimmed := strings.TrimRightFunc(s, unicode.IsSpace) - return len(s) - len(trimmed) -} - -// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing -type ToolCallXML struct { - XMLName xml.Name `xml:"tool_call"` - Content string `xml:",chardata"` // Function name (text nodes between tags) - Keys []string `xml:"arg_key"` // All arg_key elements in document order - Values []string `xml:"arg_value"` // All arg_value elements in document order -} - -// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags -func escapeContent(s string) string { - var result strings.Builder - inTag := false - - for i := range len(s) { - ch := s[i] - - if ch == '<' { - // Check if this is a known tag - if strings.HasPrefix(s[i:], "") || - strings.HasPrefix(s[i:], "") || - strings.HasPrefix(s[i:], "") || - strings.HasPrefix(s[i:], "") { - inTag = true - } - } - - if inTag { - result.WriteByte(ch) - if ch == '>' { - inTag = false - } - } else { - // Escape special characters in text content - switch ch { - case '&': - result.WriteString("&") - case '<': - result.WriteString("<") - case '>': - result.WriteString(">") - default: - result.WriteByte(ch) - } - } - } - - return result.String() -} - -func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) { - // Escape any unescaped entities in text content - escaped := escapeContent(raw.raw) - - // Wrap the content in a root element to make it valid XML - xmlString := "" + escaped + "" - - // Parse XML into struct - var parsed ToolCallXML - if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil { - return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err) - } - - // Extract and trim function name - functionName := strings.TrimSpace(parsed.Content) - if functionName == "" { - return api.ToolCall{}, fmt.Errorf("empty function name") - } - - // Verify keys and values are paired correctly - if len(parsed.Keys) != len(parsed.Values) { - return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values)) - } - - // Find the matching tool to get parameter types - var matchedTool *api.Tool - for i := range tools { - if tools[i].Function.Name == functionName { - matchedTool = &tools[i] - break - } - } - - // Build arguments map by pairing keys and values - toolCall := api.ToolCall{ - Function: api.ToolCallFunction{ - Name: functionName, - Arguments: api.NewToolCallFunctionArguments(), - }, - } - - for i := range parsed.Keys { - key := strings.TrimSpace(parsed.Keys[i]) - value := parsed.Values[i] // Don't trim here - parseValue handles it - - // Look up parameter type - var paramType api.PropertyType - if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { - if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok { - // Handle anyOf by collecting all types from the union - if len(prop.AnyOf) > 0 { - for _, anyOfProp := range prop.AnyOf { - paramType = append(paramType, anyOfProp.Type...) - } - } else { - paramType = prop.Type - } - } - } - - // Parse value with type coercion - toolCall.Function.Arguments.Set(key, parseValue(value, paramType)) - } - - return toolCall, nil -} - -// parseValue parses a string value and coerces it to the appropriate type based on paramType. -func parseValue(value string, paramType api.PropertyType) any { - value = strings.TrimSpace(value) - - // If no type specified, return as string - if len(paramType) == 0 { - return value - } - - // Try to parse based on specified types - for _, t := range paramType { - switch t { - case "boolean": - if value == "true" { - return true - } - if value == "false" { - return false - } - case "integer": - var i int64 - if _, err := fmt.Sscanf(value, "%d", &i); err == nil { - return i - } - case "number": - var f float64 - if _, err := fmt.Sscanf(value, "%f", &f); err == nil { - return f - } - case "array", "object": - // Try to parse as JSON - var result any - if err := json.Unmarshal([]byte(value), &result); err == nil { - return result - } - } - } - - // Default to string - return value -} diff --git a/x/imagegen/models/glm4_moe_lite/parser_test.go b/x/imagegen/models/glm4_moe_lite/parser_test.go deleted file mode 100644 index 0ce382709..000000000 --- a/x/imagegen/models/glm4_moe_lite/parser_test.go +++ /dev/null @@ -1,192 +0,0 @@ -//go:build mlx - -package glm4_moe_lite - -import ( - "testing" - - "github.com/ollama/ollama/api" -) - -func TestParserThinking(t *testing.T) { - tests := []struct { - name string - input string - thinkEnabled bool - wantContent string - wantThinking string - wantToolCalls int - }{ - { - name: "thinking enabled - simple thinking then content", - input: "Let me think about this...Here is my answer.", - thinkEnabled: true, - wantThinking: "Let me think about this...", - wantContent: "Here is my answer.", - }, - { - name: "thinking enabled - only thinking", - input: "I need to consider multiple factors...", - thinkEnabled: true, - wantThinking: "I need to consider multiple factors...", - wantContent: "", - }, - { - name: "thinking disabled - direct content", - input: "Here is my direct answer.", - thinkEnabled: false, - wantThinking: "", - wantContent: "Here is my direct answer.", - }, - { - name: "thinking with tool call", - input: "Let me search for that...I'll use a tool.searchquerytest", - thinkEnabled: true, - wantThinking: "Let me search for that...", - wantContent: "I'll use a tool.", - wantToolCalls: 1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &Parser{} - - var thinkValue *api.ThinkValue - if tt.thinkEnabled { - thinkValue = &api.ThinkValue{Value: true} - } else { - thinkValue = &api.ThinkValue{Value: false} - } - - // Define tools for tool call tests - props := api.NewToolPropertiesMap() - props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}}) - tools := []api.Tool{ - { - Function: api.ToolFunction{ - Name: "search", - Parameters: api.ToolFunctionParameters{ - Properties: props, - }, - }, - }, - } - - p.Init(tools, nil, thinkValue) - - content, thinking, calls, err := p.Add(tt.input, true) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if thinking != tt.wantThinking { - t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking) - } - if content != tt.wantContent { - t.Errorf("content = %q, want %q", content, tt.wantContent) - } - if len(calls) != tt.wantToolCalls { - t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls) - } - }) - } -} - -func TestParserToolCall(t *testing.T) { - p := &Parser{} - - props := api.NewToolPropertiesMap() - props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}}) - props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}}) - tools := []api.Tool{ - { - Function: api.ToolFunction{ - Name: "get_weather", - Parameters: api.ToolFunctionParameters{ - Properties: props, - }, - }, - }, - } - - // Initialize with thinking disabled - tv := &api.ThinkValue{Value: false} - p.Init(tools, nil, tv) - - input := "get_weatherlocationSan Franciscounitcelsius" - - _, _, calls, err := p.Add(input, true) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(calls) != 1 { - t.Fatalf("expected 1 tool call, got %d", len(calls)) - } - - call := calls[0] - if call.Function.Name != "get_weather" { - t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather") - } - - location, ok := call.Function.Arguments.Get("location") - if !ok || location != "San Francisco" { - t.Errorf("location = %v, want %q", location, "San Francisco") - } - - unit, ok := call.Function.Arguments.Get("unit") - if !ok || unit != "celsius" { - t.Errorf("unit = %v, want %q", unit, "celsius") - } -} - -func TestOverlap(t *testing.T) { - tests := []struct { - s string - tag string - want int - }{ - {"hello<", "", 1}, - {"hello", 2}, - {"hello", 3}, - {"hello", 4}, - {"hello", 5}, - {"hello", 6}, - {"hello", 7}, - {"hello", "", 8}, // Complete tag at end returns full length - {"hello", "", 0}, - {"", "", 0}, - } - - for _, tt := range tests { - t.Run(tt.s+"_"+tt.tag, func(t *testing.T) { - got := overlap(tt.s, tt.tag) - if got != tt.want { - t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want) - } - }) - } -} - -func TestTrailingWhitespaceLen(t *testing.T) { - tests := []struct { - s string - want int - }{ - {"hello ", 3}, - {"hello\n\t ", 3}, - {"hello", 0}, - {"", 0}, - {" ", 3}, - } - - for _, tt := range tests { - t.Run(tt.s, func(t *testing.T) { - got := trailingWhitespaceLen(tt.s) - if got != tt.want { - t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want) - } - }) - } -} diff --git a/x/imagegen/models/glm4_moe_lite/render.go b/x/imagegen/models/glm4_moe_lite/render.go deleted file mode 100644 index 4998604bf..000000000 --- a/x/imagegen/models/glm4_moe_lite/render.go +++ /dev/null @@ -1,175 +0,0 @@ -//go:build mlx - -package glm4_moe_lite - -import ( - "encoding/json" - "fmt" - "strings" - - "github.com/ollama/ollama/api" -) - -// Renderer renders messages for GLM4-MoE-Lite models. -// -// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode): -// -// 1. INTERLEAVED THINKING -// The model thinks between tool calls and after receiving tool results. -// This enables complex step-by-step reasoning: interpreting each tool output -// before deciding what to do next. Thinking blocks are preserved and returned -// with tool results to maintain reasoning continuity. -// -// 2. PRESERVED THINKING -// The model retains reasoning content from previous assistant turns in context. -// This preserves reasoning continuity across multi-turn conversations. The -// upstream API has a "clear_thinking" parameter to control this: -// - clear_thinking=true: clears reasoning from previous turns (outputs ) -// - clear_thinking=false: preserves ... blocks from previous turns -// -// 3. TURN-LEVEL THINKING -// Controls whether the model should reason on each turn. The upstream API -// uses "enable_thinking" parameter: -// - enable_thinking=true: outputs to start reasoning -// - enable_thinking=false: outputs to skip reasoning -// -// OLLAMA DEFAULTS: -// - Thinking is ENABLED by default (thinkValue=nil or true outputs ) -// - Thinking is PRESERVED by default (reasoning content from previous turns is always -// included in ... blocks, equivalent to clear_thinking=false) -// - Users can disable thinking per-turn via thinkValue=false -type Renderer struct{} - -// Render renders messages into the GLM4 chat format. -func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { - var sb strings.Builder - - sb.WriteString("[gMASK]") - - if len(tools) > 0 { - sb.WriteString("<|system|>\n") - sb.WriteString("# Tools\n\n") - sb.WriteString("You may call one or more functions to assist with the user query.\n\n") - sb.WriteString("You are provided with function signatures within XML tags:\n") - sb.WriteString("\n") - for _, tool := range tools { - d, _ := json.Marshal(tool) - sb.WriteString(formatToolJSON(d)) - sb.WriteString("\n") - } - sb.WriteString("\n\n") - sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n") - sb.WriteString("{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...") - } - - think := true - if thinkValue != nil && !thinkValue.Bool() { - think = false - } - - for i, message := range messages { - switch message.Role { - case "user": - sb.WriteString("<|user|>") - sb.WriteString(message.Content) - case "assistant": - sb.WriteString("<|assistant|>") - if message.Thinking != "" { - sb.WriteString("" + message.Thinking + "") - } else { - sb.WriteString("") - } - if message.Content != "" { - sb.WriteString(message.Content) - } - if len(message.ToolCalls) > 0 { - for _, toolCall := range message.ToolCalls { - sb.WriteString("" + toolCall.Function.Name) - sb.WriteString(renderToolArguments(toolCall.Function.Arguments)) - sb.WriteString("") - } - } - case "tool": - if i == 0 || messages[i-1].Role != "tool" { - sb.WriteString("<|observation|>") - } - sb.WriteString("") - sb.WriteString(message.Content) - sb.WriteString("") - case "system": - sb.WriteString("<|system|>") - sb.WriteString(message.Content) - } - } - - sb.WriteString("<|assistant|>") - if think { - sb.WriteString("") - } else { - sb.WriteString("") - } - - return sb.String(), nil -} - -// renderToolArguments converts tool call arguments to GLM4 XML format. -func renderToolArguments(args api.ToolCallFunctionArguments) string { - var sb strings.Builder - for key, value := range args.All() { - sb.WriteString("" + key + "") - var valueStr string - if str, ok := value.(string); ok { - valueStr = str - } else { - jsonBytes, err := json.Marshal(value) - if err != nil { - valueStr = fmt.Sprintf("%v", value) - } else { - valueStr = string(jsonBytes) - } - } - - sb.WriteString("" + valueStr + "") - } - - return sb.String() -} - -// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and , -func formatToolJSON(raw []byte) string { - var sb strings.Builder - sb.Grow(len(raw) + len(raw)/10) - - inString := false - escaped := false - for i := range raw { - ch := raw[i] - sb.WriteByte(ch) - - if inString { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == '"' { - inString = false - } - continue - } - - if ch == '"' { - inString = true - continue - } - - if ch == ':' || ch == ',' { - sb.WriteByte(' ') - } - } - - return sb.String() -} diff --git a/x/imagegen/models/glm4_moe_lite/render_test.go b/x/imagegen/models/glm4_moe_lite/render_test.go deleted file mode 100644 index f0d576bec..000000000 --- a/x/imagegen/models/glm4_moe_lite/render_test.go +++ /dev/null @@ -1,205 +0,0 @@ -//go:build mlx - -package glm4_moe_lite - -import ( - "strings" - "testing" - - "github.com/ollama/ollama/api" -) - -func TestRendererSimple(t *testing.T) { - r := &Renderer{} - - messages := []api.Message{ - {Role: "user", Content: "Hello"}, - } - - // Thinking enabled (default) - result, err := r.Render(messages, nil, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - expected := "[gMASK]<|user|>Hello<|assistant|>" - if result != expected { - t.Errorf("result = %q, want %q", result, expected) - } -} - -func TestRendererThinkingDisabled(t *testing.T) { - r := &Renderer{} - - messages := []api.Message{ - {Role: "user", Content: "Hello"}, - } - - tv := &api.ThinkValue{Value: false} - - result, err := r.Render(messages, nil, tv) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - expected := "[gMASK]<|user|>Hello<|assistant|>" - if result != expected { - t.Errorf("result = %q, want %q", result, expected) - } -} - -func TestRendererMultiTurn(t *testing.T) { - r := &Renderer{} - - messages := []api.Message{ - {Role: "user", Content: "What is 2+2?"}, - {Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"}, - {Role: "user", Content: "And 3+3?"}, - } - - result, err := r.Render(messages, nil, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Check key parts - if !strings.Contains(result, "[gMASK]") { - t.Error("missing [gMASK] prefix") - } - if !strings.Contains(result, "<|user|>What is 2+2?") { - t.Error("missing first user message") - } - if !strings.Contains(result, "<|assistant|>Let me calculate: 2+2=44") { - t.Error("missing assistant message with thinking") - } - if !strings.Contains(result, "<|user|>And 3+3?") { - t.Error("missing second user message") - } - if !strings.HasSuffix(result, "<|assistant|>") { - t.Errorf("should end with <|assistant|>, got suffix: %q", result[len(result)-30:]) - } -} - -func TestRendererWithSystem(t *testing.T) { - r := &Renderer{} - - messages := []api.Message{ - {Role: "system", Content: "You are a helpful assistant."}, - {Role: "user", Content: "Hello"}, - } - - result, err := r.Render(messages, nil, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if !strings.Contains(result, "<|system|>You are a helpful assistant.") { - t.Error("missing system message") - } -} - -func TestRendererWithTools(t *testing.T) { - r := &Renderer{} - - messages := []api.Message{ - {Role: "user", Content: "What's the weather?"}, - } - - props := api.NewToolPropertiesMap() - props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"}) - tools := []api.Tool{ - { - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get the weather for a location", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: props, - Required: []string{"location"}, - }, - }, - }, - } - - result, err := r.Render(messages, tools, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Check for tool system prompt - if !strings.Contains(result, "<|system|>") { - t.Error("missing system tag for tools") - } - if !strings.Contains(result, "# Tools") { - t.Error("missing tools header") - } - if !strings.Contains(result, "") { - t.Error("missing tools tag") - } - if !strings.Contains(result, "get_weather") { - t.Error("missing tool name") - } - if !strings.Contains(result, "") { - t.Error("missing closing tools tag") - } -} - -func TestRendererWithToolCalls(t *testing.T) { - r := &Renderer{} - - args := api.NewToolCallFunctionArguments() - args.Set("location", "San Francisco") - - messages := []api.Message{ - {Role: "user", Content: "What's the weather in SF?"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: args, - }, - }, - }, - }, - {Role: "tool", Content: "Sunny, 72F"}, - } - - result, err := r.Render(messages, nil, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if !strings.Contains(result, "get_weather") { - t.Error("missing tool call") - } - if !strings.Contains(result, "location") { - t.Error("missing arg_key") - } - if !strings.Contains(result, "San Francisco") { - t.Error("missing arg_value") - } - if !strings.Contains(result, "") { - t.Error("missing tool call closing tag") - } - if !strings.Contains(result, "<|observation|>") { - t.Error("missing observation tag") - } - if !strings.Contains(result, "Sunny, 72F") { - t.Error("missing tool response") - } -} - -func TestFormatToolJSON(t *testing.T) { - input := []byte(`{"name":"test","value":123}`) - result := formatToolJSON(input) - - // Should add spaces after : and , - if !strings.Contains(result, ": ") { - t.Error("should add space after colon") - } - if !strings.Contains(result, ", ") { - t.Error("should add space after comma") - } -} diff --git a/x/imagegen/models/gpt_oss/gpt_oss.go b/x/imagegen/models/gpt_oss/gpt_oss.go deleted file mode 100644 index bbf01370f..000000000 --- a/x/imagegen/models/gpt_oss/gpt_oss.go +++ /dev/null @@ -1,487 +0,0 @@ -//go:build mlx - -package gpt_oss - -import ( - "encoding/json" - "fmt" - "math" - "os" - "path/filepath" - - "github.com/ollama/ollama/x/imagegen/cache" - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/nn" - "github.com/ollama/ollama/x/imagegen/safetensors" - "github.com/ollama/ollama/x/imagegen/tokenizer" -) - -// RopeScaling holds YaRN or other RoPE scaling configuration -type RopeScaling struct { - RopeType string `json:"rope_type"` - Factor float32 `json:"factor"` - OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"` - BetaFast float32 `json:"beta_fast"` - BetaSlow float32 `json:"beta_slow"` -} - -type Config struct { - HiddenSize int32 `json:"hidden_size"` - NumHiddenLayers int32 `json:"num_hidden_layers"` - IntermediateSize int32 `json:"intermediate_size"` - NumAttentionHeads int32 `json:"num_attention_heads"` - NumKeyValueHeads int32 `json:"num_key_value_heads"` - VocabSize int32 `json:"vocab_size"` - RMSNormEps float32 `json:"rms_norm_eps"` - RopeTheta float32 `json:"rope_theta"` - HeadDim int32 `json:"head_dim"` - SlidingWindow int32 `json:"sliding_window"` - NumLocalExperts int32 `json:"num_local_experts"` - NumExpertsPerTok int32 `json:"num_experts_per_tok"` - LayerTypes []string `json:"layer_types"` - SwiGLULimit float32 `json:"swiglu_limit"` - RopeScaling *RopeScaling `json:"rope_scaling"` - Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim) -} - -type Attention struct { - QProj *nn.Linear `weight:"self_attn.q_proj"` - KProj *nn.Linear `weight:"self_attn.k_proj"` - VProj *nn.Linear `weight:"self_attn.v_proj"` - OProj *nn.Linear `weight:"self_attn.o_proj"` - Sinks *mlx.Array `weight:"self_attn.sinks,optional"` - YarnFreqs *mlx.Array // computed - YarnMscale float32 -} - -// swiGLU applies the GPT-OSS custom SwiGLU activation. -// Formula: (gate * sigmoid(alpha * gate)) * (up + 1) -// with clipping: gate to [None, limit], up to [-limit, limit] -func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array { - // Clip gate to [None, limit] - gateClipped := mlx.ClipScalar(gate, 0, limit, false, true) - - // Clip up to [-limit, limit] - upClipped := mlx.ClipScalar(up, -limit, limit, true, true) - - // glu_scaled = alpha * gate_clipped - gluScaled := mlx.MulScalar(gateClipped, alpha) - - // sig = sigmoid(glu_scaled) - sig := mlx.Sigmoid(gluScaled) - - // out_glu = gate_clipped * sig - outGlu := mlx.Mul(gateClipped, sig) - - // result = out_glu * (up_clipped + 1) - return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0)) -} - -// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers -var compiledSwiGLU *mlx.CompiledFunc - -// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed -func getCompiledSwiGLU() *mlx.CompiledFunc { - if compiledSwiGLU == nil { - const alpha float32 = 1.702 - const limit float32 = 7.0 - compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array { - return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)} - }, true) // shapeless=true so it works for any input size - } - return compiledSwiGLU -} - -// ComputeYarnFreqs computes YaRN-modified RoPE frequencies -// Based on mlx-lm's YarnRoPE implementation -func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) { - // yarn_find_correction_dim - yarnFindCorrectionDim := func(numRotations float64) float64 { - return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base))) - } - - // yarn_find_correction_range - low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast)))) - high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow)))) - if low < 0 { - low = 0 - } - if high > int(dims)-1 { - high = int(dims) - 1 - } - - // yarn_get_mscale - yarnGetMscale := func(scale, mscale float64) float64 { - if scale <= 1 { - return 1.0 - } - return 0.1*mscale*math.Log(scale) + 1.0 - } - mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0)) - - // Compute frequencies - // freq_extra = base ** (arange(0, dims, 2) / dims) - // freq_inter = scaling_factor * freq_extra - halfDims := dims / 2 - freqData := make([]float32, halfDims) - for i := int32(0); i < halfDims; i++ { - exp := float64(2*i) / float64(dims) - freqExtra := math.Pow(float64(base), exp) - freqInter := float64(scalingFactor) * freqExtra - - // linear ramp mask - var freqMask float64 - if low == high { - freqMask = 0.0 - } else { - t := (float64(i) - float64(low)) / float64(high-low) - if t < 0 { - t = 0 - } - if t > 1 { - t = 1 - } - freqMask = 1.0 - t - } - - // Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask)) - freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask))) - } - - return mlx.NewArray(freqData, []int32{halfDims}), mscale -} - -// initYarn initializes YaRN RoPE if configured -func (a *Attention) initYarn(cfg *Config) { - a.YarnMscale = 1.0 - if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" { - a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs( - cfg.HeadDim, - cfg.RopeTheta, - cfg.RopeScaling.Factor, - cfg.RopeScaling.OriginalMaxPositionEmbeddings, - cfg.RopeScaling.BetaFast, - cfg.RopeScaling.BetaSlow, - ) - } -} - -func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array { - q := a.QProj.Forward(x) - k := a.KProj.Forward(x) - v := a.VProj.Forward(x) - - // Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim] - q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - - offset := 0 - if c != nil { - offset = c.Offset() - } - if a.YarnFreqs != nil { - if a.YarnMscale != 1.0 { - q = mlx.MulScalar(q, a.YarnMscale) - } - q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset) - k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset) - } else { - q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) - k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) - } - - if c != nil { - k, v = c.Update(k, v, int(L)) - } - - out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks) - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) - return a.OProj.Forward(out) -} - -// CreateSlidingWindowMask creates a causal mask with sliding window -// Mirrors mlx-lm's create_causal_mask with window_size -func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array { - // Build mask aligned to actual cache length (may be rotated) - // rinds covers existing keys: [keyStart, keyStart+keyLen) - // linds covers new queries: [queryStart, queryStart+seqLen) - rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen] - linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen] - - linds = mlx.ExpandDims(linds, 1) // [seqLen, 1] - rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen] - - causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen] - windowLimit := mlx.AddScalar(rinds, float32(windowSize)) - windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen] - - return mlx.LogicalAnd(causalMask, windowMask) -} - -// MoE represents the Mixture of Experts SwiGLU layer with quantized experts. -type MoE struct { - Router *nn.Linear `weight:"mlp.router"` - TopK int32 - HiddenSize int32 - GroupSize int - Bits int - // Expert weights (loaded manually via sanitizeExpertWeights) - GateBlocks, GateScales, GateBias *mlx.Array - UpBlocks, UpScales, UpBias *mlx.Array - DownBlocks, DownScales, DownBias *mlx.Array -} - -func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array { - logits := moe.Router.Forward(x) - neg := mlx.Neg(logits) - part := mlx.Argpartition(neg, int(moe.TopK)-1, -1) - topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK}) - topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1) - weights := mlx.Softmax(topKVal, -1) - - xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize) - idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK) - - doSort := B*L >= 64 - var invOrder *mlx.Array - sorted := false - n := B * L * moe.TopK - - if doSort { - idxAll := mlx.Flatten(idxFlat) - order := mlx.Argsort(idxAll, 0) - invOrder = mlx.Argsort(order, 0) - xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1) - idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) - sorted = true - } - - gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted) - up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted) - - if moe.GateBias != nil { - gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2)) - } - if moe.UpBias != nil { - up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2)) - } - - hidden := getCompiledSwiGLU().Call(gate, up)[0] - - down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted) - if moe.DownBias != nil { - down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2)) - } - - if doSort { - down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize) - } else { - down = mlx.Squeeze(down, 2) - } - - ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1) - return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize) -} - -type Block struct { - Attention *Attention - MLP *MoE - InputNorm *nn.RMSNorm `weight:"input_layernorm"` - PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"` - LayerType string // "sliding_attention" or "full_attention" -} - -func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array { - h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg)) - return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L)) -} - -type Model struct { - EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` - Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization - Norm *nn.RMSNorm `weight:"model.norm"` - LMHead *nn.Linear `weight:"lm_head"` - - tok *tokenizer.Tokenizer - *Config -} - -func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } -func (m *Model) NumLayers() int { return len(m.Layers) } -func (m *Model) VocabSize() int32 { return m.Config.VocabSize } - -func (m *Model) NewCache(int32) []cache.Cache { - caches := make([]cache.Cache, len(m.Layers)) - for i, layer := range m.Layers { - if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 { - caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow)) - } else { - caches[i] = cache.NewKVCache() - } - } - return caches -} - -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - B, L := tokens.Shape()[0], tokens.Shape()[1] - x := m.EmbedTokens.Forward(tokens) - - // Find representative cache indices for sliding window attention - var swaIdx int = -1 - for i, layer := range m.Layers { - if layer.LayerType == "sliding_attention" { - swaIdx = i - break - } - } - - // Create masks once at model level - var fullMask, swaMask *mlx.Array - var fullMaskMode, swaMaskMode string - - if L > 1 { - fullMaskMode = "causal" - if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil { - c := caches[swaIdx] - offset := c.Offset() - windowSize := int(m.SlidingWindow) - cacheLen := min(int(L), windowSize) - if offset > 0 { - cacheLen = min(c.Len()+int(L), windowSize) - } - if int(L) > windowSize { - swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize) - } else { - swaMaskMode = "causal" - } - } else { - swaMaskMode = "causal" - } - } - - for i, layer := range m.Layers { - var c cache.Cache - if caches != nil { - c = caches[i] - } - mask, maskMode := fullMask, fullMaskMode - if layer.LayerType == "sliding_attention" { - mask, maskMode = swaMask, swaMaskMode - } - x = layer.Forward(x, c, B, L, mask, maskMode, m.Config) - } - - return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps)) -} - -// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays. -// MXFP4 quantized weights require contiguous memory - strided views give wrong results. -func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) { - gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks") - gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales") - gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias") - downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks") - downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales") - downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias") - - moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias} - - if gateUpBlocks != nil { - gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1) - s := gub.Shape() - moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1})) - moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1})) - } - if gateUpScales != nil { - s := gateUpScales.Shape() - moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1})) - moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1})) - } - if gateUpBias != nil { - s := gateUpBias.Shape() - moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2})) - moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2})) - } - if downBlocks != nil { - moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1) - } - return moe -} - -func Load(modelPath string) (*Model, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("load config: %w", err) - } - var cfg Config - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("parse config: %w", err) - } - cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) - - weights, err := safetensors.LoadModelWeights(modelPath) - if err != nil { - return nil, fmt.Errorf("load weights: %w", err) - } - - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) - if err != nil { - return nil, fmt.Errorf("load tokenizer: %w", err) - } - - m := &Model{ - Layers: make([]*Block, cfg.NumHiddenLayers), - Config: &cfg, - tok: tok, - } - - // Load simple weights via struct tags - if err := safetensors.LoadModule(m, weights, ""); err != nil { - return nil, err - } - - // Load layers with custom MoE handling - for i := int32(0); i < cfg.NumHiddenLayers; i++ { - prefix := fmt.Sprintf("model.layers.%d", i) - layer := &Block{} - if err := safetensors.LoadModule(layer, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d: %w", i, err) - } - - // Initialize attention YaRN - layer.Attention.initYarn(&cfg) - - // Load MoE with weight sanitization - moe := sanitizeExpertWeights(weights, prefix) - moe.Router = layer.MLP.Router // Router was loaded by LoadModule - moe.TopK = cfg.NumExpertsPerTok - moe.HiddenSize = cfg.HiddenSize - layer.MLP = moe - - // Set layer type - layer.LayerType = "full_attention" - if int(i) < len(cfg.LayerTypes) { - layer.LayerType = cfg.LayerTypes[i] - } - - m.Layers[i] = layer - } - - // Release safetensors BEFORE eval - lazy arrays have captured data, - // this reduces peak memory by freeing mmap during materialization - weights.ReleaseAll() - mlx.Eval(mlx.Collect(m)...) - - return m, nil -} - -func (m *Model) MaxContextLength() int32 { - if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 { - return m.RopeScaling.OriginalMaxPositionEmbeddings - } - return 131072 -} diff --git a/x/imagegen/models/llama/llama.go b/x/imagegen/models/llama/llama.go deleted file mode 100644 index 2b695f78e..000000000 --- a/x/imagegen/models/llama/llama.go +++ /dev/null @@ -1,152 +0,0 @@ -//go:build mlx - -package llama - -import ( - "encoding/json" - "fmt" - "math" - "os" - "path/filepath" - - "github.com/ollama/ollama/x/imagegen/cache" - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/nn" - "github.com/ollama/ollama/x/imagegen/safetensors" - "github.com/ollama/ollama/x/imagegen/tokenizer" -) - -type Config struct { - HiddenSize int32 `json:"hidden_size"` - NumHiddenLayers int32 `json:"num_hidden_layers"` - IntermediateSize int32 `json:"intermediate_size"` - NumAttentionHeads int32 `json:"num_attention_heads"` - NumKeyValueHeads int32 `json:"num_key_value_heads"` - VocabSize int32 `json:"vocab_size"` - RMSNormEps float32 `json:"rms_norm_eps"` - RopeTheta float32 `json:"rope_theta"` - MaxPositionEmbeddings int32 `json:"max_position_embeddings"` - HeadDim int32 `json:"-"` - Scale float32 `json:"-"` -} - -type Model struct { - EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` - Layers []*Layer `weight:"model.layers"` - Norm *nn.RMSNorm `weight:"model.norm"` - Output *nn.Linear `weight:"lm_head,optional"` - - tok *tokenizer.Tokenizer - *Config -} - -type Layer struct { - Attention *Attention - MLP *MLP - AttentionNorm *nn.RMSNorm `weight:"input_layernorm"` - MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"` -} - -type Attention struct { - QProj *nn.Linear `weight:"self_attn.q_proj"` - KProj *nn.Linear `weight:"self_attn.k_proj"` - VProj *nn.Linear `weight:"self_attn.v_proj"` - OProj *nn.Linear `weight:"self_attn.o_proj"` -} - -type MLP struct { - GateProj *nn.Linear `weight:"mlp.gate_proj"` - UpProj *nn.Linear `weight:"mlp.up_proj"` - DownProj *nn.Linear `weight:"mlp.down_proj"` -} - -func Load(modelPath string) (*Model, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("load config: %w", err) - } - var cfg Config - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("parse config: %w", err) - } - cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads - cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) - - weights, err := safetensors.LoadModelWeights(modelPath) - if err != nil { - return nil, fmt.Errorf("load weights: %w", err) - } - - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) - if err != nil { - return nil, fmt.Errorf("load tokenizer: %w", err) - } - - m := &Model{ - Layers: make([]*Layer, cfg.NumHiddenLayers), - Config: &cfg, - tok: tok, - } - if err := safetensors.LoadModule(m, weights, ""); err != nil { - return nil, err - } - m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil) - - mlx.Eval(mlx.Collect(m)...) - weights.ReleaseAll() - - return m, nil -} - -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - B, L := tokens.Shape()[0], tokens.Shape()[1] - h := m.EmbedTokens.Forward(tokens) - for i, layer := range m.Layers { - h = layer.Forward(h, caches[i], B, L, m.Config) - } - return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps)) -} - -func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)) - return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps))) -} - -func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - q := a.QProj.Forward(x) - k := a.KProj.Forward(x) - v := a.VProj.Forward(x) - - q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - - q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) - k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) - - k, v = c.Update(k, v, int(L)) - out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) - return a.OProj.Forward(out) -} - -func (m *MLP) Forward(x *mlx.Array) *mlx.Array { - return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) -} - -// Interface methods -func (m *Model) NumLayers() int { return len(m.Layers) } -func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings } -func (m *Model) VocabSize() int32 { return m.Config.VocabSize } -func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } - -func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { - caches := make([]cache.Cache, len(m.Layers)) - for i := range caches { - caches[i] = cache.NewKVCache() - } - return caches -} diff --git a/x/imagegen/runner.go b/x/imagegen/runner.go index e24383ad5..0409c4bf7 100644 --- a/x/imagegen/runner.go +++ b/x/imagegen/runner.go @@ -39,19 +39,23 @@ func Execute(args []string) error { return fmt.Errorf("--port is required") } - // Initialize MLX + // Detect model type from capabilities + mode := detectModelMode(*modelName) + slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode) + + if mode != ModeImageGen { + return fmt.Errorf("imagegen runner only supports image generation models") + } + + // Initialize MLX only for image generation mode. if err := mlx.InitMLX(); err != nil { slog.Error("unable to initialize MLX", "error", err) return err } slog.Info("MLX library initialized") - // Detect model type from capabilities - mode := detectModelMode(*modelName) - slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode) - // Create and start server - server, err := newServer(*modelName, *port, mode) + server, err := newServer(*modelName, *port) if err != nil { return fmt.Errorf("failed to create server: %w", err) } @@ -61,12 +65,6 @@ func Execute(args []string) error { mux.HandleFunc("/health", server.healthHandler) mux.HandleFunc("/completion", server.completionHandler) - // LLM-specific endpoints - if mode == ModeLLM { - mux.HandleFunc("/tokenize", server.tokenizeHandler) - mux.HandleFunc("/embedding", server.embeddingHandler) - } - httpServer := &http.Server{ Addr: fmt.Sprintf("127.0.0.1:%d", *port), Handler: mux, @@ -112,34 +110,22 @@ func detectModelMode(modelName string) ModelMode { // server holds the model and handles HTTP requests. type server struct { - mode ModelMode modelName string port int - // Image generation model (when mode == ModeImageGen) + // Image generation model. imageModel ImageModel - - // LLM model (when mode == ModeLLM) - llmModel *llmState } -// newServer creates a new server instance and loads the appropriate model. -func newServer(modelName string, port int, mode ModelMode) (*server, error) { +// newServer creates a new server instance for image generation models. +func newServer(modelName string, port int) (*server, error) { s := &server{ - mode: mode, modelName: modelName, port: port, } - switch mode { - case ModeImageGen: - if err := s.loadImageModel(); err != nil { - return nil, fmt.Errorf("failed to load image model: %w", err) - } - case ModeLLM: - if err := s.loadLLMModel(); err != nil { - return nil, fmt.Errorf("failed to load LLM model: %w", err) - } + if err := s.loadImageModel(); err != nil { + return nil, fmt.Errorf("failed to load image model: %w", err) } return s, nil @@ -163,41 +149,5 @@ func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) { return } - switch s.mode { - case ModeImageGen: - s.handleImageCompletion(w, r, req) - case ModeLLM: - s.handleLLMCompletion(w, r, req) - } -} - -func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) { - if s.llmModel == nil { - http.Error(w, "LLM model not loaded", http.StatusInternalServerError) - return - } - - var req struct { - Content string `json:"content"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - tok := s.llmModel.model.Tokenizer() - tokens := tok.Encode(req.Content, false) - - // Convert int32 to int for JSON response - intTokens := make([]int, len(tokens)) - for i, t := range tokens { - intTokens[i] = int(t) - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens}) -} - -func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) { - http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented) + s.handleImageCompletion(w, r, req) } diff --git a/x/imagegen/server.go b/x/imagegen/server.go index 2beb66bda..102cb0c55 100644 --- a/x/imagegen/server.go +++ b/x/imagegen/server.go @@ -30,13 +30,12 @@ import ( // Server wraps an MLX runner subprocess to implement llm.LlamaServer. // // This implementation is compatible with Ollama's scheduler and can be loaded/unloaded -// like any other model. It supports both LLM (safetensors) and image generation models. +// like any other model. It is used for image generation models. type Server struct { mu sync.Mutex cmd *exec.Cmd port int modelName string - mode ModelMode vramSize uint64 done chan error client *http.Client @@ -45,7 +44,7 @@ type Server struct { } // NewServer spawns a new MLX runner subprocess and waits until it's ready. -func NewServer(modelName string, mode ModelMode) (*Server, error) { +func NewServer(modelName string) (*Server, error) { // Validate platform support before attempting to start if err := CheckPlatformSupport(); err != nil { return nil, err @@ -119,7 +118,6 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) { cmd: cmd, port: port, modelName: modelName, - mode: mode, vramSize: vramSize, done: make(chan error, 1), client: &http.Client{Timeout: 10 * time.Minute}, @@ -145,7 +143,7 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) { } }() - slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode) + slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port) if err := cmd.Start(); err != nil { return nil, fmt.Errorf("failed to start mlx runner: %w", err) } @@ -396,36 +394,7 @@ func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, e // Tokenize tokenizes the input content. func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) { - body, err := json.Marshal(map[string]string{"content": content}) - if err != nil { - return nil, err - } - - url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port) - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - - resp, err := s.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode) - } - - var result struct { - Tokens []int `json:"tokens"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, err - } - - return result.Tokens, nil + return nil, errors.New("tokenization not supported for image generation models") } // Detokenize converts tokens back to text.