mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
mlx: quantized embeddings, fast SwiGLU, and runtime fixes (#14884)
Add QuantizedEmbedding and EmbeddingLayer interface so models can use quantized embedding weights and expose tied output projections. This change updates gemma3, glm4_moe_lite, llama, qwen3, and qwen3_5 to use the new interface.
This commit is contained in:
@@ -310,6 +310,12 @@ func Log(a *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Logaddexp(a, b *Array) *Array {
|
||||
out := New("LOGADDEXP")
|
||||
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS")
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
|
||||
42
x/mlxrunner/model/embedding.go
Normal file
42
x/mlxrunner/model/embedding.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// MakeEmbeddingLayer constructs an embedding layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it returns a
|
||||
// QuantizedEmbedding using the same quant metadata path that linear layers use.
|
||||
// For non-quantized tensors, it returns a standard dense embedding.
|
||||
func MakeEmbeddingLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.EmbeddingLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
return nn.NewQuantizedEmbedding(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
return nn.NewEmbedding(w)
|
||||
}
|
||||
78
x/mlxrunner/model/embedding_test.go
Normal file
78
x/mlxrunner/model/embedding_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeEmbeddingLayerDense(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
weight := mlx.FromValues([]float32{
|
||||
1, 2, 3, 4,
|
||||
5, 6, 7, 8,
|
||||
}, 2, 4).AsType(mlx.DTypeBFloat16)
|
||||
|
||||
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
|
||||
"model.embed_tokens.weight": weight,
|
||||
}, "model.embed_tokens", 0, 0, "", nil)
|
||||
|
||||
dense, ok := emb.(*nn.Embedding)
|
||||
if !ok {
|
||||
t.Fatalf("embedding type = %T, want *nn.Embedding", emb)
|
||||
}
|
||||
if dense.Weight.DType() != mlx.DTypeBFloat16 {
|
||||
t.Fatalf("embedding dtype = %v, want %v", dense.Weight.DType(), mlx.DTypeBFloat16)
|
||||
}
|
||||
if _, ok := emb.AsLinear().(*nn.Linear); !ok {
|
||||
t.Fatalf("AsLinear type = %T, want *nn.Linear", emb.AsLinear())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeEmbeddingLayerQuantized(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
denseWeight := mlx.FromValues(func() []float32 {
|
||||
out := make([]float32, 2*64)
|
||||
for i := range out {
|
||||
out[i] = float32(i%17) / 8
|
||||
}
|
||||
return out
|
||||
}(), 2, 64).AsType(mlx.DTypeBFloat16)
|
||||
|
||||
qw, scales, qbiases := mlx.Quantize(denseWeight, 64, 4, "affine")
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
|
||||
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
|
||||
"model.embed_tokens.weight": qw,
|
||||
"model.embed_tokens.weight_scale": scales,
|
||||
"model.embed_tokens.weight_qbias": qbiases,
|
||||
}, "model.embed_tokens", 64, 4, "affine", nil)
|
||||
|
||||
qemb, ok := emb.(*nn.QuantizedEmbedding)
|
||||
if !ok {
|
||||
t.Fatalf("embedding type = %T, want *nn.QuantizedEmbedding", emb)
|
||||
}
|
||||
if qemb.GroupSize != 64 || qemb.Bits != 4 || qemb.Mode != "affine" {
|
||||
t.Fatalf("quant params = (%d, %d, %q), want (64, 4, %q)", qemb.GroupSize, qemb.Bits, qemb.Mode, "affine")
|
||||
}
|
||||
|
||||
indices := mlx.FromValues([]int32{1, 0}, 2)
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
if dims := out.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 64 {
|
||||
t.Fatalf("embedding output dims = %v, want [2 64]", dims)
|
||||
}
|
||||
if _, ok := emb.AsLinear().(*nn.QuantizedLinear); !ok {
|
||||
t.Fatalf("AsLinear type = %T, want *nn.QuantizedLinear", emb.AsLinear())
|
||||
}
|
||||
}
|
||||
@@ -147,7 +147,7 @@ func Execute(args []string) error {
|
||||
return
|
||||
}
|
||||
|
||||
tokens := runner.Tokenizer.Encode(b.String(), true)
|
||||
tokens := runner.Tokenizer.Encode(b.String(), runner.Tokenizer.AddBOS())
|
||||
|
||||
if err := json.NewEncoder(w).Encode(tokens); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
|
||||
@@ -91,7 +91,7 @@ type DecoderLayer struct {
|
||||
|
||||
// Model is the Gemma 3 text-only model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
EmbedTokens nn.EmbeddingLayer
|
||||
Layers []*DecoderLayer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
@@ -310,11 +310,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
if embedTokens == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
m.EmbedTokens = embedTokens
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
@@ -328,7 +328,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Gemma usually ties output projection to embeddings.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
m.LMHead = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
|
||||
@@ -345,7 +345,7 @@ type Block interface {
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
EmbedTokens nn.EmbeddingLayer
|
||||
Layers []Block
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
@@ -586,9 +586,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
// Load embedding
|
||||
if w := tensors["model.embed_tokens.weight"]; w != nil {
|
||||
m.EmbedTokens = nn.NewEmbedding(w)
|
||||
}
|
||||
m.EmbedTokens = model.MakeEmbeddingLayer(tensors, "model.embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
||||
|
||||
// Load final norm
|
||||
if w := tensors["model.norm.weight"]; w != nil {
|
||||
|
||||
@@ -44,7 +44,7 @@ type Config struct {
|
||||
|
||||
// Model is a Llama text model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
EmbedTokens nn.EmbeddingLayer
|
||||
Layers []*Layer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
@@ -170,11 +170,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
if embedTokens == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
m.EmbedTokens = embedTokens
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
@@ -183,14 +183,14 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if m.TieWordEmbeddings {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
m.LMHead = m.EmbedTokens.AsLinear()
|
||||
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Fallback used by many Llama checkpoints where output is tied.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
m.LMHead = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
|
||||
@@ -13,6 +13,13 @@ type LinearLayer interface {
|
||||
OutputDim() int32
|
||||
}
|
||||
|
||||
// EmbeddingLayer is an interface for embedding layers that can also expose a
|
||||
// tied-output projection when the model reuses embedding weights as the LM head.
|
||||
type EmbeddingLayer interface {
|
||||
Forward(indices *mlx.Array) *mlx.Array
|
||||
AsLinear() LinearLayer
|
||||
}
|
||||
|
||||
// Conv1d applies 1D convolution over NLC input.
|
||||
type Conv1d struct {
|
||||
Weight *mlx.Array
|
||||
@@ -140,6 +147,53 @@ func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
func (e *Embedding) AsLinear() LinearLayer {
|
||||
return NewLinear(e.Weight, nil)
|
||||
}
|
||||
|
||||
// QuantizedEmbedding performs row-wise embedding lookup from affine/nvfp4/etc.
|
||||
// packed weights and dequantizes only the selected rows.
|
||||
type QuantizedEmbedding struct {
|
||||
Weight *mlx.Array
|
||||
Scales *mlx.Array
|
||||
QBiases *mlx.Array
|
||||
GroupSize int
|
||||
Bits int
|
||||
Mode string
|
||||
}
|
||||
|
||||
func NewQuantizedEmbedding(weight, scales, qbiases *mlx.Array, groupSize, bits int, mode string) *QuantizedEmbedding {
|
||||
return &QuantizedEmbedding{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
func (qe *QuantizedEmbedding) Forward(indices *mlx.Array) *mlx.Array {
|
||||
weight := qe.Weight.TakeAxis(indices, 0)
|
||||
scales := qe.Scales.TakeAxis(indices, 0)
|
||||
var qbiases *mlx.Array
|
||||
if qe.QBiases != nil && qe.QBiases.Valid() {
|
||||
qbiases = qe.QBiases.TakeAxis(indices, 0)
|
||||
}
|
||||
return mlx.Dequantize(weight, scales, qbiases, qe.GroupSize, qe.Bits, qe.Mode)
|
||||
}
|
||||
|
||||
func (qe *QuantizedEmbedding) AsLinear() LinearLayer {
|
||||
return &QuantizedLinear{
|
||||
Weight: qe.Weight,
|
||||
Scales: qe.Scales,
|
||||
QBiases: qe.QBiases,
|
||||
GroupSize: qe.GroupSize,
|
||||
Bits: qe.Bits,
|
||||
Mode: qe.Mode,
|
||||
}
|
||||
}
|
||||
|
||||
// LayerNorm represents a standard layer normalization layer (with bias).
|
||||
type LayerNorm struct {
|
||||
Weight *mlx.Array
|
||||
@@ -175,7 +229,6 @@ func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return x.Matmul(wT)
|
||||
}
|
||||
|
||||
|
||||
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
|
||||
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
|
||||
shape := scores.Dims()
|
||||
|
||||
@@ -45,7 +45,7 @@ type Config struct {
|
||||
|
||||
// Model is the Qwen3 text-only model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
EmbedTokens nn.EmbeddingLayer
|
||||
Layers []*Layer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
@@ -177,11 +177,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
if embedTokens == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
m.EmbedTokens = embedTokens
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
@@ -190,14 +190,14 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if m.TieWordEmbeddings {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
m.LMHead = m.EmbedTokens.AsLinear()
|
||||
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Qwen3 checkpoints commonly tie output projection to embeddings.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
m.LMHead = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
|
||||
@@ -81,7 +81,7 @@ type Config struct {
|
||||
|
||||
// Model is the Qwen 3.5 model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
EmbedTokens nn.EmbeddingLayer
|
||||
Layers []*Layer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
@@ -824,12 +824,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
freeTensorKeys(tensors, mtpKeys...)
|
||||
}
|
||||
|
||||
embedKey := modelPrefix + "embed_tokens.weight"
|
||||
embedWeight := tensors[embedKey]
|
||||
if embedWeight == nil {
|
||||
embedTokens := model.MakeEmbeddingLayer(tensors, modelPrefix+"embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
||||
if embedTokens == nil {
|
||||
return fmt.Errorf("missing embedding weight: %sembed_tokens.weight", modelPrefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
m.EmbedTokens = embedTokens
|
||||
|
||||
normKey := modelPrefix + "norm.weight"
|
||||
normWeight := maybeShiftNormWeight(normKey, tensors[normKey], shouldShiftNormWeights)
|
||||
@@ -839,13 +838,13 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Norm = nn.NewRMSNorm(normWeight, cfg.RMSNormEps)
|
||||
|
||||
if cfg.TieWordEmbeddings {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
m.LMHead = m.EmbedTokens.AsLinear()
|
||||
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
m.LMHead = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
|
||||
useQuantizedExperts := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
@@ -1065,7 +1064,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
func softplus(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Log(mlx.AddScalar(mlx.Exp(x), 1.0))
|
||||
return mlx.Logaddexp(x, mlx.Zeros(x.DType(), x.Dims()...))
|
||||
}
|
||||
|
||||
func depthwiseCausalConv1d(x, w *mlx.Array, outLen int32) *mlx.Array {
|
||||
@@ -1150,7 +1149,8 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
out = mlx.Mul(out, mlx.Sigmoid(gate))
|
||||
gateSigmoid := mlx.Sigmoid(gate)
|
||||
out = mlx.Mul(out, gateSigmoid)
|
||||
out = a.OProj.Forward(out)
|
||||
return out
|
||||
}
|
||||
@@ -1175,7 +1175,6 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
mlx.Reshape(v, B, L, cfg.LinearNumValueHeads*cfg.LinearValueHeadDim),
|
||||
}, -1)
|
||||
}
|
||||
|
||||
convTail := cfg.LinearConvKernelDim - 1
|
||||
var convState *mlx.Array
|
||||
var rc *cache.RecurrentCache
|
||||
@@ -1216,9 +1215,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
q = mlx.MulScalar(mlx.RMSNormFn(q, nil, 1e-6), invScale*invScale)
|
||||
k = mlx.MulScalar(mlx.RMSNormFn(k, nil, 1e-6), invScale)
|
||||
|
||||
aF32 := a.AsType(mlx.DTypeFloat32)
|
||||
dtBiasF32 := g.DtBias.AsType(mlx.DTypeFloat32)
|
||||
gDecay := softplus(mlx.Add(aF32, dtBiasF32))
|
||||
gDecay := softplus(mlx.Add(a, g.DtBias))
|
||||
gDecay = mlx.Mul(gDecay, g.AExp)
|
||||
gDecay = mlx.Exp(mlx.MulScalar(gDecay, -1))
|
||||
gDecay = gDecay.AsType(a.DType())
|
||||
@@ -1234,8 +1231,9 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
}
|
||||
|
||||
out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state)
|
||||
outDType := out.DType()
|
||||
out = mlx.RMSNormFn(out, g.NormWeight, cfg.RMSNormEps)
|
||||
out = mlx.Mul(out, mlx.SiLU(z))
|
||||
out = mlx.Mul(out.AsType(mlx.DTypeFloat32), mlx.SiLU(z.AsType(mlx.DTypeFloat32))).AsType(outDType)
|
||||
out = mlx.Reshape(out, B, L, valueDim)
|
||||
out = g.OutProj.Forward(out)
|
||||
if rc != nil {
|
||||
|
||||
@@ -7,6 +7,13 @@ import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||
@@ -155,3 +162,184 @@ func TestNewCachesLayout(t *testing.T) {
|
||||
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
cfg := &Config{
|
||||
HiddenSize: 4,
|
||||
IntermediateSize: 8,
|
||||
NumHiddenLayers: 2,
|
||||
NumAttentionHeads: 1,
|
||||
NumKeyValueHeads: 1,
|
||||
HeadDim: 4,
|
||||
RMSNormEps: 1e-6,
|
||||
TieWordEmbeddings: true,
|
||||
LayerTypes: []string{"linear", "full"},
|
||||
LinearNumValueHeads: 1,
|
||||
LinearNumKeyHeads: 1,
|
||||
LinearKeyHeadDim: 2,
|
||||
LinearValueHeadDim: 2,
|
||||
LinearConvKernelDim: 4,
|
||||
FullAttentionInterval: 2,
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Config: cfg,
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
}
|
||||
|
||||
bf16 := mlx.DTypeBFloat16
|
||||
f32 := mlx.DTypeFloat32
|
||||
tensors := map[string]*mlx.Array{
|
||||
"model.embed_tokens.weight": mlx.FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 2, 4).AsType(bf16),
|
||||
"model.norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
|
||||
"model.layers.0.input_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
|
||||
"model.layers.0.post_attention_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
|
||||
"model.layers.0.linear_attn.in_proj_qkv.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
1, 1, 0, 0,
|
||||
0, 1, 1, 0,
|
||||
}, 6, 4),
|
||||
"model.layers.0.linear_attn.in_proj_z.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
}, 2, 4),
|
||||
"model.layers.0.linear_attn.in_proj_b.weight": mlx.FromValues([]float32{1, 0, 0, 0}, 1, 4),
|
||||
"model.layers.0.linear_attn.in_proj_a.weight": mlx.FromValues([]float32{0, 1, 0, 0}, 1, 4),
|
||||
"model.layers.0.linear_attn.out_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0,
|
||||
0, 1,
|
||||
1, 1,
|
||||
0, 0,
|
||||
}, 4, 2),
|
||||
"model.layers.0.linear_attn.conv1d.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
1, 1, 0, 0,
|
||||
0, 1, 1, 0,
|
||||
}, 6, 4),
|
||||
"model.layers.0.linear_attn.norm.weight": mlx.FromValues([]float32{1, 1}, 2),
|
||||
"model.layers.0.linear_attn.dt_bias": mlx.FromValues([]float32{0}, 1),
|
||||
"model.layers.0.linear_attn.A_log": mlx.FromValues([]float32{0}, 1),
|
||||
"model.layers.0.mlp.gate_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
1, 1, 0, 0,
|
||||
0, 1, 1, 0,
|
||||
0, 0, 1, 1,
|
||||
1, 0, 0, 1,
|
||||
}, 8, 4),
|
||||
"model.layers.0.mlp.up_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
1, 1, 0, 0,
|
||||
0, 1, 1, 0,
|
||||
0, 0, 1, 1,
|
||||
1, 0, 0, 1,
|
||||
}, 8, 4),
|
||||
"model.layers.0.mlp.down_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 0, 0, 0, 0,
|
||||
}, 4, 8),
|
||||
"model.layers.1.input_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
|
||||
"model.layers.1.post_attention_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
|
||||
"model.layers.1.self_attn.q_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
1, 1, 0, 0,
|
||||
0, 1, 1, 0,
|
||||
0, 0, 1, 1,
|
||||
1, 0, 0, 1,
|
||||
}, 8, 4),
|
||||
"model.layers.1.self_attn.k_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
}, 4, 4),
|
||||
"model.layers.1.self_attn.v_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
}, 4, 4),
|
||||
"model.layers.1.self_attn.o_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
}, 4, 4),
|
||||
"model.layers.1.self_attn.q_norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
|
||||
"model.layers.1.self_attn.k_norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
|
||||
"model.layers.1.mlp.gate_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
1, 1, 0, 0,
|
||||
0, 1, 1, 0,
|
||||
0, 0, 1, 1,
|
||||
1, 0, 0, 1,
|
||||
}, 8, 4),
|
||||
"model.layers.1.mlp.up_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0,
|
||||
0, 1, 0, 0,
|
||||
0, 0, 1, 0,
|
||||
0, 0, 0, 1,
|
||||
1, 1, 0, 0,
|
||||
0, 1, 1, 0,
|
||||
0, 0, 1, 1,
|
||||
1, 0, 0, 1,
|
||||
}, 8, 4),
|
||||
"model.layers.1.mlp.down_proj.weight": mlx.FromValues([]float32{
|
||||
1, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 0, 0, 0, 0,
|
||||
}, 4, 8),
|
||||
}
|
||||
|
||||
if err := m.LoadWeights(tensors); err != nil {
|
||||
t.Fatalf("LoadWeights failed: %v", err)
|
||||
}
|
||||
|
||||
if got := m.Layers[0].InputNorm.Weight.DType(); got != f32 {
|
||||
t.Fatalf("layer 0 input norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
if got := m.Layers[0].PostAttentionNorm.Weight.DType(); got != f32 {
|
||||
t.Fatalf("layer 0 post-attn norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
if got := m.Layers[1].InputNorm.Weight.DType(); got != f32 {
|
||||
t.Fatalf("layer 1 input norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
if got := m.Layers[1].PostAttentionNorm.Weight.DType(); got != f32 {
|
||||
t.Fatalf("layer 1 post-attn norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
|
||||
if got := m.Norm.Weight.DType(); got != f32 {
|
||||
t.Fatalf("final norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
if got := m.Layers[0].Linear.NormWeight.DType(); got != f32 {
|
||||
t.Fatalf("linear-attn norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
if got := m.Layers[1].FullAttn.QNorm.Weight.DType(); got != f32 {
|
||||
t.Fatalf("q norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
if got := m.Layers[1].FullAttn.KNorm.Weight.DType(); got != f32 {
|
||||
t.Fatalf("k norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +71,11 @@ func (t *Tokenizer) BOS() int32 {
|
||||
return t.vocab.BOS
|
||||
}
|
||||
|
||||
// AddBOS returns whether a BOS token should be prepended during encoding.
|
||||
func (t *Tokenizer) AddBOS() bool {
|
||||
return t.vocab.AddBOS
|
||||
}
|
||||
|
||||
// EOS returns the first end of sequence token ID (for backwards compatibility)
|
||||
func (t *Tokenizer) EOS() int32 {
|
||||
if len(t.vocab.EOS) > 0 {
|
||||
|
||||
Reference in New Issue
Block a user