mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
glm4.7
This commit is contained in:
@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.0" CACHE STRING "")
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
|
||||
@@ -15,6 +15,12 @@ func (m Linear) Forward(x *Array) *Array {
|
||||
return x.Matmul(w)
|
||||
}
|
||||
|
||||
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
||||
w := m.Weight.Transpose(0, 2, 1)
|
||||
// TODO: bias
|
||||
return x.GatherMM(w, lhs, rhs, sorted)
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight Array `weight:"weight"`
|
||||
}
|
||||
|
||||
@@ -37,6 +37,12 @@ func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgsortAxis(axis int) *Array {
|
||||
out := New("ARGSORT_AXIS", t)
|
||||
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE", t)
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
@@ -79,12 +85,42 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Divide(other *Array) *Array {
|
||||
out := New("DIVIDE", t, other)
|
||||
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS", t)
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Flatten(startAxis, endAxis int) *Array {
|
||||
out := New("FLATTEN", t)
|
||||
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) FloorDivide(other *Array) *Array {
|
||||
out := New("FLOOR_DIVIDE", t, other)
|
||||
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
if lhs == nil {
|
||||
lhs = New("")
|
||||
}
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_MM", t, other, lhs, rhs)
|
||||
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", t)
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
@@ -150,18 +186,45 @@ func (t *Array) Squeeze(axis int) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
|
||||
vectorData := make([]C.mlx_array, len(others)+1)
|
||||
vectorData[0] = t.ctx
|
||||
for i := range others {
|
||||
vectorData[i+1] = others[i].ctx
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK_AXIS", append(others, t)...)
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT", t, other)
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
|
||||
out := New("SUM_AXIS", t)
|
||||
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS", t, indices)
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS", t, indices)
|
||||
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH", t)
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
|
||||
316
x/mlxrunner/model/glm/4/moe/lite/model.go
Normal file
316
x/mlxrunner/model/glm/4/moe/lite/model.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package glm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int `json:"num_key_value_heads"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
|
||||
QLoraRank int `json:"q_lora_rank"`
|
||||
KVLoraRank int `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int `json:"qk_nope_head_dim"`
|
||||
|
||||
NumRoutedExperts int `json:"n_routed_experts"`
|
||||
NumSharedExperts int `json:"n_shared_experts"`
|
||||
NumExpertsPerTok int `json:"num_experts_per_tok"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
FirstKDenseReplace int `json:"first_k_dense_replace"`
|
||||
|
||||
mlx.RoPE
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Layer `weight:"model.layers"`
|
||||
Norm mlx.RMSNorm `weight:"model.norm"`
|
||||
LMHead mlx.Linear `weight:"lm_head"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func (m Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := inputs.Dim(0), inputs.Dim(1)
|
||||
h := m.EmbedTokens.Forward(inputs)
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.Options)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
return h
|
||||
}
|
||||
|
||||
func (m Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
InputLayernorm mlx.RMSNorm `weight:"input_layernorm"`
|
||||
Attention Attention `weight:"self_attn"`
|
||||
PostAttentionLayernorm mlx.RMSNorm `weight:"post_attention_layernorm"`
|
||||
MLP MLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
func (m Layer) Forward(h *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
r := h
|
||||
h = m.InputLayernorm.Forward(h, opts.RMSNormEps)
|
||||
h = m.Attention.Forward(h, cache, B, L, opts)
|
||||
h = h.Add(r)
|
||||
|
||||
r = h
|
||||
h = m.PostAttentionLayernorm.Forward(h, opts.RMSNormEps)
|
||||
h = m.MLP.Forward(h, B, L, opts)
|
||||
h = h.Add(r)
|
||||
return h
|
||||
}
|
||||
|
||||
type MultiLinear struct {
|
||||
Weight mlx.Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (m MultiLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return x.Matmul(m.Weight.Transpose(0, 2, 1))
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QAProj mlx.Linear `weight:"q_a_proj"`
|
||||
QALayernorm mlx.RMSNorm `weight:"q_a_layernorm"`
|
||||
QBProj mlx.Linear `weight:"q_b_proj"`
|
||||
|
||||
KVAProjWithMQA mlx.Linear `weight:"kv_a_proj_with_mqa"`
|
||||
KVALayernorm mlx.RMSNorm `weight:"kv_a_layernorm"`
|
||||
KVBProj mlx.Linear `weight:"kv_b_proj"`
|
||||
|
||||
embedQ MultiLinear
|
||||
unembedOut MultiLinear
|
||||
|
||||
OProj mlx.Linear `weight:"o_proj"`
|
||||
}
|
||||
|
||||
func (m *Attention) AfterLoad(root *model.Root) ([]*mlx.Array, error) {
|
||||
bts, err := root.ReadFile("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
QKNopeHeadDim int `json:"qk_nope_head_dim"`
|
||||
KVLoraRank int `json:"kv_lora_rank"`
|
||||
}
|
||||
if err := json.Unmarshal(bts, &opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w := m.KVBProj.Weight.Reshape(opts.NumAttentionHeads, -1, opts.KVLoraRank)
|
||||
m.embedQ.Weight.Set(w.Slice(mlx.Slice(), mlx.Slice(0, opts.QKNopeHeadDim), mlx.Slice()).Transpose(0, 2, 1))
|
||||
m.unembedOut.Weight.Set(w.Slice(mlx.Slice(), mlx.Slice(opts.QKNopeHeadDim, 0), mlx.Slice()))
|
||||
|
||||
return []*mlx.Array{
|
||||
&m.embedQ.Weight,
|
||||
&m.unembedOut.Weight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
query := m.QAProj.Forward(hiddenStates)
|
||||
query = m.QALayernorm.Forward(query, opts.RMSNormEps)
|
||||
query = m.QBProj.Forward(query)
|
||||
|
||||
query = query.Reshape(B, L, opts.NumAttentionHeads, -1)
|
||||
query = query.Transpose(0, 2, 1, 3)
|
||||
|
||||
queryNope := query.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.QKNopeHeadDim))
|
||||
queryRope := query.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(opts.QKNopeHeadDim, 0))
|
||||
|
||||
compressedKV := m.KVAProjWithMQA.Forward(hiddenStates)
|
||||
|
||||
keyRope := compressedKV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(opts.KVLoraRank, 0))
|
||||
keyRope = keyRope.Reshape(B, L, 1, opts.QKRopeHeadDim)
|
||||
keyRope = keyRope.Transpose(0, 2, 1, 3)
|
||||
|
||||
kvCompressed := compressedKV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.KVLoraRank))
|
||||
|
||||
var offset int
|
||||
if cache != nil {
|
||||
offset = cache.Offset()
|
||||
}
|
||||
|
||||
queryRope = opts.RoPE.Forward(queryRope, offset)
|
||||
keyRope = opts.RoPE.Forward(keyRope, offset)
|
||||
|
||||
key := m.KVALayernorm.Forward(kvCompressed, opts.RMSNormEps).
|
||||
ExpandDims(1).
|
||||
Concatenate(3, keyRope)
|
||||
|
||||
if cache != nil {
|
||||
key, _ = cache.Update(key, mlx.Zeros(mlx.DTypeBFloat16, B, 1, L, 0))
|
||||
}
|
||||
|
||||
value := key.Clone().Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.KVLoraRank))
|
||||
query = m.embedQ.Forward(queryNope).Concatenate(3, queryRope)
|
||||
|
||||
attention := mlx.ScaledDotProductAttention(query, key, value, nil, float32(1.0/math.Sqrt(float64(opts.QKNopeHeadDim+opts.QKRopeHeadDim))))
|
||||
attention = m.unembedOut.Forward(attention)
|
||||
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
|
||||
return m.OProj.Forward(attention)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
Forward(*mlx.Array, int, int, Options) *mlx.Array
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
GateProj mlx.Linear `weight:"gate_proj"`
|
||||
UpProj mlx.Linear `weight:"up_proj"`
|
||||
DownProj mlx.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
func (m dense) Forward(h *mlx.Array, _, _ int, opts Options) *mlx.Array {
|
||||
h = mlx.SILU(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h))
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
type Gate struct {
|
||||
Gate mlx.Linear `weight:"gate"`
|
||||
CorrectionBias mlx.Array `weight:"gate.e_score_correction_bias"`
|
||||
}
|
||||
|
||||
func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) {
|
||||
scores = m.Gate.Forward(h).AsType(mlx.DTypeFloat32).Sigmoid()
|
||||
original := scores
|
||||
scores = scores.Add(&m.CorrectionBias)
|
||||
|
||||
indices = scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1)
|
||||
indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok))
|
||||
|
||||
scores = original.TakeAlongAxis(indices, -1)
|
||||
if opts.NumExpertsPerTok > 1 && opts.NormTopKProb {
|
||||
scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20)))
|
||||
}
|
||||
|
||||
scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor))
|
||||
return indices, scores
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Gate
|
||||
|
||||
Experts []dense `weight:"experts"`
|
||||
fused struct {
|
||||
GateProj mlx.Linear
|
||||
UpProj mlx.Linear
|
||||
DownProj mlx.Linear
|
||||
}
|
||||
|
||||
SharedExperts dense `weight:"shared_experts"`
|
||||
}
|
||||
|
||||
func (m *sparse) AfterLoad(*model.Root) ([]*mlx.Array, error) {
|
||||
w1 := make([]*mlx.Array, len(m.Experts))
|
||||
w2 := make([]*mlx.Array, len(m.Experts))
|
||||
w3 := make([]*mlx.Array, len(m.Experts))
|
||||
|
||||
for i := range m.Experts {
|
||||
w1[i] = &m.Experts[i].GateProj.Weight
|
||||
w2[i] = &m.Experts[i].UpProj.Weight
|
||||
w3[i] = &m.Experts[i].DownProj.Weight
|
||||
}
|
||||
|
||||
m.fused.GateProj.Weight.Set(w1[0].StackAxis(0, w1[1:]...))
|
||||
m.fused.UpProj.Weight.Set(w2[0].StackAxis(0, w2[1:]...))
|
||||
m.fused.DownProj.Weight.Set(w3[0].StackAxis(0, w3[1:]...))
|
||||
|
||||
return []*mlx.Array{
|
||||
&m.fused.GateProj.Weight,
|
||||
&m.fused.UpProj.Weight,
|
||||
&m.fused.DownProj.Weight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m sparse) Forward(h *mlx.Array, B, L int, opts Options) *mlx.Array {
|
||||
indices, scores := m.Gate.Forward(h, opts)
|
||||
scores = scores.ExpandDims(-1)
|
||||
|
||||
flat := h.ExpandDims(-2).ExpandDims(-2).Reshape(-1, 1, 1, opts.HiddenSize)
|
||||
indices = indices.Reshape(-1, opts.NumExpertsPerTok)
|
||||
|
||||
sort := B*L >= 64
|
||||
var inverseOrder *mlx.Array
|
||||
if sort {
|
||||
indicesAll := indices.Flatten(0, len(indices.Dims())-1)
|
||||
order := indicesAll.ArgsortAxis(0)
|
||||
inverseOrder = order.ArgsortAxis(0)
|
||||
flat = flat.Squeeze(1).TakeAxis(order.FloorDivide(mlx.FromValue(opts.NumExpertsPerTok)), 0).ExpandDims(1)
|
||||
indices = indicesAll.TakeAxis(order, 0).Reshape(B*L*opts.NumExpertsPerTok, 1)
|
||||
}
|
||||
|
||||
experts := mlx.SILU(m.fused.GateProj.Gather(flat, nil, indices, sort)).
|
||||
Multiply(m.fused.UpProj.Gather(flat, nil, indices, sort))
|
||||
experts = m.fused.DownProj.Gather(experts, nil, indices, sort)
|
||||
|
||||
if sort {
|
||||
experts = experts.Squeeze(2).Squeeze(1).TakeAxis(inverseOrder, 0)
|
||||
experts = experts.Reshape(-1, opts.NumExpertsPerTok, opts.HiddenSize)
|
||||
} else {
|
||||
experts = experts.Squeeze(2)
|
||||
}
|
||||
|
||||
experts = experts.Reshape(B, L, opts.NumExpertsPerTok, opts.HiddenSize)
|
||||
experts = experts.Multiply(scores).SumAxis(-2, false).AsType(experts.DType())
|
||||
experts = experts.Add(m.SharedExperts.Forward(h, B, L, opts))
|
||||
return experts.Reshape(B, L, -1)
|
||||
}
|
||||
|
||||
func init() {
|
||||
base.Register("Glm4MoeLiteForCausalLM", func(root *model.Root) (base.Model, error) {
|
||||
bts, err := root.ReadFile("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts Options
|
||||
if err := json.Unmarshal(bts, &opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts.RoPE = mlx.RoPE{
|
||||
Dims: opts.QKRopeHeadDim,
|
||||
Traditional: true,
|
||||
Base: opts.RopeTheta,
|
||||
Scale: 1,
|
||||
}
|
||||
|
||||
layers := make([]Layer, opts.NumHiddenLayers)
|
||||
for i := range layers {
|
||||
if i < opts.FirstKDenseReplace {
|
||||
layers[i].MLP = &dense{}
|
||||
} else {
|
||||
layers[i].MLP = &sparse{Experts: make([]dense, opts.NumRoutedExperts)}
|
||||
}
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
@@ -2,5 +2,6 @@ package model
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model/gemma/3"
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model/glm/4/moe/lite"
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model/llama"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user