From 3490e9590bbaba174f89c03207b42727150460e0 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sun, 1 Mar 2026 18:44:04 -0800 Subject: [PATCH] model/qwen3next: avoid crash in in DeltaNet when offloading (#14541) Co-authored-by: Yossi Ovadia --- model/models/qwen3next/deltanet.go | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/model/models/qwen3next/deltanet.go b/model/models/qwen3next/deltanet.go index 6ce315649..9c7aa02b5 100644 --- a/model/models/qwen3next/deltanet.go +++ b/model/models/qwen3next/deltanet.go @@ -454,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked( vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs) stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs) + // Collect chunk outputs and concatenate at the end. + // Avoids SET on buffer-less intermediates under partial offload. + chunks := make([]ml.Tensor, nChunks) + for chunk := range nChunks { qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1) vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1) @@ -475,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked( vAttn := vTNewChunk.Mulmat(ctx, attnChunk) coreAttnOutChunk := attnInter.Add(ctx, vAttn) - v = v.SetInplace( - ctx, - coreAttnOutChunk, - v.Stride(1), - v.Stride(2), - v.Stride(3), - chunk*v.Stride(2), - ) + chunks[chunk] = coreAttnOutChunk // Update state for next chunk gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1) @@ -495,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked( stateT = stateT.Add(ctx, kgdMulVNew) } + // Use a balanced concat tree so concat work does not balloon on long prompts. + for len(chunks) > 1 { + merged := make([]ml.Tensor, 0, (len(chunks)+1)/2) + for i := 0; i < len(chunks); i += 2 { + if i+1 < len(chunks) { + merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2)) + } else { + merged = append(merged, chunks[i]) + } + } + chunks = merged + } + v = chunks[0] + // Final reshape coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)