diff --git a/model/models/qwen3next/deltanet.go b/model/models/qwen3next/deltanet.go index 958d1e937..e0a6f7b25 100644 --- a/model/models/qwen3next/deltanet.go +++ b/model/models/qwen3next/deltanet.go @@ -406,8 +406,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked( gDiff := gCumsum.Neg(ctx).Add(ctx, gLast) gDiffExp := gDiff.Exp(ctx) - // key_gdiff = k * exp(g_diff) - keyGDiff := k.Mul(ctx, gDiffExp) + // Reshapes g_diff_exp to [1, chunkSize, nChunks, ...] + gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs) + keyGDiff := k.Mul(ctx, gDiffExpReshaped) + keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // Process chunks and update state var coreAttnOut ml.Tensor @@ -444,12 +446,9 @@ func (gdn *GatedDeltaNet) deltaNetChunked( coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1) } - // Update state for next chunk using pre-computed values + // Update state for next chunk gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1) - kGDiffChunk := keyGDiff.Slice(ctx, 2, chunk, chunk+1, 1) - - // kgdmulvnew = key_gdiff^T @ v_new - kGDiffChunkT := kGDiffChunk.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1) kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT) // state = state * g_last + kgdmulvnew