Files
ollama/x/model/models/gemma3/model_text.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

212 lines
7.7 KiB
Go

//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model/input"
)
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
Layers []TextLayer `gguf:"layers"`
OutputNorm *nn.RMSNorm `gguf:"norm"`
Output *nn.Linear `gguf:"embed_tokens"`
*TextConfig
}
const (
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 62
)
// const (
// cacheTypeSWA = iota
// cacheTypeCausal
// )
func newTextModel(c fs.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
ropeScale: 1, // 1 - default is 1, implied in python code
// vocabSize: vocabSize, // 262144
// attnValLen: int(c.Uint("attention.value_length", 256)), //256
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
},
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"q_proj"`
QueryNorm *nn.RMSNorm `gguf:"q_norm"`
Key *nn.Linear `gguf:"k_proj"`
KeyNorm *nn.RMSNorm `gguf:"k_norm"`
Value *nn.Linear `gguf:"v_proj"`
Output *nn.Linear `gguf:"o_proj"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
B := hiddenState.Dim(0)
L := hiddenState.Dim(1)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
k := sa.Key.Forward(ctx, hiddenState)
v := sa.Value.Forward(ctx, hiddenState)
q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
traditional := false
q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
// TODO - this is wrong somehow so commenting out
// if opts.largeModelScaling {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
// } else {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
// }
scaleFactor := math.Pow(256, -0.5)
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// ropeBase := m.TextConfig.ropeLocalBase
// if (layer+1)%gemmaGlobalCacheCount == 0 {
// ropeBase = m.TextConfig.ropeGlobalBase
// }
// q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
panic("not yet implemented")
// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"up_proj"`
Down *nn.Linear `gguf:"down_proj"`
Gate *nn.Linear `gguf:"gate_proj"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
SelfAttention *TextSelfAttention `gguf:"self_attn"`
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
MLP *TextMLP `gguf:"mlp"`
PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
residual = residual.TakeAxes(ctx, outputs, 1)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
// var except []int
// for _, image := range batch.Multimodal {
// visionOutputs := image.Multimodal[0].Tensor
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
// []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
// []int{image.Index * hiddenState.Stride(1)}, 0)))
// for i := range visionOutputs.Dim(1) {
// except = append(except, image.Index+i)
// }
// }
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
if cache != nil {
// cacheType := cacheTypeSWA
// if (i+1)%gemmaGlobalCacheCount == 0 {
// cacheType = cacheTypeCausal
// }
cache.SetLayer(i)
// TODO this needs to come back
// wc := cache.(*kvcache.WrapperCache)
// wc.SetLayerType(cacheType)
// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
// causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
// }
}
var offset int
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
offset = batch.Offset
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}