From 77eb2ca619f97e0ae1b6397bd2ad8f1acf63434c Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 3 Feb 2026 23:27:21 -0800 Subject: [PATCH] model: add qwen3-next architecture (#14051) --- convert/convert.go | 2 + convert/convert_qwen3next.go | 512 +++++++++++++++ convert/reader.go | 1 + convert/reader_safetensors.go | 1 + fs/ggml/ggml.go | 2 + kvcache/cache.go | 7 + llama/patches/0033-ggml-metal-solve_tri.patch | 276 ++++++++ ml/backend.go | 26 + ml/backend/ggml/ggml.go | 72 ++- .../ggml/src/ggml-metal/ggml-metal-device.cpp | 20 + .../ggml/src/ggml-metal/ggml-metal-device.h | 1 + .../ggml/src/ggml-metal/ggml-metal-device.m | 11 + .../src/ggml-metal/ggml-metal-embed.metal | 81 +++ .../ggml/src/ggml-metal/ggml-metal-impl.h | 21 + .../ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 ++ .../ggml/ggml/src/ggml-metal/ggml-metal-ops.h | 1 + .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 60 ++ model/model_test.go | 12 + model/models/models.go | 1 + model/models/qwen3next/attention.go | 103 +++ model/models/qwen3next/cache.go | 596 ++++++++++++++++++ model/models/qwen3next/checkpoints.go | 498 +++++++++++++++ model/models/qwen3next/checkpoints_test.go | 300 +++++++++ model/models/qwen3next/deltanet.go | 473 ++++++++++++++ model/models/qwen3next/model.go | 383 +++++++++++ model/renderers/qwen3coder.go | 6 +- model/renderers/qwen3coder_test.go | 35 +- runner/ollamarunner/cache.go | 13 +- server/quantization.go | 49 ++ 29 files changed, 3614 insertions(+), 12 deletions(-) create mode 100644 convert/convert_qwen3next.go create mode 100644 llama/patches/0033-ggml-metal-solve_tri.patch create mode 100644 model/models/qwen3next/attention.go create mode 100644 model/models/qwen3next/cache.go create mode 100644 model/models/qwen3next/checkpoints.go create mode 100644 model/models/qwen3next/checkpoints_test.go create mode 100644 model/models/qwen3next/deltanet.go create mode 100644 model/models/qwen3next/model.go diff --git a/convert/convert.go b/convert/convert.go index 73b494747..1f318be90 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -317,6 +317,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { conv = &glmOcrModel{} case "Lfm2ForCausalLM": conv = &lfm2Model{} + case "Qwen3NextForCausalLM": + conv = &qwen3NextModel{} default: return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0]) } diff --git a/convert/convert_qwen3next.go b/convert/convert_qwen3next.go new file mode 100644 index 000000000..84db35e14 --- /dev/null +++ b/convert/convert_qwen3next.go @@ -0,0 +1,512 @@ +package convert + +import ( + "fmt" + "io/fs" + "math" + "slices" + "strings" + + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + + "github.com/ollama/ollama/fs/ggml" +) + +type qwen3NextModel struct { + ModelParameters + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + HeadDim uint32 `json:"head_dim"` + RopeTheta float32 `json:"rope_theta"` + RMSNormEPS float32 `json:"rms_norm_eps"` + + // MoE config + NumExperts uint32 `json:"num_experts"` + NumExpertsPerToken uint32 `json:"num_experts_per_tok"` + NormTopkProb bool `json:"norm_topk_prob"` + MoEIntermediateSize uint32 `json:"moe_intermediate_size"` + SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"` + + // Hybrid attention config + FullAttentionInterval uint32 `json:"full_attention_interval"` + + // Linear attention (Gated Delta Net) config + LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"` + LinearKeyHeadDim uint32 `json:"linear_key_head_dim"` + LinearNumKeyHeads uint32 `json:"linear_num_key_heads"` + LinearNumValueHeads uint32 `json:"linear_num_value_heads"` + LinearValueHeadDim uint32 `json:"linear_value_head_dim"` + + // RoPE config + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + RopeScaling struct { + Type string `json:"type"` + Factor ropeFactor `json:"factor"` + } `json:"rope_scaling"` +} + +var _ ModelConverter = (*qwen3NextModel)(nil) + +func (q *qwen3NextModel) parseMore(_ fs.FS) error { + if q.NumHiddenLayers == 0 { + return fmt.Errorf("qwen3next: num_hidden_layers must be set") + } + if q.NumAttentionHeads == 0 { + return fmt.Errorf("qwen3next: num_attention_heads must be set") + } + if q.NumKeyValueHeads == 0 { + return fmt.Errorf("qwen3next: num_key_value_heads must be set") + } + if q.HeadDim == 0 { + return fmt.Errorf("qwen3next: head_dim must be set") + } + if q.RopeTheta == 0 { + return fmt.Errorf("qwen3next: rope_theta must be set") + } + if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 { + return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor) + } + if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 { + return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)") + } + if q.FullAttentionInterval == 0 { + return fmt.Errorf("qwen3next: full_attention_interval must be set") + } + if q.FullAttentionInterval > q.NumHiddenLayers { + return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers) + } + + hasFull := false + for i := range q.NumHiddenLayers { + if (i+1)%q.FullAttentionInterval == 0 { + hasFull = true + break + } + } + if !hasFull { + return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers) + } + + return nil +} + +func (q *qwen3NextModel) KV(t *Tokenizer) KV { + kv := q.ModelParameters.KV(t) + kv["general.architecture"] = "qwen3next" + kv["tokenizer.ggml.pre"] = "qwen2" + kv["block_count"] = q.NumHiddenLayers + kv["context_length"] = q.MaxPositionEmbeddings + kv["embedding_length"] = q.HiddenSize + kv["feed_forward_length"] = q.IntermediateSize + kv["attention.head_count"] = q.NumAttentionHeads + headDim := q.HeadDim + if headDim == 0 && q.NumAttentionHeads > 0 { + headDim = q.HiddenSize / q.NumAttentionHeads + } + kv["attention.key_length"] = headDim + kv["attention.value_length"] = headDim + kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS + kv["rope.freq_base"] = q.RopeTheta + + // RoPE dimension count (partial rotary) + // partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE + partialRotary := q.PartialRotaryFactor + if partialRotary > 0 && partialRotary <= 1 { + kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary) + } + + // MoE config + if q.NumExperts > 0 { + kv["expert_count"] = q.NumExperts + kv["expert_used_count"] = q.NumExpertsPerToken + kv["norm_top_k_prob"] = q.NormTopkProb + if q.MoEIntermediateSize > 0 { + kv["expert_feed_forward_length"] = q.MoEIntermediateSize + } + if q.SharedExpertIntermSize > 0 { + kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize + } + } + + // SSM/Linear attention config + // d_inner = linear_value_head_dim * linear_num_value_heads + dInner := q.LinearValueHeadDim * q.LinearNumValueHeads + kv["ssm.inner_size"] = dInner + kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim + kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads + kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads + kv["ssm.conv_kernel"] = q.LinearConvKernelDim + interval := q.FullAttentionInterval + kv["full_attention_interval"] = interval + + // Build per-layer KV head count array to identify layer types + // 0 = recurrent (linear attention), non-zero = full attention + kvHeadCounts := make([]uint32, q.NumHiddenLayers) + for i := range q.NumHiddenLayers { + // Full attention every full_attention_interval layers (starting at interval-1) + if interval > 0 && (i+1)%interval == 0 { + kvHeadCounts[i] = q.NumKeyValueHeads + } + // else stays 0 (recurrent layer) + } + kv["attention.head_count_kv"] = kvHeadCounts + + // RoPE scaling + if q.RopeScaling.Type != "" { + kv["rope.scaling.type"] = q.RopeScaling.Type + kv["rope.scaling.factor"] = q.RopeScaling.Factor + } + + return kv +} + +func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + + // Create merges for expert tensors - stack individual experts into batched tensors + merges := make([]merge, q.NumHiddenLayers*3) + for i := range q.NumHiddenLayers { + merges[i*3+0] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), + } + merges[i*3+1] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_up_exps.weight", i), + } + merges[i*3+2] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), + } + } + + // Merge expert tensors + merged, remaining := mergeTensors(ts, merges...) + out = append(out, merged...) + + // Process remaining tensors + for _, t := range remaining { + name := t.Name() + shape := t.Shape() + + // Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible + if strings.HasSuffix(name, ".ssm_in.weight") { + if qkv, gate, ok := q.splitQKVZTensor(t); ok { + out = append(out, qkv, gate) + continue + } + panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape)) + } + + switch { + // Add 1 to norm weights (except ssm_norm which is linear_attn.norm) + // This matches the Python converter behavior for qwen3next + case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"): + t.SetRepacker(q.addOne) + out = append(out, &ggml.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: slices.Clone(shape), + WriterTo: t, + }) + + // Handle linear attention A_log -> ssm_a (negate and exp) + // Note: name has already been transformed by Replacements at this point + case strings.HasSuffix(name, ".ssm_a"): + t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { + // Compute -exp(A_log) + result := make([]float32, len(data)) + for i, v := range data { + // -exp(v) + result[i] = -float32(math.Exp(float64(v))) + } + return result, nil + }) + out = append(out, &ggml.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: slices.Clone(shape), + WriterTo: t, + }) + + // Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K] + case strings.HasSuffix(name, ".ssm_conv1d.weight"): + newShape := slices.Clone(shape) + if len(shape) == 3 { + if shape[0] == 1 { + // [1, D, K] -> [D, K] + newShape = []uint64{shape[1], shape[2]} + } else if shape[1] == 1 { + // [D, 1, K] -> [D, K] + newShape = []uint64{shape[0], shape[2]} + } + } + out = append(out, &ggml.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: newShape, + WriterTo: t, + }) + // Squeeze shared expert gate: [D, 1] or [1, D] -> [D] + case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"): + newShape := slices.Clone(shape) + if len(shape) == 2 { + if shape[0] == 1 && shape[1] > 1 { + newShape = []uint64{shape[1]} + } else if shape[1] == 1 && shape[0] > 1 { + newShape = []uint64{shape[0]} + } + } + out = append(out, &ggml.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: newShape, + WriterTo: t, + }) + + default: + out = append(out, &ggml.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: slices.Clone(shape), + WriterTo: t, + }) + } + } + + return out +} + +type qkvzSplitSpec struct { + hidden int + headKDim int + headVDim int + numKHeads int + numVHeads int + qkvzDim int + qkvOut int + gateOut int +} + +func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) { + if len(shape) != 2 { + return qkvzSplitSpec{}, false + } + + numKHeads := int(q.LinearNumKeyHeads) + numVHeads := int(q.LinearNumValueHeads) + headKDim := int(q.LinearKeyHeadDim) + headVDim := int(q.LinearValueHeadDim) + if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 { + return qkvzSplitSpec{}, false + } + if numVHeads%numKHeads != 0 { + return qkvzSplitSpec{}, false + } + + hidden := int(shape[1]) + vPerHead := headVDim * (numVHeads / numKHeads) + qkvzDim := 2*headKDim + 2*vPerHead + expectedOut := qkvzDim * numKHeads + if int(shape[0]) != expectedOut { + return qkvzSplitSpec{}, false + } + + return qkvzSplitSpec{ + hidden: hidden, + headKDim: headKDim, + headVDim: headVDim, + numKHeads: numKHeads, + numVHeads: numVHeads, + qkvzDim: qkvzDim, + qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads, + gateOut: headVDim * numVHeads, + }, true +} + +func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) { + spec, ok := q.qkvzSpec(t.Shape()) + if !ok { + return nil, nil, false + } + + qkvTensor := t.Clone() + qkvTensor.SetRepacker(q.repackQKVZ(spec, false)) + + gateTensor := t.Clone() + gateTensor.SetRepacker(q.repackQKVZ(spec, true)) + + qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1) + gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1) + + return &ggml.Tensor{ + Name: qkvName, + Kind: t.Kind(), + Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)}, + WriterTo: qkvTensor, + }, &ggml.Tensor{ + Name: gateName, + Kind: t.Kind(), + Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)}, + WriterTo: gateTensor, + }, true +} + +func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker { + vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads) + + return func(_ string, data []float32, shape []uint64) ([]float32, error) { + dims := make([]int, len(shape)) + for i := range shape { + dims[i] = int(shape[i]) + } + + var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + var err error + + // Convert to [hidden, out_features] layout for slicing + tt, err = tensor.Transpose(tt, 1, 0) + if err != nil { + return nil, err + } + tt = tensor.Materialize(tt) + + if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil { + return nil, err + } + + offset := 0 + qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim)) + if err != nil { + return nil, err + } + offset += spec.headKDim + kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim)) + if err != nil { + return nil, err + } + offset += spec.headKDim + vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead)) + if err != nil { + return nil, err + } + offset += vPerHead + zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead)) + if err != nil { + return nil, err + } + + qMat := tensor.Materialize(qSlice).(*tensor.Dense) + kMat := tensor.Materialize(kSlice).(*tensor.Dense) + vMat := tensor.Materialize(vSlice).(*tensor.Dense) + zMat := tensor.Materialize(zSlice).(*tensor.Dense) + + if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil { + return nil, err + } + if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil { + return nil, err + } + if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil { + return nil, err + } + if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil { + return nil, err + } + + var out tensor.Tensor + if extractGate { + out = zMat + } else { + out, err = tensor.Concat(1, qMat, kMat, vMat) + if err != nil { + return nil, err + } + } + + out = tensor.Materialize(out) + out, err = tensor.Transpose(out, 1, 0) + if err != nil { + return nil, err + } + out = tensor.Materialize(out) + + if err := out.Reshape(out.Shape().TotalSize()); err != nil { + return nil, err + } + + return native.VectorF32(out.(*tensor.Dense)) + } +} + +// addOne adds 1.0 to all elements in the tensor (for norm weights) +func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) { + n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data)) + ones := tensor.Ones(tensor.Float32, int(shape[0])) + + n, err := n.Add(ones) + if err != nil { + return nil, err + } + + ts, err := native.SelectF32(n, 0) + if err != nil { + return nil, err + } + + var f32s []float32 + for _, t := range ts { + f32s = append(f32s, t...) + } + + return f32s, nil +} + +func (q *qwen3NextModel) Replacements() []string { + return []string{ + // Embeddings and output + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.norm", "output_norm", + "model.layers", "blk", + + // Layer norms + "input_layernorm", "attn_norm", + "post_attention_layernorm", "post_attention_norm", + + // Full attention (self_attn) + "self_attn.q_proj", "attn_q", + "self_attn.q_norm", "attn_q_norm", + "self_attn.k_proj", "attn_k", + "self_attn.k_norm", "attn_k_norm", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + + // Linear attention (Gated Delta Net) + "linear_attn.in_proj_qkvz", "ssm_in", + "linear_attn.in_proj_ba", "ssm_ba", + "linear_attn.conv1d", "ssm_conv1d", + "linear_attn.dt_bias", "ssm_dt", + "linear_attn.dt_proj", "ssm_dt", + "linear_attn.A_log", "ssm_a", + "linear_attn.norm", "ssm_norm", + "linear_attn.out_proj", "ssm_out", + + // MoE (experts are stacked via mergeTensors, not replaced here) + "mlp.gate.weight", "ffn_gate_inp.weight", + "mlp.shared_expert.down_proj", "ffn_down_shexp", + "mlp.shared_expert.gate_proj", "ffn_gate_shexp", + "mlp.shared_expert.up_proj", "ffn_up_shexp", + "mlp.shared_expert_gate", "ffn_gate_inp_shexp", + + // Dense FFN (if any layers use it) + "mlp.down_proj", "ffn_down", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + } +} diff --git a/convert/reader.go b/convert/reader.go index a2ac41dc9..0cff12a22 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -41,6 +41,7 @@ func (t tensorBase) Kind() uint32 { if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") || strings.HasSuffix(t.name, ".bias") || strings.HasSuffix(t.name, ".shortconv.conv.weight") || + strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal t.name == "token_types.weight" || t.name == "v.positional_embedding_vlm" || t.name == "v.tile_position_embd.weight" || diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index da1a62e02..f7dae0646 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -100,6 +100,7 @@ func (st safetensor) Kind() uint32 { !strings.HasPrefix(st.name, "v.") && !strings.HasPrefix(st.name, "s.") && !strings.HasPrefix(st.name, "mm.") && + !strings.Contains(st.name, "ffn_gate_inp_shexp.weight") && kind != tensorKindFP32 { kind = tensorKindBF16 } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index aa5377ebc..11bfb8c90 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -268,6 +268,7 @@ func (kv KV) OllamaEngineRequired() bool { "olmo3", "qwen25vl", "qwen3", "qwen3moe", + "qwen3next", "qwen3vl", "qwen3vlmoe", "glm4moelite", "glmocr", @@ -866,6 +867,7 @@ func (f GGML) FlashAttention() bool { "mistral3", "olmo3", "qwen3", "qwen3moe", + "qwen3next", "qwen3vl", "qwen3vlmoe", }, f.KV().String("general.architecture")) } diff --git a/kvcache/cache.go b/kvcache/cache.go index 405c79733..5c6fc250b 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -75,3 +75,10 @@ type Cache interface { // removed by calling Remove(seq, 0, math.MaxInt32) Remove(seq int, beginIndex, endIndex int32) error } + +// CheckpointCache optionally supports restoring recurrent state to a prior +// position to avoid full prompt reprocessing when a prefix mismatch occurs. +// The returned position is the number of tokens that can be kept (prefix length). +type CheckpointCache interface { + PrepareRestore(seq int, targetPos int32) (int32, bool) +} diff --git a/llama/patches/0033-ggml-metal-solve_tri.patch b/llama/patches/0033-ggml-metal-solve_tri.patch new file mode 100644 index 000000000..7bc65fda7 --- /dev/null +++ b/llama/patches/0033-ggml-metal-solve_tri.patch @@ -0,0 +1,276 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jeffrey Morgan +Date: Tue, 3 Feb 2026 12:00:00 -0800 +Subject: [PATCH] ggml: metal solve_tri + +--- + ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++ + ggml/src/ggml-metal/ggml-metal-device.h | 1 + + ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++ + ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++ + ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++ + ggml/src/ggml-metal/ggml-metal-ops.h | 1 + + ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++ + 7 files changed, 177 insertions(+) + +diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp +index 680904d13..83385c9ef 100644 +--- a/ggml/src/ggml-metal/ggml-metal-device.cpp ++++ b/ggml/src/ggml-metal/ggml-metal-device.cpp +@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met + return res; + } + ++ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { ++ assert(op->op == GGML_OP_SOLVE_TRI); ++ ++ GGML_ASSERT(ggml_is_contiguous(op->src[0])); ++ GGML_ASSERT(ggml_is_contiguous(op->src[1])); ++ ++ char base[256]; ++ char name[256]; ++ ++ snprintf(base, 256, "kernel_solve_tri_f32"); ++ snprintf(name, 256, "%s", base); ++ ++ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); ++ if (!res.pipeline) { ++ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); ++ } ++ ++ return res; ++} ++ + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_GROUP_NORM); + +diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h +index 0a8b9211a..8a9d17460 100644 +--- a/ggml/src/ggml-metal/ggml-metal-device.h ++++ b/ggml/src/ggml-metal/ggml-metal-device.h +@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ++struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); +diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m +index 7b5ee968c..4e5acfbe5 100644 +--- a/ggml/src/ggml-metal/ggml-metal-device.m ++++ b/ggml/src/ggml-metal/ggml-metal-device.m +@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_L2_NORM: + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); ++ case GGML_OP_SOLVE_TRI: ++ return ggml_is_contiguous(op->src[0]) && ++ ggml_is_contiguous(op->src[1]) && ++ op->src[0]->type == GGML_TYPE_F32 && ++ op->src[1]->type == GGML_TYPE_F32 && ++ op->type == GGML_TYPE_F32; ++ case GGML_OP_COUNT_EQUAL: ++ return has_simdgroup_reduction && ++ op->src[0]->type == GGML_TYPE_I32 && ++ op->src[1]->type == GGML_TYPE_I32 && ++ op->type == GGML_TYPE_I64; + case GGML_OP_ARGMAX: + return has_simdgroup_reduction; + case GGML_OP_NORM: +diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h +index 8944b07e9..cfdea9c07 100644 +--- a/ggml/src/ggml-metal/ggml-metal-impl.h ++++ b/ggml/src/ggml-metal/ggml-metal-impl.h +@@ -500,6 +500,27 @@ typedef struct { + float eps; + } ggml_metal_kargs_l2_norm; + ++typedef struct { ++ int32_t ne00; ++ int32_t ne01; ++ int32_t ne02; ++ int32_t ne03; ++ uint64_t nb00; ++ uint64_t nb01; ++ uint64_t nb02; ++ uint64_t nb03; ++ int32_t ne10; ++ int32_t ne11; ++ uint64_t nb10; ++ uint64_t nb11; ++ uint64_t nb12; ++ uint64_t nb13; ++ uint64_t nb0; ++ uint64_t nb1; ++ uint64_t nb2; ++ uint64_t nb3; ++} ggml_metal_kargs_solve_tri; ++ + typedef struct { + int64_t ne00; + int64_t ne01; +diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp +index 80864f303..4ac135603 100644 +--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp ++++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp +@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { + { + n_fuse = ggml_metal_op_l2_norm(ctx, idx); + } break; ++ case GGML_OP_SOLVE_TRI: ++ { ++ n_fuse = ggml_metal_op_solve_tri(ctx, idx); ++ } break; + case GGML_OP_GROUP_NORM: + { + n_fuse = ggml_metal_op_group_norm(ctx, idx); +@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { + return 1; + } + ++int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { ++ ggml_tensor * op = ctx->node(idx); ++ ++ ggml_metal_library_t lib = ctx->lib; ++ ggml_metal_encoder_t enc = ctx->enc; ++ ++ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); ++ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); ++ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); ++ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); ++ GGML_TENSOR_LOCALS( int32_t, ne, op, ne); ++ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ++ ++ ggml_metal_kargs_solve_tri args = { ++ /*.ne00 =*/ ne00, ++ /*.ne01 =*/ ne01, ++ /*.ne02 =*/ ne02, ++ /*.ne03 =*/ ne03, ++ /*.nb00 =*/ nb00, ++ /*.nb01 =*/ nb01, ++ /*.nb02 =*/ nb02, ++ /*.nb03 =*/ nb03, ++ /*.ne10 =*/ ne10, ++ /*.ne11 =*/ ne11, ++ /*.nb10 =*/ nb10, ++ /*.nb11 =*/ nb11, ++ /*.nb12 =*/ nb12, ++ /*.nb13 =*/ nb13, ++ /*.nb0 =*/ nb0, ++ /*.nb1 =*/ nb1, ++ /*.nb2 =*/ nb2, ++ /*.nb3 =*/ nb3, ++ }; ++ ++ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); ++ ++ const int64_t ncols = ne10; ++ const int64_t n_batches = (int64_t)ne02 * ne03; ++ const int64_t nr = n_batches * ncols; ++ ++ int nth = 64; ++ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); ++ if (nth < 1) { ++ nth = 1; ++ } ++ ++ const int64_t n_tg = (nr + nth - 1) / nth; ++ ++ ggml_metal_encoder_set_pipeline(enc, pipeline); ++ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ++ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ++ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ++ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); ++ ++ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1); ++ ++ return 1; ++} ++ + int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + +diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h +index 902b54452..a475183d3 100644 +--- a/ggml/src/ggml-metal/ggml-metal-ops.h ++++ b/ggml/src/ggml-metal/ggml-metal-ops.h +@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx); + int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx); + int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx); + int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); ++int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); + int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); + int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); + int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index d33c16079..c37447a10 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32( + } + } + ++kernel void kernel_solve_tri_f32( ++ constant ggml_metal_kargs_solve_tri & args, ++ device const char * src0, ++ device const char * src1, ++ device char * dst, ++ uint tgpig[[threadgroup_position_in_grid]], ++ ushort tpitg[[thread_position_in_threadgroup]], ++ ushort ntg[[threads_per_threadgroup]]) { ++ const uint64_t ncols = (uint64_t) args.ne10; ++ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03; ++ const uint64_t nr = n_batches * ncols; ++ ++ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg; ++ if (gid >= nr) { ++ return; ++ } ++ ++ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols); ++ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols; ++ const uint64_t i02 = rem / ncols; ++ const uint64_t i01 = rem - i02 * ncols; ++ ++ const uint64_t sa0 = args.nb00 / sizeof(float); ++ const uint64_t sa1 = args.nb01 / sizeof(float); ++ const uint64_t sa2 = args.nb02 / sizeof(float); ++ const uint64_t sa3 = args.nb03 / sizeof(float); ++ ++ const uint64_t sb0 = args.nb10 / sizeof(float); ++ const uint64_t sb1 = args.nb11 / sizeof(float); ++ const uint64_t sb2 = args.nb12 / sizeof(float); ++ const uint64_t sb3 = args.nb13 / sizeof(float); ++ ++ const uint64_t sx0 = args.nb0 / sizeof(float); ++ const uint64_t sx1 = args.nb1 / sizeof(float); ++ const uint64_t sx2 = args.nb2 / sizeof(float); ++ const uint64_t sx3 = args.nb3 / sizeof(float); ++ ++ device const float * A = (device const float *) src0; ++ device const float * B = (device const float *) src1; ++ device float * X = (device float *) dst; ++ ++ const uint64_t A_base = i02 * sa2 + i03 * sa3; ++ const uint64_t B_base = i02 * sb2 + i03 * sb3; ++ const uint64_t X_base = i02 * sx2 + i03 * sx3; ++ ++ const uint64_t n = (uint64_t) args.ne11; ++ ++ for (uint64_t i00 = 0; i00 < n; ++i00) { ++ float sum = 0.0f; ++ for (uint64_t t = 0; t < i00; ++t) { ++ sum += A[A_base + i00 * sa1 + t * sa0] * ++ X[X_base + t * sx1 + i01 * sx0]; ++ } ++ ++ const float diag = A[A_base + i00 * sa1 + i00 * sa0]; ++ X[X_base + i00 * sx1 + i01 * sx0] = ++ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag; ++ } ++} ++ + kernel void kernel_group_norm_f32( + constant ggml_metal_kargs_group_norm & args, + device const float * src0, diff --git a/ml/backend.go b/ml/backend.go index 624e2c773..7db8acc97 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -207,6 +207,32 @@ type Tensor interface { Stddev(ctx Context) Tensor Sqr(ctx Context) Tensor Sqrt(ctx Context) Tensor + Exp(ctx Context) Tensor + Neg(ctx Context) Tensor + + // Clamp clamps values to [min, max] range + Clamp(ctx Context, min, max float32) Tensor + + // Softplus computes ln(1 + exp(x)) + Softplus(ctx Context) Tensor + + // CumSum computes cumulative sum along dimension 0 + CumSum(ctx Context) Tensor + + // Diag creates a diagonal matrix from a 1D tensor + Diag(ctx Context) Tensor + + // Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower) + Tri(ctx Context, triType int) Tensor + + // Fill fills a tensor with a constant value (in-place) + Fill(ctx Context, value float32) Tensor + + // Repeat4D repeats tensor to match target shape + Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor + + // SolveTri solves a triangular system Ax = B + SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor } diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index ea69235b0..ee4e93197 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } } - maxGraphNodes := max(1024, len(meta.Tensors().Items())*8) + maxGraphNodes := max(1024, len(meta.Tensors().Items())*32) sched := C.ggml_backend_sched_new_ext( (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), @@ -1779,6 +1779,76 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor { } } +func (t *Tensor) Exp(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_exp(ctx.(*Context).ctx, t.t), + } +} + +func (t *Tensor) Neg(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_neg(ctx.(*Context).ctx, t.t), + } +} + +func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)), + } +} + +func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_softplus(ctx.(*Context).ctx, t.t), + } +} + +func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_cumsum(ctx.(*Context).ctx, t.t), + } +} + +func (t *Tensor) Diag(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_diag(ctx.(*Context).ctx, t.t), + } +} + +func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)), + } +} + +func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)), + } +} + +func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)), + } +} + +func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)), + } +} + func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor { var mode C.uint32_t switch samplingMode { diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp index 680904d13..83385c9ef 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_SOLVE_TRI); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_solve_tri_f32"); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_GROUP_NORM); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h index 0a8b9211a..8a9d17460 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h @@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m index 7b5ee968c..4e5acfbe5 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m @@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_SOLVE_TRI: + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32; + case GGML_OP_COUNT_EQUAL: + return has_simdgroup_reduction && + op->src[0]->type == GGML_TYPE_I32 && + op->src[1]->type == GGML_TYPE_I32 && + op->type == GGML_TYPE_I64; case GGML_OP_ARGMAX: return has_simdgroup_reduction; case GGML_OP_NORM: diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 3235a18eb..9404c93ce 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -2385,6 +2385,27 @@ typedef struct { float eps; } ggml_metal_kargs_l2_norm; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_solve_tri; + typedef struct { int64_t ne00; int64_t ne01; @@ -5813,6 +5834,66 @@ kernel void kernel_l2_norm_f32( } } +kernel void kernel_solve_tri_f32( + constant ggml_metal_kargs_solve_tri & args, + device const char * src0, + device const char * src1, + device char * dst, + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort ntg[[threads_per_threadgroup]]) { + const uint64_t ncols = (uint64_t) args.ne10; + const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03; + const uint64_t nr = n_batches * ncols; + + const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg; + if (gid >= nr) { + return; + } + + const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols); + const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols; + const uint64_t i02 = rem / ncols; + const uint64_t i01 = rem - i02 * ncols; + + const uint64_t sa0 = args.nb00 / sizeof(float); + const uint64_t sa1 = args.nb01 / sizeof(float); + const uint64_t sa2 = args.nb02 / sizeof(float); + const uint64_t sa3 = args.nb03 / sizeof(float); + + const uint64_t sb0 = args.nb10 / sizeof(float); + const uint64_t sb1 = args.nb11 / sizeof(float); + const uint64_t sb2 = args.nb12 / sizeof(float); + const uint64_t sb3 = args.nb13 / sizeof(float); + + const uint64_t sx0 = args.nb0 / sizeof(float); + const uint64_t sx1 = args.nb1 / sizeof(float); + const uint64_t sx2 = args.nb2 / sizeof(float); + const uint64_t sx3 = args.nb3 / sizeof(float); + + device const float * A = (device const float *) src0; + device const float * B = (device const float *) src1; + device float * X = (device float *) dst; + + const uint64_t A_base = i02 * sa2 + i03 * sa3; + const uint64_t B_base = i02 * sb2 + i03 * sb3; + const uint64_t X_base = i02 * sx2 + i03 * sx3; + + const uint64_t n = (uint64_t) args.ne11; + + for (uint64_t i00 = 0; i00 < n; ++i00) { + float sum = 0.0f; + for (uint64_t t = 0; t < i00; ++t) { + sum += A[A_base + i00 * sa1 + t * sa0] * + X[X_base + t * sx1 + i01 * sx0]; + } + + const float diag = A[A_base + i00 * sa1 + i00 * sa0]; + X[X_base + i00 * sx1 + i01 * sx0] = + (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag; + } +} + kernel void kernel_group_norm_f32( constant ggml_metal_kargs_group_norm & args, device const float * src0, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h index 8944b07e9..cfdea9c07 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h @@ -500,6 +500,27 @@ typedef struct { float eps; } ggml_metal_kargs_l2_norm; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_solve_tri; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp index 80864f303..4ac135603 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_l2_norm(ctx, idx); } break; + case GGML_OP_SOLVE_TRI: + { + n_fuse = ggml_metal_op_solve_tri(ctx, idx); + } break; case GGML_OP_GROUP_NORM: { n_fuse = ggml_metal_op_group_norm(ctx, idx); @@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_solve_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); + + const int64_t ncols = ne10; + const int64_t n_batches = (int64_t)ne02 * ne03; + const int64_t nr = n_batches * ncols; + + int nth = 64; + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + if (nth < 1) { + nth = 1; + } + + const int64_t n_tg = (nr + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1); + + return 1; +} + int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h index 902b54452..a475183d3 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h @@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx); int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx); int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx); int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index d33c16079..c37447a10 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32( } } +kernel void kernel_solve_tri_f32( + constant ggml_metal_kargs_solve_tri & args, + device const char * src0, + device const char * src1, + device char * dst, + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort ntg[[threads_per_threadgroup]]) { + const uint64_t ncols = (uint64_t) args.ne10; + const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03; + const uint64_t nr = n_batches * ncols; + + const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg; + if (gid >= nr) { + return; + } + + const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols); + const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols; + const uint64_t i02 = rem / ncols; + const uint64_t i01 = rem - i02 * ncols; + + const uint64_t sa0 = args.nb00 / sizeof(float); + const uint64_t sa1 = args.nb01 / sizeof(float); + const uint64_t sa2 = args.nb02 / sizeof(float); + const uint64_t sa3 = args.nb03 / sizeof(float); + + const uint64_t sb0 = args.nb10 / sizeof(float); + const uint64_t sb1 = args.nb11 / sizeof(float); + const uint64_t sb2 = args.nb12 / sizeof(float); + const uint64_t sb3 = args.nb13 / sizeof(float); + + const uint64_t sx0 = args.nb0 / sizeof(float); + const uint64_t sx1 = args.nb1 / sizeof(float); + const uint64_t sx2 = args.nb2 / sizeof(float); + const uint64_t sx3 = args.nb3 / sizeof(float); + + device const float * A = (device const float *) src0; + device const float * B = (device const float *) src1; + device float * X = (device float *) dst; + + const uint64_t A_base = i02 * sa2 + i03 * sa3; + const uint64_t B_base = i02 * sb2 + i03 * sb3; + const uint64_t X_base = i02 * sx2 + i03 * sx3; + + const uint64_t n = (uint64_t) args.ne11; + + for (uint64_t i00 = 0; i00 < n; ++i00) { + float sum = 0.0f; + for (uint64_t t = 0; t < i00; ++t) { + sum += A[A_base + i00 * sa1 + t * sa0] * + X[X_base + t * sx1 + i01 * sx0]; + } + + const float diag = A[A_base + i00 * sa1 + i00 * sa0]; + X[X_base + i00 * sx1 + i01 * sx0] = + (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag; + } +} + kernel void kernel_group_norm_f32( constant ggml_metal_kargs_group_norm & args, device const float * src0, diff --git a/model/model_test.go b/model/model_test.go index f6d75b230..ed2868ff3 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -56,6 +56,18 @@ type fakeTensor struct { Name string } +// Stub methods to satisfy ml.Tensor interface +func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f } +func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f } +func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f } +func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f } +func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f } +func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f } +func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f } +func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f } +func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f } +func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f } + func (m *fakeBackend) Get(name string) ml.Tensor { if slices.Contains(m.names, name) { return &fakeTensor{Name: name} diff --git a/model/models/models.go b/model/models/models.go index 4818518c9..d4a8dc536 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -20,5 +20,6 @@ import ( _ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen25vl" _ "github.com/ollama/ollama/model/models/qwen3" + _ "github.com/ollama/ollama/model/models/qwen3next" _ "github.com/ollama/ollama/model/models/qwen3vl" ) diff --git a/model/models/qwen3next/attention.go b/model/models/qwen3next/attention.go new file mode 100644 index 000000000..ee4a06bea --- /dev/null +++ b/model/models/qwen3next/attention.go @@ -0,0 +1,103 @@ +package qwen3next + +import ( + "errors" + "math" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible +// with the attention layer requirements. +var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout") + +// FullAttention implements gated attention with QK normalization and sigmoid-gated output. +// Key differences from standard attention: +// - Q projection outputs 2x size (Q + gate interleaved) +// - Both Q and K have RMSNorm +// - Output is gated: attn * sigmoid(gate) +type FullAttention struct { + Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head] + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { + // Use Dim() instead of Shape() for consistent behavior during graph construction + hiddenDim := hiddenStates.Dim(0) + batchSize := hiddenStates.Dim(1) + nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor + + if cache != nil && cache.IsSupportedForBatch() { + seqTokens := cache.seqTokens() + seqs := cache.numSeqs() + if seqTokens > 0 && seqs > 0 { + if nSeqs > 0 { + // 3D tensor: [hiddenDim, seqTokens, nSeqs] + if batchSize != seqTokens || nSeqs != seqs { + return nil, ErrUnsupportedBatchLayout + } + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs) + batchSize = seqTokens * seqs + } else if batchSize != seqTokens*seqs { + return nil, ErrUnsupportedBatchLayout + } + } + } + headDim := opts.headDim() + numHeads := opts.numHeads + + // Q projection outputs query + gate interleaved + qFull := sa.Query.Forward(ctx, hiddenStates) + + // Reshape to [headDim * 2, numHeads, batchSize] + qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize) + + // Split Q and gate along dimension 0 + // Q: first headDim elements, gate: second headDim elements + query := qFull.Slice(ctx, 0, 0, headDim, 1) + gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1) + + // Make query contiguous for further operations + query = query.Contiguous(ctx, headDim, numHeads, batchSize) + + // K and V projections + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + // Derive numKVHeads from tensor dimensions (per-layer value) + numKVHeads := key.Dim(0) / headDim + + key = key.Reshape(ctx, headDim, numKVHeads, batchSize) + value = value.Reshape(ctx, headDim, numKVHeads, batchSize) + + // Apply QK normalization + query = sa.QueryNorm.Forward(ctx, query, opts.eps) + key = sa.KeyNorm.Forward(ctx, key, opts.eps) + + // Apply RoPE + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) + + // Standard attention computation + scale := opts.attentionScale + if scale == 0 { + scale = 1.0 / math.Sqrt(float64(headDim)) + } + attention := nn.Attention(ctx, query, key, value, scale, cache) + + // Flatten heads + attention = attention.Reshape(ctx, headDim*numHeads, batchSize) + + // Apply sigmoid gate + // gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize] + gate = gate.Contiguous(ctx, headDim*numHeads, batchSize) + gateSigmoid := gate.Sigmoid(ctx) + attention = attention.Mul(ctx, gateSigmoid) + + return sa.Output.Forward(ctx, attention), nil +} diff --git a/model/models/qwen3next/cache.go b/model/models/qwen3next/cache.go new file mode 100644 index 000000000..86ee2b58d --- /dev/null +++ b/model/models/qwen3next/cache.go @@ -0,0 +1,596 @@ +package qwen3next + +import ( + "math" + "slices" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" +) + +var _ kvcache.Cache = (*HybridCache)(nil) + +// HybridCache stores: +// - a standard causal KV cache for full attention layers +// - per-sequence conv state for linear attention layers +// - per-sequence delta state for linear attention layers +// +// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels] +// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads] +type HybridCache struct { + kv *kvcache.Causal + + backend ml.Backend + dtype ml.DType + maxSequences int + + // Conv state dimensions + convDim int // convKernelSize - 1 + convChannels int // d_inner + 2 * num_k_heads * head_k_dim + + // Delta state dimensions + deltaStateSize int // headVDim * headVDim * numVHeads + + // slot mapping for recurrent state (copy-on-write) + slotForSeq map[int]int + refCount []int + freeSlots []int + + // per-layer conv state buffers (allocated lazily) + convCtxs map[int]ml.Context + convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots] + + // per-layer delta state buffers (allocated lazily) + deltaCtxs map[int]ml.Context + deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots] + + // recurrent checkpoints (per slot) + checkpointCount int + checkpointMinPos int32 + checkpointInterval int32 + checkpointCtxSize int + checkpoints map[int]*slotCheckpointStore + pendingRestore map[int]checkpointRestore + curCheckpointPos []int32 + curCheckpointSlots map[int]int + reserveCheckpoints bool + checkpointConvCtxs map[int]ml.Context + checkpointDeltaCtxs map[int]ml.Context + checkpointReserved map[int]struct{} + + // current forward batch (derived in StartForward) + curSeqs []int + curSlots []int + curSlotsInput ml.Tensor + curSeqTokens int + + // track if EnsureWritable has been called for this forward pass + writableEnsured bool + writableError error +} + +func NewHybridCache( + shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), + convDim, convChannels, deltaStateSize int, +) *HybridCache { + return &HybridCache{ + kv: kvcache.NewCausalCache(shift), + convDim: convDim, + convChannels: convChannels, + deltaStateSize: deltaStateSize, + slotForSeq: make(map[int]int), + convCtxs: make(map[int]ml.Context), + convStates: make(map[int]ml.Tensor), + deltaCtxs: make(map[int]ml.Context), + deltaStates: make(map[int]ml.Tensor), + checkpointCount: checkpointCountDefault, + checkpointMinPos: checkpointMinPosDefault, + checkpointInterval: checkpointIntervalDefault, + checkpoints: make(map[int]*slotCheckpointStore), + pendingRestore: make(map[int]checkpointRestore), + curCheckpointSlots: make(map[int]int), + checkpointConvCtxs: make(map[int]ml.Context), + checkpointDeltaCtxs: make(map[int]ml.Context), + checkpointReserved: make(map[int]struct{}), + } +} + +func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { + c.backend = backend + c.dtype = dtype + c.maxSequences = maxSequences + c.checkpoints = make(map[int]*slotCheckpointStore) + c.pendingRestore = make(map[int]checkpointRestore) + c.curCheckpointPos = c.curCheckpointPos[:0] + c.curCheckpointSlots = make(map[int]int) + c.checkpointReserved = make(map[int]struct{}) + c.checkpointCtxSize = c.checkpointCount * c.maxSequences + if c.checkpointCtxSize < 8 { + c.checkpointCtxSize = 8 + } + + // initialize slot allocator + c.refCount = make([]int, maxSequences) + c.freeSlots = c.freeSlots[:0] + for i := maxSequences - 1; i >= 0; i-- { + c.freeSlots = append(c.freeSlots, i) + } + + c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch) +} + +func (c *HybridCache) Close() { + for _, ctx := range c.convCtxs { + ctx.Close() + } + for _, ctx := range c.deltaCtxs { + ctx.Close() + } + for _, ctx := range c.checkpointConvCtxs { + ctx.Close() + } + for _, ctx := range c.checkpointDeltaCtxs { + ctx.Close() + } + c.kv.Close() +} + +func (c *HybridCache) SetConfig(config ml.CacheConfig) { + c.kv.SetConfig(config) +} + +func (c *HybridCache) SetLayer(layer int) { + c.kv.SetLayer(layer) +} + +func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return c.kv.Get(ctx) +} + +func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) { + c.kv.Put(ctx, key, value) +} + +func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { + if err := c.kv.StartForward(ctx, batch, reserve); err != nil { + return err + } + + // Derive equal-length sequence layout for recurrent layers + seqCounts := make(map[int]int) + c.curSeqs = c.curSeqs[:0] + for _, s := range batch.Sequences { + if _, ok := seqCounts[s]; !ok { + c.curSeqs = append(c.curSeqs, s) + } + seqCounts[s]++ + } + + if len(c.curSeqs) == 0 { + return nil + } + + nTokens := len(batch.Sequences) + nSeqs := len(c.curSeqs) + want := nTokens / nSeqs + for _, s := range c.curSeqs { + if seqCounts[s] != want { + return kvcache.ErrNotSupported + } + } + + c.curSeqTokens = want + + // When reserving memory for estimation, use fake slot assignments + if reserve { + c.curSlots = c.curSlots[:0] + slots := make([]int32, nSeqs) + for i := range nSeqs { + c.curSlots = append(c.curSlots, i) + slots[i] = int32(i) + } + c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) + c.reserveCheckpoints = true + c.planCheckpoints(batch) + return nil + } + + // Ensure slots exist for sequences in this batch + c.curSlots = c.curSlots[:0] + var newSlots []int + for _, s := range c.curSeqs { + slot, ok := c.slotForSeq[s] + if !ok { + var err error + slot, err = c.allocSlot() + if err != nil { + return err + } + c.slotForSeq[s] = slot + c.refCount[slot] = 1 + newSlots = append(newSlots, slot) + } + c.curSlots = append(c.curSlots, slot) + } + + // Zero state for newly allocated slots + if len(newSlots) > 0 { + c.zeroSlots(ctx, newSlots) + } + + // Create a tensor for the current slots + slots := make([]int32, len(c.curSlots)) + for i, v := range c.curSlots { + slots[i] = int32(v) + } + c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) + + // Reset writable state for new forward pass + c.writableEnsured = false + c.writableError = nil + c.reserveCheckpoints = false + c.planCheckpoints(batch) + + return nil +} + +func (c *HybridCache) allocSlot() (int, error) { + if len(c.freeSlots) == 0 { + return 0, kvcache.ErrKvCacheFull + } + slot := c.freeSlots[len(c.freeSlots)-1] + c.freeSlots = c.freeSlots[:len(c.freeSlots)-1] + return slot, nil +} + +func (c *HybridCache) freeSlot(slot int) { + if slot >= 0 && slot < c.maxSequences { + c.freeSlots = append(c.freeSlots, slot) + } +} + +// zeroSlots zeros the recurrent state for the given slots across all layers. +func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) { + if len(slots) == 0 { + return + } + + inputCtx := ctx.Input() + + slotIndices := make([]int32, len(slots)) + for i, s := range slots { + slotIndices[i] = int32(s) + } + slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices)) + + // Zero conv states + if len(c.convStates) > 0 { + zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots)) + for _, buf := range c.convStates { + ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor)) + } + } + + // Zero delta states + if len(c.deltaStates) > 0 { + zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots)) + for _, buf := range c.deltaStates { + ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor)) + } + } +} + +// EnsureWritable ensures sequences have private slots (copy-on-write). +func (c *HybridCache) EnsureWritable(ctx ml.Context) error { + for i, seq := range c.curSeqs { + slot, ok := c.slotForSeq[seq] + if !ok { + continue + } + + if slot < 0 || slot >= len(c.refCount) { + continue + } + + if c.refCount[slot] <= 1 { + continue + } + + newSlot, err := c.allocSlot() + if err != nil { + return err + } + c.refCount[slot]-- + c.refCount[newSlot] = 1 + c.slotForSeq[seq] = newSlot + c.curSlots[i] = newSlot + + c.copyRecurrentState(ctx, slot, newSlot) + c.copyCheckpoints(ctx, slot, newSlot) + } + + // Rebuild current slots tensor + slots := make([]int32, len(c.curSlots)) + for i, v := range c.curSlots { + slots[i] = int32(v) + } + c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) + + return nil +} + +func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) { + src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1) + dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1) + + for _, buf := range c.convStates { + rows := buf.Rows(ctx, src) + rowsF32 := rows.Cast(ctx, ml.DTypeF32) + ctx.Forward(buf.SetRows(ctx, rowsF32, dst)) + } + + for _, buf := range c.deltaStates { + rows := buf.Rows(ctx, src) + rowsF32 := rows.Cast(ctx, ml.DTypeF32) + ctx.Forward(buf.SetRows(ctx, rowsF32, dst)) + } +} + +func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) { + c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen) + + // Copy-on-write for recurrent state + if dstSlot, ok := c.slotForSeq[dstSeq]; ok { + if c.validSlot(dstSlot) { + c.refCount[dstSlot]-- + if c.refCount[dstSlot] <= 0 { + c.refCount[dstSlot] = 0 + c.freeSlot(dstSlot) + } + } + delete(c.slotForSeq, dstSeq) + } + + srcSlot, ok := c.slotForSeq[srcSeq] + if !ok { + return + } + + if c.validSlot(srcSlot) { + c.slotForSeq[dstSeq] = srcSlot + c.refCount[srcSlot]++ + } +} + +func (c *HybridCache) CanResume(seq int, pos int32) bool { + if !c.kv.CanResume(seq, pos) { + return false + } + if pos == 0 { + return true + } + return c.hasCheckpoint(seq, pos) +} + +func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error { + if beginIndex > 0 && endIndex != math.MaxInt32 { + return kvcache.ErrNotSupported + } + + if beginIndex > 0 { + restore, ok := c.pendingRestore[seq] + if !ok || restore.pos+1 != beginIndex { + return kvcache.ErrNotSupported + } + if !c.restoreComplete(restore) { + return kvcache.ErrNotSupported + } + // If the recurrent slot is shared, detach it before applying a restore. + if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 { + newSlot, err := c.allocSlot() + if err != nil { + return err + } + ctx := c.backend.NewContext() + c.copyRecurrentState(ctx, slot, newSlot) + c.copyCheckpoints(ctx, slot, newSlot) + if len(c.convStates) > 0 || len(c.deltaStates) > 0 { + ctx.Compute() + } + ctx.Close() + + c.refCount[slot]-- + c.refCount[newSlot] = 1 + c.slotForSeq[seq] = newSlot + + restore.slot = newSlot + c.pendingRestore[seq] = restore + } + } + + if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil { + return err + } + + if beginIndex > 0 { + restore := c.pendingRestore[seq] + delete(c.pendingRestore, seq) + return c.applyCheckpointRestore(restore) + } + + // Removal invalidates recurrent state + slot, ok := c.slotForSeq[seq] + delete(c.pendingRestore, seq) + if !ok { + return nil + } + + if !c.validSlot(slot) { + delete(c.slotForSeq, seq) + return nil + } + + c.refCount[slot]-- + if c.refCount[slot] <= 0 { + c.refCount[slot] = 0 + c.clearCheckpoints(slot) + c.freeSlot(slot) + } + delete(c.slotForSeq, seq) + + return nil +} + +func (c *HybridCache) validSlot(slot int) bool { + return slot >= 0 && slot < len(c.refCount) +} + +func (c *HybridCache) slotsTensor() ml.Tensor { + return c.curSlotsInput +} + +// contiguousSlots returns the starting slot if current slots are contiguous and ordered. +func (c *HybridCache) contiguousSlots() (int, bool) { + if len(c.curSlots) == 0 { + return 0, false + } + start := c.curSlots[0] + for i, s := range c.curSlots { + if s != start+i { + return 0, false + } + } + return start, true +} + +func (c *HybridCache) seqTokens() int { + return c.curSeqTokens +} + +func (c *HybridCache) numSeqs() int { + return len(c.curSeqs) +} + +func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor { + if buf, ok := c.convStates[layer]; ok { + return buf + } + + if _, ok := c.convCtxs[layer]; !ok { + c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer) + } + + // Recurrent state must stay in F32 (ssm_conv kernels are F32-only). + buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences) + c.convStates[layer] = buf + return buf +} + +func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor { + if buf, ok := c.deltaStates[layer]; ok { + return buf + } + + if _, ok := c.deltaCtxs[layer]; !ok { + c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer) + } + + // Recurrent delta state must stay in F32. + buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences) + c.deltaStates[layer] = buf + return buf +} + +func (c *HybridCache) ensureWritableOnce(ctx ml.Context) { + if !c.writableEnsured { + needsWritable := false + for _, seq := range c.curSeqs { + slot, ok := c.slotForSeq[seq] + if !ok { + continue + } + if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 { + needsWritable = true + break + } + } + + if needsWritable { + if err := c.EnsureWritable(ctx); err != nil { + c.writableError = err + } + } + c.writableEnsured = true + } +} + +// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs]. +func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) { + c.ensureWritableOnce(ctx) + + if c.writableError != nil { + return nil, c.writableError + } + + buf := c.convBuffer(ctx, layer) + cur := buf.Rows(ctx, c.slotsTensor()) + return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil +} + +// UpdateConvState writes a new conv state for current batch sequences. +func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) { + buf := c.convBuffer(ctx, layer) + src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs()) + srcF32 := src.Cast(ctx, ml.DTypeF32) + if start, ok := c.contiguousSlots(); ok { + // Fast path: contiguous slots allow a single view + copy + offset := start * buf.Stride(1) + view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs()) + ctx.Forward(srcF32.Copy(ctx, view)) + } else { + ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor())) + } + + c.captureConvCheckpoint(ctx, layer, srcF32) +} + +// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs]. +func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) { + c.ensureWritableOnce(ctx) + + if c.writableError != nil { + return nil, c.writableError + } + + buf := c.deltaBuffer(ctx, layer) + cur := buf.Rows(ctx, c.slotsTensor()) + return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil +} + +// UpdateDeltaState writes a new delta state for current batch sequences. +func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) { + buf := c.deltaBuffer(ctx, layer) + src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs()) + srcF32 := src.Cast(ctx, ml.DTypeF32) + if start, ok := c.contiguousSlots(); ok { + // Fast path: contiguous slots allow a single view + copy + offset := start * buf.Stride(1) + view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs()) + ctx.Forward(srcF32.Copy(ctx, view)) + } else { + ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor())) + } + + c.captureDeltaCheckpoint(ctx, layer, srcF32) +} + +// IsSupportedForBatch returns true if the current batch layout supports recurrent layers. +func (c *HybridCache) IsSupportedForBatch() bool { + return c.curSeqTokens > 0 && len(c.curSeqs) > 0 +} + +// Seqs returns the ordered unique sequences for the current forward pass. +func (c *HybridCache) Seqs() []int { + return slices.Clone(c.curSeqs) +} diff --git a/model/models/qwen3next/checkpoints.go b/model/models/qwen3next/checkpoints.go new file mode 100644 index 000000000..913af1c05 --- /dev/null +++ b/model/models/qwen3next/checkpoints.go @@ -0,0 +1,498 @@ +package qwen3next + +import ( + "log/slog" + "math" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" +) + +const ( + checkpointCountDefault = 32 + checkpointMinPosDefault = int32(16) + checkpointIntervalDefault = int32(1280) +) + +// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU +// memory usage while preserving prefix reuse for recurrent state. + +type checkpointEntry struct { + pos int32 + conv map[int]ml.Tensor + delta map[int]ml.Tensor +} + +type slotCheckpointStore struct { + entries []checkpointEntry + size int + next int + lastPos int32 +} + +type checkpointRestore struct { + slot int + idx int + pos int32 +} + +func newSlotCheckpointStore(n int) *slotCheckpointStore { + entries := make([]checkpointEntry, n) + for i := range entries { + entries[i].pos = -1 + } + return &slotCheckpointStore{ + entries: entries, + lastPos: -1, + } +} + +func (s *slotCheckpointStore) reset() { + s.size = 0 + s.next = 0 + s.lastPos = -1 + for i := range s.entries { + s.entries[i].pos = -1 + } +} + +func (s *slotCheckpointStore) record(pos int32) int { + if len(s.entries) == 0 { + return -1 + } + idx := s.next + s.next = (s.next + 1) % len(s.entries) + if s.size < len(s.entries) { + s.size++ + } + s.entries[idx].pos = pos + s.lastPos = pos + return idx +} + +func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) { + bestIdx := -1 + bestPos := int32(-1) + for i := range s.entries { + pos := s.entries[i].pos + if pos < 0 || pos >= targetPos { + continue + } + if pos > bestPos { + bestPos = pos + bestIdx = i + } + } + if bestIdx < 0 { + return -1, -1, false + } + return bestIdx, bestPos, true +} + +func (s *slotCheckpointStore) pruneAfter(pos int32) { + if len(s.entries) == 0 { + s.size = 0 + s.next = 0 + s.lastPos = -1 + return + } + + size := 0 + next := -1 + minPos := int32(math.MaxInt32) + minIdx := 0 + for i := range s.entries { + if s.entries[i].pos > pos { + s.entries[i].pos = -1 + } + if s.entries[i].pos >= 0 { + size++ + if s.entries[i].pos < minPos { + minPos = s.entries[i].pos + minIdx = i + } + } else if next == -1 { + next = i + } + } + + s.size = size + if size == 0 { + s.next = 0 + s.lastPos = -1 + return + } + if next != -1 { + s.next = next + } else { + // Full ring: overwrite the oldest checkpoint next. + s.next = minIdx + } + s.lastPos = pos +} + +func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) { + minPos = int32(math.MaxInt32) + maxPos = int32(-1) + for i := range s.entries { + pos := s.entries[i].pos + if pos < 0 { + continue + } + size++ + if pos < minPos { + minPos = pos + } + if pos > maxPos { + maxPos = pos + } + } + if size == 0 { + minPos = -1 + maxPos = -1 + } + return size, minPos, maxPos, s.lastPos +} + +func (c *HybridCache) planCheckpoints(batch input.Batch) { + if c.checkpointCount == 0 || len(c.curSeqs) == 0 { + c.curCheckpointPos = c.curCheckpointPos[:0] + for k := range c.curCheckpointSlots { + delete(c.curCheckpointSlots, k) + } + return + } + + if cap(c.curCheckpointPos) < len(c.curSeqs) { + c.curCheckpointPos = make([]int32, len(c.curSeqs)) + } else { + c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)] + } + for i := range c.curCheckpointPos { + c.curCheckpointPos[i] = -1 + } + for k := range c.curCheckpointSlots { + delete(c.curCheckpointSlots, k) + } + + posMax := make(map[int]int32, len(c.curSeqs)) + for i, seq := range batch.Sequences { + pos := batch.Positions[i] + if cur, ok := posMax[seq]; !ok || pos > cur { + posMax[seq] = pos + } + } + + for i, seq := range c.curSeqs { + pos, ok := posMax[seq] + if !ok { + continue + } + if pos < c.checkpointMinPos { + continue + } + slot := c.curSlots[i] + store := c.checkpointStore(slot) + lastPos := store.lastPos + if lastPos < 0 || pos-lastPos >= c.checkpointInterval { + c.curCheckpointPos[i] = pos + } + } +} + +func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore { + store, ok := c.checkpoints[slot] + if ok { + return store + } + store = newSlotCheckpointStore(c.checkpointCount) + c.checkpoints[slot] = store + return store +} + +func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int { + if c.checkpointCount == 0 { + return -1 + } + if idx, ok := c.curCheckpointSlots[slot]; ok { + return idx + } + store := c.checkpointStore(slot) + idx := store.record(pos) + if idx >= 0 { + c.curCheckpointSlots[slot] = idx + } + return idx +} + +func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool { + if pos <= 0 { + return false + } + slot, ok := c.slotForSeq[seq] + if !ok { + return false + } + store, ok := c.checkpoints[slot] + if !ok { + return false + } + _, _, ok = store.bestIndex(pos) + return ok +} + +func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) { + if targetPos <= 0 { + return 0, false + } + slot, ok := c.slotForSeq[seq] + if !ok { + return 0, false + } + store, ok := c.checkpoints[slot] + if !ok { + slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0) + return 0, false + } + idx, pos, ok := store.bestIndex(targetPos) + if !ok { + size, minPos, maxPos, lastPos := store.window() + slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size, + "min", minPos, "max", maxPos, "last", lastPos) + return 0, false + } + c.pendingRestore[seq] = checkpointRestore{ + slot: slot, + idx: idx, + pos: pos, + } + return pos + 1, true +} + +func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error { + entry, ok := c.restoreEntry(restore) + if !ok { + return kvcache.ErrNotSupported + } + + ctx := c.backend.NewContext() + defer ctx.Close() + + slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1) + for layer, src := range entry.conv { + buf := c.convBuffer(ctx, layer) + ctx.Forward(buf.SetRows(ctx, src, slotIdx)) + } + for layer, src := range entry.delta { + buf := c.deltaBuffer(ctx, layer) + ctx.Forward(buf.SetRows(ctx, src, slotIdx)) + } + + if len(entry.conv) > 0 || len(entry.delta) > 0 { + ctx.Compute() + } + store := c.checkpoints[restore.slot] + store.pruneAfter(restore.pos) + return nil +} + +func (c *HybridCache) restoreComplete(restore checkpointRestore) bool { + _, ok := c.restoreEntry(restore) + return ok +} + +func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) { + store, ok := c.checkpoints[restore.slot] + if !ok || restore.idx < 0 || restore.idx >= len(store.entries) { + return nil, false + } + entry := &store.entries[restore.idx] + if entry.pos < 0 { + return nil, false + } + if !c.entryComplete(entry) { + return nil, false + } + return entry, true +} + +func (c *HybridCache) entryComplete(entry *checkpointEntry) bool { + for layer := range c.convStates { + if entry.conv == nil || entry.conv[layer] == nil { + return false + } + } + for layer := range c.deltaStates { + if entry.delta == nil || entry.delta[layer] == nil { + return false + } + } + return true +} + +func (c *HybridCache) clearCheckpoints(slot int) { + if store, ok := c.checkpoints[slot]; ok { + store.reset() + } +} + +func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) { + if c.checkpointCount == 0 { + return + } + srcStore, ok := c.checkpoints[srcSlot] + if !ok || srcStore.size == 0 { + return + } + dstStore := c.checkpointStore(dstSlot) + dstStore.size = srcStore.size + dstStore.next = srcStore.next + dstStore.lastPos = srcStore.lastPos + + for i := range srcStore.entries { + srcEntry := &srcStore.entries[i] + dstEntry := &dstStore.entries[i] + dstEntry.pos = srcEntry.pos + if srcEntry.conv != nil { + if dstEntry.conv == nil { + dstEntry.conv = make(map[int]ml.Tensor) + } + for layer, src := range srcEntry.conv { + dst := c.ensureCheckpointConv(layer, dstEntry) + ctx.Forward(src.Copy(ctx, dst)) + } + } + if srcEntry.delta != nil { + if dstEntry.delta == nil { + dstEntry.delta = make(map[int]ml.Tensor) + } + for layer, src := range srcEntry.delta { + dst := c.ensureCheckpointDelta(layer, dstEntry) + ctx.Forward(src.Copy(ctx, dst)) + } + } + } +} + +func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) { + if c.checkpointCount == 0 { + return + } + if c.reserveCheckpoints { + c.reserveCheckpointConv(layer) + return + } + if len(c.curCheckpointPos) == 0 { + return + } + for i, pos := range c.curCheckpointPos { + if pos < 0 { + continue + } + slot := c.curSlots[i] + idx := c.checkpointIndexForSlot(slot, pos) + if idx < 0 { + continue + } + entry := &c.checkpoints[slot].entries[idx] + dst := c.ensureCheckpointConv(layer, entry) + seqSlice := src.Slice(ctx, 1, i, i+1, 1) + ctx.Forward(seqSlice.Copy(ctx, dst)) + } +} + +func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) { + if c.checkpointCount == 0 { + return + } + if c.reserveCheckpoints { + c.reserveCheckpointDelta(layer) + return + } + if len(c.curCheckpointPos) == 0 { + return + } + for i, pos := range c.curCheckpointPos { + if pos < 0 { + continue + } + slot := c.curSlots[i] + idx := c.checkpointIndexForSlot(slot, pos) + if idx < 0 { + continue + } + entry := &c.checkpoints[slot].entries[idx] + dst := c.ensureCheckpointDelta(layer, entry) + seqSlice := src.Slice(ctx, 1, i, i+1, 1) + ctx.Forward(seqSlice.Copy(ctx, dst)) + } +} + +func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor { + if entry.conv == nil { + entry.conv = make(map[int]ml.Tensor) + } + if t, ok := entry.conv[layer]; ok { + return t + } + ctx, ok := c.checkpointConvCtxs[layer] + if !ok { + ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer) + c.checkpointConvCtxs[layer] = ctx + } + t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1) + entry.conv[layer] = t + return t +} + +func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor { + if entry.delta == nil { + entry.delta = make(map[int]ml.Tensor) + } + if t, ok := entry.delta[layer]; ok { + return t + } + ctx, ok := c.checkpointDeltaCtxs[layer] + if !ok { + ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer) + c.checkpointDeltaCtxs[layer] = ctx + } + t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1) + entry.delta[layer] = t + return t +} + +func (c *HybridCache) reserveCheckpointConv(layer int) { + key := checkpointReserveKey(layer, 0) + if _, ok := c.checkpointReserved[key]; ok { + return + } + for slot := range c.maxSequences { + store := c.checkpointStore(slot) + for i := range store.entries { + entry := &store.entries[i] + _ = c.ensureCheckpointConv(layer, entry) + } + } + c.checkpointReserved[key] = struct{}{} +} + +func (c *HybridCache) reserveCheckpointDelta(layer int) { + key := checkpointReserveKey(layer, 1) + if _, ok := c.checkpointReserved[key]; ok { + return + } + for slot := range c.maxSequences { + store := c.checkpointStore(slot) + for i := range store.entries { + entry := &store.entries[i] + _ = c.ensureCheckpointDelta(layer, entry) + } + } + c.checkpointReserved[key] = struct{}{} +} + +func checkpointReserveKey(layer int, kind int) int { + return layer*2 + kind +} diff --git a/model/models/qwen3next/checkpoints_test.go b/model/models/qwen3next/checkpoints_test.go new file mode 100644 index 000000000..440a3a2cf --- /dev/null +++ b/model/models/qwen3next/checkpoints_test.go @@ -0,0 +1,300 @@ +package qwen3next + +import ( + "errors" + "math" + "os" + "testing" + + "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" +) + +func newTestBackend(tb testing.TB) ml.Backend { + tb.Helper() + + f, err := os.CreateTemp(tb.TempDir(), "*.gguf") + if err != nil { + tb.Fatal(err) + } + if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil { + _ = f.Close() + tb.Fatal(err) + } + if err := f.Close(); err != nil { + tb.Fatal(err) + } + + b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true}) + if err != nil { + tb.Fatal(err) + } + tb.Cleanup(func() { + b.Close() + }) + + return b +} + +func TestSlotCheckpointStoreBestIndex(t *testing.T) { + store := newSlotCheckpointStore(2) + store.record(10) + store.record(20) + + _, pos, ok := store.bestIndex(15) + if !ok || pos != 10 { + t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok) + } + + store.record(30) // overwrite oldest (10) + + if _, _, ok := store.bestIndex(15); ok { + t.Fatalf("expected no checkpoint for targetPos=15 after overwrite") + } + + _, pos, ok = store.bestIndex(40) + if !ok || pos != 30 { + t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok) + } +} + +func TestHybridCachePrepareRestore(t *testing.T) { + cache := NewHybridCache(nil, 1, 1, 1) + cache.checkpointCount = 3 + cache.checkpoints = make(map[int]*slotCheckpointStore) + cache.pendingRestore = make(map[int]checkpointRestore) + + cache.slotForSeq[1] = 0 + store := cache.checkpointStore(0) + store.record(5) + store.record(9) + store.record(15) + + restorePos, ok := cache.PrepareRestore(1, 12) + if !ok { + t.Fatalf("expected restore ok") + } + if restorePos != 10 { + t.Fatalf("expected restorePos 10, got %d", restorePos) + } + rest, ok := cache.pendingRestore[1] + if !ok { + t.Fatalf("expected pending restore entry") + } + if rest.pos != 9 { + t.Fatalf("expected pending restore pos 9, got %d", rest.pos) + } +} + +func TestSlotCheckpointStorePruneAfter(t *testing.T) { + store := newSlotCheckpointStore(3) + store.record(10) + store.record(20) + store.record(30) + + store.pruneAfter(20) + + if store.lastPos != 20 { + t.Fatalf("expected lastPos 20, got %d", store.lastPos) + } + + _, pos, ok := store.bestIndex(25) + if !ok || pos != 20 { + t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok) + } + + _, pos, ok = store.bestIndex(35) + if !ok || pos != 20 { + t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok) + } +} + +func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) { + backend := newTestBackend(t) + + cache := NewHybridCache(nil, 1, 2, 2) + cache.Init(backend, ml.DTypeF16, 2, 8, 2) + + cache.slotForSeq[1] = 0 + cache.slotForSeq[2] = 0 + cache.refCount[0] = 2 + cache.refCount[1] = 0 + cache.freeSlots = []int{1} + + store := cache.checkpointStore(0) + idx := store.record(9) + cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9} + + if err := cache.Remove(1, 10, math.MaxInt32); err != nil { + t.Fatalf("Remove failed: %v", err) + } + + if cache.slotForSeq[1] == cache.slotForSeq[2] { + t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1]) + } + if cache.slotForSeq[1] != 1 { + t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1]) + } + if cache.slotForSeq[2] != 0 { + t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2]) + } + if cache.refCount[0] != 1 || cache.refCount[1] != 1 { + t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1]) + } + if _, ok := cache.pendingRestore[1]; ok { + t.Fatalf("expected pending restore to be cleared") + } +} + +func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) { + cache := NewHybridCache(nil, 1, 2, 2) + cache.checkpointCount = 3 + cache.checkpoints = make(map[int]*slotCheckpointStore) + cache.pendingRestore = make(map[int]checkpointRestore) + + cache.slotForSeq[1] = 0 + cache.refCount = []int{1} + cache.freeSlots = nil + + // Simulate that layer 0 has both conv and delta state (so entryComplete expects both) + cache.convStates[0] = nil // placeholder to indicate layer 0 exists + cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists + + store := cache.checkpointStore(0) + idx := store.record(9) + entry := &store.entries[idx] + // Only set conv checkpoint, not delta - making it incomplete + entry.conv = map[int]ml.Tensor{0: nil} + // entry.delta is not set, so checkpoint is incomplete + + cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9} + + err := cache.Remove(1, 10, math.MaxInt32) + if !errors.Is(err, kvcache.ErrNotSupported) { + t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err) + } +} + +func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) { + cache := NewHybridCache(nil, 1, 2, 2) + cache.checkpointCount = 3 + cache.checkpoints = make(map[int]*slotCheckpointStore) + cache.pendingRestore = make(map[int]checkpointRestore) + + cache.slotForSeq[1] = 0 + cache.refCount = []int{1} + cache.freeSlots = nil + + // Don't set convStates/deltaStates - with no layers to check, + // entryComplete will return true as long as entry.pos >= 0 + + store := cache.checkpointStore(0) + idx := store.record(9) + + cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9} + + // Test that restoreComplete returns true when no layers need checkpoints + restore := cache.pendingRestore[1] + if !cache.restoreComplete(restore) { + t.Fatalf("expected restoreComplete to return true for complete checkpoint") + } +} + +func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) { + // Test that ring buffer wrap-around reuses entries without clearing maps. + store := newSlotCheckpointStore(3) + + // Fill the buffer + store.record(10) + store.record(20) + store.record(30) + + // Create fake tensor data in the first entry's maps + store.entries[0].conv = make(map[int]ml.Tensor) + store.entries[0].conv[0] = nil // Simulated tensor reference + store.entries[0].delta = make(map[int]ml.Tensor) + store.entries[0].delta[0] = nil // Simulated tensor reference + + // Record another entry, which should wrap around and overwrite entry 0 + store.record(40) + + // Verify the maps are still present (we reuse tensors) + if store.entries[0].conv == nil { + t.Fatalf("expected conv map to be preserved on reuse") + } + if store.entries[0].delta == nil { + t.Fatalf("expected delta map to be preserved on reuse") + } + + // Verify the new position was recorded + if store.entries[0].pos != 40 { + t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos) + } +} + +func TestSlotCheckpointStoreFullCapacity(t *testing.T) { + // Test behavior when buffer is exactly at capacity + store := newSlotCheckpointStore(2) + + idx1 := store.record(10) + idx2 := store.record(20) + + if idx1 != 0 || idx2 != 1 { + t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2) + } + + if store.size != 2 { + t.Fatalf("expected size 2, got %d", store.size) + } + + // Verify both checkpoints are accessible + _, pos1, ok1 := store.bestIndex(15) + _, pos2, ok2 := store.bestIndex(25) + + if !ok1 || pos1 != 10 { + t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1) + } + if !ok2 || pos2 != 20 { + t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2) + } +} + +func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) { + // Test behavior with zero-size buffer + store := newSlotCheckpointStore(0) + + idx := store.record(10) + if idx != -1 { + t.Fatalf("expected record to return -1 for empty buffer, got %d", idx) + } + + _, _, ok := store.bestIndex(15) + if ok { + t.Fatalf("expected no checkpoint for empty buffer") + } +} + +func TestSlotCheckpointStorePruneAfterAll(t *testing.T) { + // Test pruning that removes all checkpoints + store := newSlotCheckpointStore(3) + store.record(10) + store.record(20) + store.record(30) + + // Prune everything by setting threshold below all positions + store.pruneAfter(5) + + if store.size != 0 { + t.Fatalf("expected size 0 after pruning all, got %d", store.size) + } + // When all checkpoints are pruned, lastPos is reset to -1 + if store.lastPos != -1 { + t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos) + } + + _, _, ok := store.bestIndex(100) + if ok { + t.Fatalf("expected no checkpoint after pruning all") + } +} diff --git a/model/models/qwen3next/deltanet.go b/model/models/qwen3next/deltanet.go new file mode 100644 index 000000000..958d1e937 --- /dev/null +++ b/model/models/qwen3next/deltanet.go @@ -0,0 +1,473 @@ +package qwen3next + +import ( + "errors" + "log/slog" + "math" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +const chunkSize = 64 + +// TriType constants for triangular matrix operations +const ( + TriTypeUpperDiag = 0 + TriTypeUpper = 1 + TriTypeLowerDiag = 2 + TriTypeLower = 3 +) + +// convKernel wraps the 1D convolution kernel tensor +type convKernel struct { + Weight ml.Tensor `gguf:"weight"` +} + +// Masks holds pre-computed mask tensors for chunked attention +type Masks struct { + Causal ml.Tensor // Lower triangular [chunkSize, chunkSize] + Identity ml.Tensor // Diagonal [chunkSize, chunkSize] + Diag ml.Tensor // causal + identity +} + +// GatedDeltaNet implements linear attention with SSM convolution and recurrent state. +// It implements the Operator interface directly. +type GatedDeltaNet struct { + // Optimized path: pre-split QKV and gate + SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated) + SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate + SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha + SSMConv1D *convKernel `gguf:"ssm_conv1d"` + SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias + SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp() + SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` + SSMOut *nn.Linear `gguf:"ssm_out"` + + // Layer index for cache access (set during model construction) + Layer int +} + +// createMasks builds the constant mask tensors (called once, reused for all chunks) +func createMasks(ctx ml.Context) *Masks { + ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize) + ones = ones.Fill(ctx, 1.0) + causalMask := ones.Tri(ctx, TriTypeLower) + + onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize) + onesVec = onesVec.Fill(ctx, 1.0) + identity := onesVec.Diag(ctx) + + diagMask := causalMask.Add(ctx, identity) + + return &Masks{ + Causal: causalMask, + Identity: identity, + Diag: diagMask, + } +} + +func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { + layer := gdn.Layer + nSeqTokens := hiddenStates.Dim(1) + nSeqs := hiddenStates.Dim(2) + if cache != nil && cache.IsSupportedForBatch() { + seqTokens := cache.seqTokens() + seqs := cache.numSeqs() + if seqTokens > 0 && seqs > 0 { + if nSeqs > 1 { + if nSeqTokens != seqTokens || nSeqs != seqs { + return nil, ErrUnsupportedBatchLayout + } + } else { + if nSeqTokens != seqTokens*seqs { + return nil, ErrUnsupportedBatchLayout + } + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs) + nSeqTokens = seqTokens + nSeqs = seqs + } + } + } + + headKDim := opts.ssmDState + numKHeads := opts.ssmNGroup + numVHeads := opts.ssmDtRank + headVDim := opts.ssmDInner / numVHeads + convKernelSize := opts.convKernelSize + + mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates) + qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads + + if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil { + return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)") + } + // Optimized path: pre-split QKV and gate + qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs) + z := gdn.SSMQKVGate.Forward(ctx, hiddenStates) + + baNewDim := 2 * numVHeads / numKHeads + mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs) + + // Split beta and alpha + betaSize := numVHeads / numKHeads + alphaSize := numVHeads / numKHeads + + b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1) + a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1) + + // Reshape to merge head dimensions + beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs) + alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs) + + // Compute gate: softplus(alpha + dt_bias) * -A + alphaBiased := alpha.Add(ctx, gdn.SSMDT) + alphaSoftplus := alphaBiased.Softplus(ctx) + gate := alphaSoftplus.Mul(ctx, gdn.SSMA) + qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3) + + // Get conv state from cache + convStates, err := cache.ConvState(ctx, layer) + if err != nil { + // Log this - if it happens, short-term context will be lost + slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err) + convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs) + } + + // Reshape conv states + convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs) + + // Concatenate with input for convolution + convInput := convStates.Concat(ctx, qkvMixed, 0) + + // Save new conv state (last convKernelSize-1 tokens) + lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1) + cache.UpdateConvState(ctx, layer, lastConvStates) + + // Apply SSM convolution (kernel must be F32 for Metal) + convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight) + convOutput = convOutput.SILU(ctx) + + // Reshape for extraction + convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs) + + // Extract convolved Q, K, V + qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1) + kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1) + vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1) + + // Reshape to 4D + qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs) + kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs) + vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs) + + // Get delta state from cache + state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads) + if err != nil { + // Log this - if it happens frequently, context will degrade + slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err) + state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs) + } + state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs) + + // Repeat interleave Q and K if numKHeads != numVHeads + if numKHeads != numVHeads { + repeatFactor := numVHeads / numKHeads + + qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs) + kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs) + + qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1) + kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1) + + qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs) + kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs) + } + + // Choose computation mode based on sequence length + var attnOut ml.Tensor + if nSeqTokens == 1 { + attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache) + } else { + // Use pre-computed masks from opts (created once in Model.Forward) + attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache) + } + + // Apply gated normalization + attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs) + z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs) + + // norm(attnOut, z) = RMSNorm(attnOut) * silu(z) + attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps) + zSilu := z2D.SILU(ctx) + attnOutGated := attnOutNorm.Mul(ctx, zSilu) + + // Reshape for output projection + finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs) + + out := gdn.SSMOut.Forward(ctx, finalOutput) + return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil +} + +// deltaNetAutoregressive implements single-token state update. +// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]). +func (gdn *GatedDeltaNet) deltaNetAutoregressive( + ctx ml.Context, + q, k, v, gate, beta, state ml.Tensor, + opts *Options, + layer int, + cache *HybridCache, +) ml.Tensor { + numVHeads := v.Dim(1) + headVDim := v.Dim(0) + nSeqs := q.Dim(3) + + // L2 normalize Q and K + q = q.L2Norm(ctx, opts.eps) + k = k.L2Norm(ctx, opts.eps) + + // Scale Q + scale := 1.0 / math.Sqrt(float64(headVDim)) + q = q.Scale(ctx, scale) + + // Sigmoid beta + beta = beta.Sigmoid(ctx) + + // Reshape state: [headVDim, headVDim, numVHeads, nSeqs] + state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs) + + // Reshape gate and beta for broadcasting + gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs) + betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs) + + // Apply exponential to gate + gT = gT.Exp(ctx) + + // state = state * g_t + state = state.Mul(ctx, gT) + + // kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) + kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs) + kvMem := state.Mul(ctx, kTUnsqueezed) + // Sum over dim=-2 (second dimension after permute) + kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + kvMem = kvMem.SumRows(ctx) + kvMem = kvMem.Permute(ctx, 1, 0, 2, 3) + + // v_t with singleton dimension + vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs) + + // delta = (v_t - kv_mem) * beta_t + vDiff := vT.Sub(ctx, kvMem) + delta := vDiff.Mul(ctx, betaT) + + // state = state + k_t.unsqueeze(-1) * delta + kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs) + kTDelta := kTUnsqueezedBroad.Mul(ctx, delta) + state = state.Add(ctx, kTDelta) + + // core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2) + qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs) + stateQ := state.Mul(ctx, qTUnsqueezed) + stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + coreAttnOut := stateQ.SumRows(ctx) + coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3) + + // Update delta state in cache + cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)) + + return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs) +} + +// deltaNetChunked implements chunked computation for prefill. +// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]). +func (gdn *GatedDeltaNet) deltaNetChunked( + ctx ml.Context, + q, k, v, gate, beta, state ml.Tensor, + masks *Masks, + opts *Options, + layer int, + cache *HybridCache, +) ml.Tensor { + headKDim := q.Dim(0) + numVHeads := v.Dim(1) + headVDim := v.Dim(0) + nTokens := q.Dim(2) + nSeqs := q.Dim(3) + + // L2 normalize Q and K + q = q.L2Norm(ctx, opts.eps) + k = k.L2Norm(ctx, opts.eps) + + // Scale Q + scale := 1.0 / math.Sqrt(float64(headVDim)) + q = q.Scale(ctx, scale) + + // Sigmoid beta + beta = beta.Sigmoid(ctx) + + // Permute tensors for chunked computation + q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs) + k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs) + v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs) + gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs) + + beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) + state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs) + + // Compute padding + pad := (chunkSize - nTokens%chunkSize) % chunkSize + nChunks := (nTokens + pad) / chunkSize + + // Pad tensors + if pad > 0 { + q = q.Pad(ctx, 0, pad, 0, 0) + k = k.Pad(ctx, 0, pad, 0, 0) + v = v.Pad(ctx, 0, pad, 0, 0) + gate = gate.Pad(ctx, pad, 0, 0, 0) + beta = beta.Pad(ctx, 0, pad, 0, 0) + } + + // Use pre-computed masks (passed in, not recreated) + causalMask := masks.Causal + identity := masks.Identity + diagMask := masks.Diag + identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1) + + // v_beta = v * beta, k_beta = k * beta + vBeta := v.Mul(ctx, beta) + kBeta := k.Mul(ctx, beta) + + // Reshape for chunked computation + q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs) + k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs) + kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs) + vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs) + + gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs) + + // g_cumsum = cumsum(gate) + gCumsum := gate.CumSum(ctx) + + // Compute decay mask + gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs) + gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs) + gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs) + decayMask := gcsBroadcast.Sub(ctx, gcsI) + + decayMask = decayMask.Mul(ctx, diagMask) + decayMask = decayMask.Exp(ctx) + decayMask = decayMask.Mul(ctx, diagMask) + + // k @ k_beta^T + kMulKBeta := k.Mulmat(ctx, kBeta) + + // k_decay = k @ k_beta^T * decay_mask + kDecay := kMulKBeta.Mul(ctx, decayMask) + + // attn = -k_decay * causal_mask + attn := kDecay.Neg(ctx).Mul(ctx, causalMask) + + // Triangular solve: (I - attn_lower)^-1 @ attn + attnLower := attn.Mul(ctx, causalMask) + lhs := attnLower.Neg(ctx).Add(ctx, identity4D) + linSolve := lhs.SolveTri(ctx, attn, true, true, false) + attn = linSolve.Mul(ctx, causalMask) + attn = attn.Add(ctx, identity4D) + + // v = v_beta^T @ attn + vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + v = vBetaT.Mulmat(ctx, attn) + + // Compute g_exp for state update + gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + gExp := gCumsumT.Exp(ctx) + + // kbeta_gexp = k_beta * g_exp + kBetaGExp := kBeta.Mul(ctx, gExp) + + // k_cumdecay = attn @ kbeta_gexp^T + kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + kCumdecay := attn.Mulmat(ctx, kBetaGExpT) + kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + + // Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask + attnKQ := k.Mulmat(ctx, q) + attnKQ = attnKQ.Mul(ctx, decayMask) + attnKQ = attnKQ.Mul(ctx, diagMask) + + // Pre-compute g_last and key_gdiff + // g_last = view of last element in g_cumsum along chunk_size dimension + // We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs] + gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs) + gLastExp := gLast.Exp(ctx) + + // g_diff = -(g_cumsum - g_last) = g_last - g_cumsum + gDiff := gCumsum.Neg(ctx).Add(ctx, gLast) + gDiffExp := gDiff.Exp(ctx) + + // key_gdiff = k * exp(g_diff) + keyGDiff := k.Mul(ctx, gDiffExp) + + // Process chunks and update state + var coreAttnOut ml.Tensor + newState := state + + for chunk := range nChunks { + qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1) + vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1) + gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1) + kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1) + attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed! + + // state^T - permute is needed but Contiguous creates a copy + stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs) + + // v_prime = k_cumdecay @ state + vPrime := stateT.Mulmat(ctx, kCumdecayChunk) + + // v_new = v - v_prime + vNew := vChunk.Sub(ctx, vPrime) + vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + + // attn_inter = (q * g_exp) @ state + qGExp := qChunk.Mul(ctx, gExpChunk) + attnInter := stateT.Mulmat(ctx, qGExp) + + // core_attn_out = attn_inter + attn @ v_new + vAttn := vNewT.Mulmat(ctx, attnChunk) + coreAttnOutChunk := attnInter.Add(ctx, vAttn) + + if coreAttnOut == nil { + coreAttnOut = coreAttnOutChunk + } else { + coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1) + } + + // Update state for next chunk using pre-computed values + gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1) + kGDiffChunk := keyGDiff.Slice(ctx, 2, chunk, chunk+1, 1) + + // kgdmulvnew = key_gdiff^T @ v_new + kGDiffChunkT := kGDiffChunk.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT) + + // state = state * g_last + kgdmulvnew + gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs) + newState = newState.Mul(ctx, gExpLastReshaped) + newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)) + } + + // Final reshape + coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs) + + // Slice to remove padding + if pad > 0 { + coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1) + } + + // Update delta state in cache + cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)) + + return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs) +} diff --git a/model/models/qwen3next/model.go b/model/models/qwen3next/model.go new file mode 100644 index 000000000..4ee4eebc8 --- /dev/null +++ b/model/models/qwen3next/model.go @@ -0,0 +1,383 @@ +package qwen3next + +import ( + "cmp" + "fmt" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +// Options contains model configuration +type Options struct { + hiddenSize int + numHeads int + numKVHeads int + keyLength int + valueLength int + ropeDim int + + eps float32 + ropeBase float32 + ropeScale float32 + ropeType string + originalContextLength int + attentionScale float64 + + // MoE config + numExperts int + numExpertsUsed int + normTopKProb bool + + // Linear attention (Gated Delta Net) config + ssmDInner int // d_inner = head_v_dim * num_v_heads + ssmDState int // head_k_dim + ssmNGroup int // num_k_heads + ssmDtRank int // num_v_heads + convKernelSize int // SSM conv kernel size + + // Per-layer type from GGUF metadata + isRecurrent []bool + + // Pre-computed masks for chunked attention (created once per forward pass) + masks *Masks +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + opts := []func(*rope.Options){rope.WithTypeNeoX()} + if o.ropeType == "yarn" { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + opts = append(opts, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + rope.WithAttentionFactor(attnFactor), + ) + } + ropeDim := cmp.Or(o.ropeDim, o.headDim()) + return nn.RoPE(ctx, states, positions, ropeDim, o.ropeBase, 1./o.ropeScale, opts...) +} + +// Operator is the interface for attention-like operators +type Operator interface { + Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) +} + +// MLP is the interface for feedforward networks +type MLP interface { + Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor +} + +// sparse implements MoE with shared experts +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` + + // Shared experts + SharedGateInp *nn.Linear `gguf:"ffn_gate_inp_shexp"` + SharedGate *nn.Linear `gguf:"ffn_gate_shexp"` + SharedUp *nn.Linear `gguf:"ffn_up_shexp"` + SharedDown *nn.Linear `gguf:"ffn_down_shexp"` +} + +func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) + if batchSize == 0 { + batchSize = 1 + } + hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize) + + // Router logits + routerLogits := mlp.Router.Forward(ctx, hiddenStates2D) + + // Softmax routing weights + routingWeights := routerLogits.Softmax(ctx) + selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts) + if opts.normTopKProb { + routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1)) + routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx)) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1)) + } + + hiddenStates3D := hiddenStates2D.Reshape(ctx, hiddenStates2D.Dim(0), 1, hiddenStates2D.Dim(1)) + + // Expert computation with SILU activation + gateOut := mlp.Gate.Forward(ctx, hiddenStates3D, selectedExperts) + upOut := mlp.Up.Forward(ctx, hiddenStates3D, selectedExperts) + experts := gateOut.SILU(ctx, upOut) + experts = mlp.Down.Forward(ctx, experts, selectedExperts) + experts = experts.Mul(ctx, routingWeights) + + // Sum over experts + moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + + // Add shared experts if present + if mlp.SharedUp != nil { + sharedGate := mlp.SharedGate.Forward(ctx, hiddenStates2D) + sharedUp := mlp.SharedUp.Forward(ctx, hiddenStates2D) + sharedOut := sharedGate.SILU(ctx, sharedUp) + sharedOut = mlp.SharedDown.Forward(ctx, sharedOut) + + // Apply shared expert gating + if mlp.SharedGateInp != nil { + sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D) + sharedGateVal = sharedGateVal.Sigmoid(ctx) + // Broadcast gate to match dimensions + sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0)) + sharedOut = sharedOut.Mul(ctx, sharedGateVal) + } + + moeOut = moeOut.Add(ctx, sharedOut) + } + + return moeOut +} + +// dense implements standard feedforward +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +// Layer represents a single transformer layer +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + AttentionPostNorm *nn.RMSNorm `gguf:"post_attention_norm"` // Post-attention norm before FFN + Operator Operator + + FFNNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP MLP +} + +func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { + residual := hiddenStates + + // Pre-attention norm + hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + + // Attention (full or linear) + var err error + hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, positions, cache, opts) + if err != nil { + return nil, err + } + + // Output projection for last layer + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + // First residual connection + hiddenStates = hiddenStates.Add(ctx, residual) + + // Save for FFN residual + ffnResidual := hiddenStates + + // Post-attention norm (before FFN) + hiddenStates = l.AttentionPostNorm.Forward(ctx, hiddenStates, opts.eps) + + // FFN + hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts) + + // Second residual connection + return hiddenStates.Add(ctx, ffnResidual), nil +} + +// Model is the main Qwen3-Next model +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Layers []Layer `gguf:"blk"` + + *Options +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + cache := m.Cache.(*HybridCache) + + // Create masks once per forward pass + m.Options.masks = createMasks(ctx) + + for i, layer := range m.Layers { + cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + var err error + hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options) + if err != nil { + return nil, err + } + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil +} + +var _ model.Model = (*Model)(nil) + +func New(c fs.Config) (model.Model, error) { + numLayers := int(c.Uint("block_count")) + layers := make([]Layer, numLayers) + + // Get per-layer head counts (for detecting layer type) + type headCounts interface { + HeadCount() []uint64 + HeadCountKV() []uint64 + } + + var isRecurrent []bool + var headCountKV []uint64 + if hc, ok := c.(headCounts); ok { + headCountKV = hc.HeadCountKV() + } + + isRecurrent = make([]bool, numLayers) + hasZero := false + hasFull := false + for i := range numLayers { + // If KV head count is 0, it's a recurrent layer + if i < len(headCountKV) && headCountKV[i] == 0 { + isRecurrent[i] = true + hasZero = true + } else if i < len(headCountKV) && headCountKV[i] > 0 { + hasFull = true + } + } + if !hasZero || !hasFull { + return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values") + } + + // Determine if MoE + isMoE := c.Uint("expert_count") > 0 + + for i := range layers { + if isRecurrent[i] { + layers[i].Operator = &GatedDeltaNet{Layer: i} + } else { + layers[i].Operator = &FullAttention{} + } + + if isMoE { + layers[i].MLP = &sparse{} + } else { + layers[i].MLP = &dense{} + } + } + + opts := &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: func() int { + for _, v := range headCountKV { + if v > 0 { + return int(v) + } + } + return 0 + }(), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + ropeDim: int(c.Uint("rope.dimension_count")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeType: c.String("rope.scaling.type"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + attentionScale: float64(c.Float("attention.scale")), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("norm_top_k_prob", true), + ssmDInner: int(c.Uint("ssm.inner_size")), + ssmDState: int(c.Uint("ssm.state_size")), + ssmNGroup: int(c.Uint("ssm.group_count")), + ssmDtRank: int(c.Uint("ssm.time_step_rank")), + convKernelSize: int(c.Uint("ssm.conv_kernel")), + isRecurrent: isRecurrent, + } + if opts.numKVHeads == 0 { + return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value") + } + + // Calculate cache dimensions + convDim := max(0, opts.convKernelSize-1) + convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState + headVDim := 0 + numVHeads := opts.ssmDtRank + if numVHeads > 0 { + headVDim = opts.ssmDInner / numVHeads + } + deltaStateSize := headVDim * headVDim * numVHeads + + // Validate dimension assumption: headKDim == headVDim is required for state computations + headKDim := opts.ssmDState + if headKDim != headVDim && headKDim > 0 && headVDim > 0 { + return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim) + } + + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + // Qwen3 tokenizers typically set add_bos_token=false and bos_token=null. + // Default to false when the GGUF key is missing to avoid injecting a spurious BOS. + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + ), + Layers: layers, + Options: opts, + } + + m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize) + return &m, nil +} + +func init() { + model.Register("qwen3next", New) +} diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go index 2b5a5ae95..33466ba81 100644 --- a/model/renderers/qwen3coder.go +++ b/model/renderers/qwen3coder.go @@ -167,12 +167,12 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _ // only start a new user block if this is the first tool response if i == 0 || filteredMessages[i-1].Role != "tool" { - sb.WriteString(imStartTag + "user\n") + sb.WriteString(imStartTag + "user") } - sb.WriteString("\n") + sb.WriteString("\n\n") sb.WriteString(message.Content) - sb.WriteString("\n\n") + sb.WriteString("\n") // close the user block only if this is the last tool response if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" { diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go index b6ca56e75..9f91c1f67 100644 --- a/model/renderers/qwen3coder_test.go +++ b/model/renderers/qwen3coder_test.go @@ -1,6 +1,7 @@ package renderers import ( + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -127,8 +128,7 @@ fahrenheit <|im_start|>user {"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12} - -<|im_end|> +<|im_end|> <|im_start|>user That sounds nice! What about New York?<|im_end|> <|im_start|>assistant @@ -233,8 +233,7 @@ I'll call double(1) and triple(2) for you. {"number": 6} - -<|im_end|> +<|im_end|> <|im_start|>assistant `, }, @@ -280,8 +279,7 @@ call tool<|im_end|> <|im_start|>user {"payload": {"foo": "bar"}} - -<|im_end|> +<|im_end|> <|im_start|>assistant `, }, @@ -337,6 +335,31 @@ func TestFormatToolCallArgument(t *testing.T) { } } +func TestQwen3CoderRendererToolResponseNoTrailingNewline(t *testing.T) { + msgs := []api.Message{ + {Role: "user", Content: "call tool"}, + {Role: "assistant", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{ + Name: "echo", + Arguments: testArgs(map[string]any{"payload": "ok"}), + }}, + }}, + {Role: "tool", Content: "{\"payload\":\"ok\"}", ToolName: "echo"}, + } + + rendered, err := (&Qwen3CoderRenderer{}).Render(msgs, nil, nil) + if err != nil { + t.Fatal(err) + } + + if strings.Contains(rendered, "\n<|im_end|>") { + t.Fatalf("expected no newline after , got:\n%s", rendered) + } + if !strings.Contains(rendered, "<|im_end|>") { + t.Fatalf("expected to be immediately followed by <|im_end|>, got:\n%s", rendered) + } +} + func TestQwen3ToolDefinitionTypes(t *testing.T) { tests := []struct { name string diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index faab1b229..895a8fb77 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -124,8 +124,17 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In } if c.cache != nil { - if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) { - numPast = 0 + if numPast > 0 { + // Recurrent caches use checkpoints to pick a safe resume position. + if cc, ok := c.cache.(kvcache.CheckpointCache); ok { + if restored, ok := cc.PrepareRestore(slot.Id, numPast); ok { + numPast = restored + } else { + numPast = 0 + } + } else if !c.cache.CanResume(slot.Id, numPast) { + numPast = 0 + } } err = c.cache.Remove(slot.Id, numPast, math.MaxInt32) diff --git a/server/quantization.go b/server/quantization.go index edfd4b470..76c54d8f6 100644 --- a/server/quantization.go +++ b/server/quantization.go @@ -58,6 +58,48 @@ func useMoreBits(iLayer, nLayers int) bool { return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2 } +func qwen3nextQuantType(name string) (fsggml.TensorType, bool) { + switch { + // Full attention + case strings.HasSuffix(name, ".attn_q.weight"): + return fsggml.TensorTypeQ4_K, true + case strings.HasSuffix(name, ".attn_k.weight"): + return fsggml.TensorTypeQ4_K, true + case strings.HasSuffix(name, ".attn_v.weight"): + return fsggml.TensorTypeQ6_K, true + case strings.HasSuffix(name, ".attn_output.weight"): + return fsggml.TensorTypeQ4_K, true + + // Linear attention (Gated Delta Net) after split + case strings.HasSuffix(name, ".attn_qkv.weight"): + return fsggml.TensorTypeQ4_K, true + case strings.HasSuffix(name, ".attn_gate.weight"): + return fsggml.TensorTypeQ4_K, true + + // SSM + case strings.HasSuffix(name, ".ssm_ba.weight"): + return fsggml.TensorTypeQ4_K, true + case strings.HasSuffix(name, ".ssm_out.weight"): + return fsggml.TensorTypeQ4_K, true + + // MoE experts + shared experts + case strings.HasSuffix(name, ".ffn_down_exps.weight"): + return fsggml.TensorTypeQ6_K, true + case strings.HasSuffix(name, ".ffn_down_shexp.weight"): + return fsggml.TensorTypeQ6_K, true + case strings.HasSuffix(name, ".ffn_gate_exps.weight"): + return fsggml.TensorTypeQ4_K, true + case strings.HasSuffix(name, ".ffn_gate_shexp.weight"): + return fsggml.TensorTypeQ4_K, true + case strings.HasSuffix(name, ".ffn_up_exps.weight"): + return fsggml.TensorTypeQ4_K, true + case strings.HasSuffix(name, ".ffn_up_shexp.weight"): + return fsggml.TensorTypeQ4_K, true + } + + return 0, false +} + func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType { // Ported from llama_tensor_get_type, removed unsupported quantization types nExperts := max(1, kv.Uint("expert_count", 0)) @@ -217,6 +259,7 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil // do not quantize expert gating tensors quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight") + quantize = quantize && !strings.Contains(name, "ffn_gate_inp_shexp.weight") // do not quantize positional embeddings and token types (BERT) quantize = quantize && (name != "position_embd.weight") @@ -244,6 +287,12 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil newType := fsggml.TensorType(t.Kind) if quantize { + if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) { + if qt, ok := qwen3nextQuantType(name); ok { + return qt + } + } + // get more optimal quantization type based on the tensor shape, layer, etc. newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype) if newType != defaultType {