mlx: quantized embeddings, fast SwiGLU, and runtime fixes (#14884)

Add QuantizedEmbedding and EmbeddingLayer interface so models can
use quantized embedding weights and expose tied output projections.
This change updates gemma3, glm4_moe_lite, llama, qwen3, and qwen3_5
to use the new interface.
This commit is contained in:
Patrick Devine
2026-03-17 11:21:38 -07:00
committed by GitHub
parent fa69b833cd
commit d727aacd04
12 changed files with 405 additions and 37 deletions

View File

@@ -310,6 +310,12 @@ func Log(a *Array) *Array {
return out
}
func Logaddexp(a, b *Array) *Array {
out := New("LOGADDEXP")
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
out := New("SOFTMAX_AXIS")
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)

View File

@@ -0,0 +1,42 @@
package model
import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// MakeEmbeddingLayer constructs an embedding layer from a tensor map.
//
// For quantized tensors (path.weight + path.weight_scale), it returns a
// QuantizedEmbedding using the same quant metadata path that linear layers use.
// For non-quantized tensors, it returns a standard dense embedding.
func MakeEmbeddingLayer(
tensors map[string]*mlx.Array,
path string,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) nn.EmbeddingLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
groupSize, bits, mode := ResolveLinearQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
path+".weight",
w,
scales,
)
return nn.NewQuantizedEmbedding(w, scales, qbiases, groupSize, bits, mode)
}
return nn.NewEmbedding(w)
}

View File

@@ -0,0 +1,78 @@
package model
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func TestMakeEmbeddingLayerDense(t *testing.T) {
skipIfNoMLX(t)
weight := mlx.FromValues([]float32{
1, 2, 3, 4,
5, 6, 7, 8,
}, 2, 4).AsType(mlx.DTypeBFloat16)
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
"model.embed_tokens.weight": weight,
}, "model.embed_tokens", 0, 0, "", nil)
dense, ok := emb.(*nn.Embedding)
if !ok {
t.Fatalf("embedding type = %T, want *nn.Embedding", emb)
}
if dense.Weight.DType() != mlx.DTypeBFloat16 {
t.Fatalf("embedding dtype = %v, want %v", dense.Weight.DType(), mlx.DTypeBFloat16)
}
if _, ok := emb.AsLinear().(*nn.Linear); !ok {
t.Fatalf("AsLinear type = %T, want *nn.Linear", emb.AsLinear())
}
}
func TestMakeEmbeddingLayerQuantized(t *testing.T) {
skipIfNoMLX(t)
denseWeight := mlx.FromValues(func() []float32 {
out := make([]float32, 2*64)
for i := range out {
out[i] = float32(i%17) / 8
}
return out
}(), 2, 64).AsType(mlx.DTypeBFloat16)
qw, scales, qbiases := mlx.Quantize(denseWeight, 64, 4, "affine")
mlx.Eval(qw, scales, qbiases)
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
"model.embed_tokens.weight": qw,
"model.embed_tokens.weight_scale": scales,
"model.embed_tokens.weight_qbias": qbiases,
}, "model.embed_tokens", 64, 4, "affine", nil)
qemb, ok := emb.(*nn.QuantizedEmbedding)
if !ok {
t.Fatalf("embedding type = %T, want *nn.QuantizedEmbedding", emb)
}
if qemb.GroupSize != 64 || qemb.Bits != 4 || qemb.Mode != "affine" {
t.Fatalf("quant params = (%d, %d, %q), want (64, 4, %q)", qemb.GroupSize, qemb.Bits, qemb.Mode, "affine")
}
indices := mlx.FromValues([]int32{1, 0}, 2)
out := emb.Forward(indices)
mlx.Eval(out)
if dims := out.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 64 {
t.Fatalf("embedding output dims = %v, want [2 64]", dims)
}
if _, ok := emb.AsLinear().(*nn.QuantizedLinear); !ok {
t.Fatalf("AsLinear type = %T, want *nn.QuantizedLinear", emb.AsLinear())
}
}

View File

@@ -147,7 +147,7 @@ func Execute(args []string) error {
return
}
tokens := runner.Tokenizer.Encode(b.String(), true)
tokens := runner.Tokenizer.Encode(b.String(), runner.Tokenizer.AddBOS())
if err := json.NewEncoder(w).Encode(tokens); err != nil {
slog.Error("Failed to encode response", "error", err)

View File

@@ -91,7 +91,7 @@ type DecoderLayer struct {
// Model is the Gemma 3 text-only model.
type Model struct {
EmbedTokens *nn.Embedding
EmbedTokens nn.EmbeddingLayer
Layers []*DecoderLayer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
@@ -310,11 +310,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
if embedTokens == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
m.EmbedTokens = embedTokens
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
@@ -328,7 +328,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.LMHead = lmHead
} else {
// Gemma usually ties output projection to embeddings.
m.LMHead = nn.NewLinear(embedWeight, nil)
m.LMHead = m.EmbedTokens.AsLinear()
}
for i := int32(0); i < m.NumHiddenLayers; i++ {

View File

@@ -345,7 +345,7 @@ type Block interface {
// Model represents the complete GLM4-MoE-Lite model
type Model struct {
EmbedTokens *nn.Embedding
EmbedTokens nn.EmbeddingLayer
Layers []Block
Norm *nn.RMSNorm
LMHead nn.LinearLayer
@@ -586,9 +586,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
// Load embedding
if w := tensors["model.embed_tokens.weight"]; w != nil {
m.EmbedTokens = nn.NewEmbedding(w)
}
m.EmbedTokens = model.MakeEmbeddingLayer(tensors, "model.embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
// Load final norm
if w := tensors["model.norm.weight"]; w != nil {

View File

@@ -44,7 +44,7 @@ type Config struct {
// Model is a Llama text model.
type Model struct {
EmbedTokens *nn.Embedding
EmbedTokens nn.EmbeddingLayer
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
@@ -170,11 +170,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
if embedTokens == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
m.EmbedTokens = embedTokens
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
@@ -183,14 +183,14 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
m.LMHead = m.EmbedTokens.AsLinear()
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Fallback used by many Llama checkpoints where output is tied.
m.LMHead = nn.NewLinear(embedWeight, nil)
m.LMHead = m.EmbedTokens.AsLinear()
}
for i := int32(0); i < m.NumHiddenLayers; i++ {

View File

@@ -13,6 +13,13 @@ type LinearLayer interface {
OutputDim() int32
}
// EmbeddingLayer is an interface for embedding layers that can also expose a
// tied-output projection when the model reuses embedding weights as the LM head.
type EmbeddingLayer interface {
Forward(indices *mlx.Array) *mlx.Array
AsLinear() LinearLayer
}
// Conv1d applies 1D convolution over NLC input.
type Conv1d struct {
Weight *mlx.Array
@@ -140,6 +147,53 @@ func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() LinearLayer {
return NewLinear(e.Weight, nil)
}
// QuantizedEmbedding performs row-wise embedding lookup from affine/nvfp4/etc.
// packed weights and dequantizes only the selected rows.
type QuantizedEmbedding struct {
Weight *mlx.Array
Scales *mlx.Array
QBiases *mlx.Array
GroupSize int
Bits int
Mode string
}
func NewQuantizedEmbedding(weight, scales, qbiases *mlx.Array, groupSize, bits int, mode string) *QuantizedEmbedding {
return &QuantizedEmbedding{
Weight: weight,
Scales: scales,
QBiases: qbiases,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
func (qe *QuantizedEmbedding) Forward(indices *mlx.Array) *mlx.Array {
weight := qe.Weight.TakeAxis(indices, 0)
scales := qe.Scales.TakeAxis(indices, 0)
var qbiases *mlx.Array
if qe.QBiases != nil && qe.QBiases.Valid() {
qbiases = qe.QBiases.TakeAxis(indices, 0)
}
return mlx.Dequantize(weight, scales, qbiases, qe.GroupSize, qe.Bits, qe.Mode)
}
func (qe *QuantizedEmbedding) AsLinear() LinearLayer {
return &QuantizedLinear{
Weight: qe.Weight,
Scales: qe.Scales,
QBiases: qe.QBiases,
GroupSize: qe.GroupSize,
Bits: qe.Bits,
Mode: qe.Mode,
}
}
// LayerNorm represents a standard layer normalization layer (with bias).
type LayerNorm struct {
Weight *mlx.Array
@@ -175,7 +229,6 @@ func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
return x.Matmul(wT)
}
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
shape := scores.Dims()

View File

@@ -45,7 +45,7 @@ type Config struct {
// Model is the Qwen3 text-only model.
type Model struct {
EmbedTokens *nn.Embedding
EmbedTokens nn.EmbeddingLayer
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
@@ -177,11 +177,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
if embedTokens == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
m.EmbedTokens = embedTokens
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
@@ -190,14 +190,14 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
m.LMHead = m.EmbedTokens.AsLinear()
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Qwen3 checkpoints commonly tie output projection to embeddings.
m.LMHead = nn.NewLinear(embedWeight, nil)
m.LMHead = m.EmbedTokens.AsLinear()
}
for i := int32(0); i < m.NumHiddenLayers; i++ {

View File

@@ -81,7 +81,7 @@ type Config struct {
// Model is the Qwen 3.5 model.
type Model struct {
EmbedTokens *nn.Embedding
EmbedTokens nn.EmbeddingLayer
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
@@ -824,12 +824,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
freeTensorKeys(tensors, mtpKeys...)
}
embedKey := modelPrefix + "embed_tokens.weight"
embedWeight := tensors[embedKey]
if embedWeight == nil {
embedTokens := model.MakeEmbeddingLayer(tensors, modelPrefix+"embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
if embedTokens == nil {
return fmt.Errorf("missing embedding weight: %sembed_tokens.weight", modelPrefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
m.EmbedTokens = embedTokens
normKey := modelPrefix + "norm.weight"
normWeight := maybeShiftNormWeight(normKey, tensors[normKey], shouldShiftNormWeights)
@@ -839,13 +838,13 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Norm = nn.NewRMSNorm(normWeight, cfg.RMSNormEps)
if cfg.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
m.LMHead = m.EmbedTokens.AsLinear()
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
m.LMHead = nn.NewLinear(embedWeight, nil)
m.LMHead = m.EmbedTokens.AsLinear()
}
useQuantizedExperts := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
@@ -1065,7 +1064,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
func softplus(x *mlx.Array) *mlx.Array {
return mlx.Log(mlx.AddScalar(mlx.Exp(x), 1.0))
return mlx.Logaddexp(x, mlx.Zeros(x.DType(), x.Dims()...))
}
func depthwiseCausalConv1d(x, w *mlx.Array, outLen int32) *mlx.Array {
@@ -1150,7 +1149,8 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
out = mlx.Mul(out, mlx.Sigmoid(gate))
gateSigmoid := mlx.Sigmoid(gate)
out = mlx.Mul(out, gateSigmoid)
out = a.OProj.Forward(out)
return out
}
@@ -1175,7 +1175,6 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
mlx.Reshape(v, B, L, cfg.LinearNumValueHeads*cfg.LinearValueHeadDim),
}, -1)
}
convTail := cfg.LinearConvKernelDim - 1
var convState *mlx.Array
var rc *cache.RecurrentCache
@@ -1216,9 +1215,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
q = mlx.MulScalar(mlx.RMSNormFn(q, nil, 1e-6), invScale*invScale)
k = mlx.MulScalar(mlx.RMSNormFn(k, nil, 1e-6), invScale)
aF32 := a.AsType(mlx.DTypeFloat32)
dtBiasF32 := g.DtBias.AsType(mlx.DTypeFloat32)
gDecay := softplus(mlx.Add(aF32, dtBiasF32))
gDecay := softplus(mlx.Add(a, g.DtBias))
gDecay = mlx.Mul(gDecay, g.AExp)
gDecay = mlx.Exp(mlx.MulScalar(gDecay, -1))
gDecay = gDecay.AsType(a.DType())
@@ -1234,8 +1231,9 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
}
out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state)
outDType := out.DType()
out = mlx.RMSNormFn(out, g.NormWeight, cfg.RMSNormEps)
out = mlx.Mul(out, mlx.SiLU(z))
out = mlx.Mul(out.AsType(mlx.DTypeFloat32), mlx.SiLU(z.AsType(mlx.DTypeFloat32))).AsType(outDType)
out = mlx.Reshape(out, B, L, valueDim)
out = g.OutProj.Forward(out)
if rc != nil {

View File

@@ -7,6 +7,13 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func TestParseConfigNestedDefaults(t *testing.T) {
data := []byte(`{
"model_type": "Qwen3_5MoeForConditionalGeneration",
@@ -155,3 +162,184 @@ func TestNewCachesLayout(t *testing.T) {
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
}
}
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
skipIfNoMLX(t)
cfg := &Config{
HiddenSize: 4,
IntermediateSize: 8,
NumHiddenLayers: 2,
NumAttentionHeads: 1,
NumKeyValueHeads: 1,
HeadDim: 4,
RMSNormEps: 1e-6,
TieWordEmbeddings: true,
LayerTypes: []string{"linear", "full"},
LinearNumValueHeads: 1,
LinearNumKeyHeads: 1,
LinearKeyHeadDim: 2,
LinearValueHeadDim: 2,
LinearConvKernelDim: 4,
FullAttentionInterval: 2,
}
m := &Model{
Config: cfg,
Layers: make([]*Layer, cfg.NumHiddenLayers),
}
bf16 := mlx.DTypeBFloat16
f32 := mlx.DTypeFloat32
tensors := map[string]*mlx.Array{
"model.embed_tokens.weight": mlx.FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 2, 4).AsType(bf16),
"model.norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.0.input_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.0.post_attention_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.0.linear_attn.in_proj_qkv.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
}, 6, 4),
"model.layers.0.linear_attn.in_proj_z.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
}, 2, 4),
"model.layers.0.linear_attn.in_proj_b.weight": mlx.FromValues([]float32{1, 0, 0, 0}, 1, 4),
"model.layers.0.linear_attn.in_proj_a.weight": mlx.FromValues([]float32{0, 1, 0, 0}, 1, 4),
"model.layers.0.linear_attn.out_proj.weight": mlx.FromValues([]float32{
1, 0,
0, 1,
1, 1,
0, 0,
}, 4, 2),
"model.layers.0.linear_attn.conv1d.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
}, 6, 4),
"model.layers.0.linear_attn.norm.weight": mlx.FromValues([]float32{1, 1}, 2),
"model.layers.0.linear_attn.dt_bias": mlx.FromValues([]float32{0}, 1),
"model.layers.0.linear_attn.A_log": mlx.FromValues([]float32{0}, 1),
"model.layers.0.mlp.gate_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.0.mlp.up_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.0.mlp.down_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0,
}, 4, 8),
"model.layers.1.input_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.1.post_attention_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.1.self_attn.q_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.1.self_attn.k_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
}, 4, 4),
"model.layers.1.self_attn.v_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
}, 4, 4),
"model.layers.1.self_attn.o_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
}, 4, 4),
"model.layers.1.self_attn.q_norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.1.self_attn.k_norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.1.mlp.gate_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.1.mlp.up_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.1.mlp.down_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0,
}, 4, 8),
}
if err := m.LoadWeights(tensors); err != nil {
t.Fatalf("LoadWeights failed: %v", err)
}
if got := m.Layers[0].InputNorm.Weight.DType(); got != f32 {
t.Fatalf("layer 0 input norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[0].PostAttentionNorm.Weight.DType(); got != f32 {
t.Fatalf("layer 0 post-attn norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[1].InputNorm.Weight.DType(); got != f32 {
t.Fatalf("layer 1 input norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[1].PostAttentionNorm.Weight.DType(); got != f32 {
t.Fatalf("layer 1 post-attn norm dtype = %v, want %v", got, f32)
}
if got := m.Norm.Weight.DType(); got != f32 {
t.Fatalf("final norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[0].Linear.NormWeight.DType(); got != f32 {
t.Fatalf("linear-attn norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[1].FullAttn.QNorm.Weight.DType(); got != f32 {
t.Fatalf("q norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[1].FullAttn.KNorm.Weight.DType(); got != f32 {
t.Fatalf("k norm dtype = %v, want %v", got, f32)
}
}

View File

@@ -71,6 +71,11 @@ func (t *Tokenizer) BOS() int32 {
return t.vocab.BOS
}
// AddBOS returns whether a BOS token should be prepended during encoding.
func (t *Tokenizer) AddBOS() bool {
return t.vocab.AddBOS
}
// EOS returns the first end of sequence token ID (for backwards compatibility)
func (t *Tokenizer) EOS() int32 {
if len(t.vocab.EOS) > 0 {