mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
models: add nemotronh architecture support (#14356)
This commit is contained in:
752
kvcache/recurrent.go
Normal file
752
kvcache/recurrent.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCheckpointCount = 32
|
||||
DefaultCheckpointMinPos = int32(16)
|
||||
DefaultCheckpointInterval = int32(1280)
|
||||
)
|
||||
|
||||
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
||||
|
||||
// Config configures a shared hybrid recurrent cache.
|
||||
type RecurrentConfig struct {
|
||||
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
ConvDim int
|
||||
ConvChannels int
|
||||
RecurrentStateSize int
|
||||
CheckpointLogPrefix string
|
||||
}
|
||||
|
||||
var (
|
||||
_ Cache = (*Recurrent)(nil)
|
||||
_ CheckpointCache = (*Recurrent)(nil)
|
||||
)
|
||||
|
||||
// Cache stores:
|
||||
// - a standard causal KV cache
|
||||
// - per-sequence conv state for recurrent operators
|
||||
// - per-sequence recurrent state for recurrent operators
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convDim, convChannels]
|
||||
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
|
||||
type Recurrent struct {
|
||||
kv *Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int
|
||||
convChannels int
|
||||
|
||||
// Recurrent state dimensions
|
||||
recurrentStateSize int
|
||||
|
||||
logPrefix string
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
seqCounts map[int]int
|
||||
slotScratch [1]int32
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer recurrent state buffers (allocated lazily)
|
||||
recurrentCtxs map[int]ml.Context
|
||||
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointRecurCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
|
||||
return &Recurrent{
|
||||
kv: NewCausalCache(config.Shift),
|
||||
convDim: config.ConvDim,
|
||||
convChannels: config.ConvChannels,
|
||||
recurrentStateSize: config.RecurrentStateSize,
|
||||
logPrefix: config.CheckpointLogPrefix,
|
||||
slotForSeq: make(map[int]int),
|
||||
seqCounts: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
recurrentCtxs: make(map[int]ml.Context),
|
||||
recurrentStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: DefaultCheckpointCount,
|
||||
checkpointMinPos: DefaultCheckpointMinPos,
|
||||
checkpointInterval: DefaultCheckpointInterval,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointRecurCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.recurrentCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointRecurCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
if nTokens == 0 {
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
c.curSlots = c.curSlots[:0]
|
||||
c.curSlotsInput = nil
|
||||
c.curSeqTokens = 0
|
||||
c.reserveCheckpoints = false
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path for single-sequence batches (common during decode and prefill).
|
||||
firstSeq := batch.Sequences[0]
|
||||
singleSeq := true
|
||||
for _, s := range batch.Sequences[1:] {
|
||||
if s != firstSeq {
|
||||
singleSeq = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if singleSeq {
|
||||
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers.
|
||||
seqCounts := c.seqCounts
|
||||
for s := range seqCounts {
|
||||
delete(seqCounts, s)
|
||||
}
|
||||
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if seqCounts[s] == 0 {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
}
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch.
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
|
||||
c.curSeqs = append(c.curSeqs[:0], seq)
|
||||
c.curSeqTokens = seqTokens
|
||||
|
||||
if reserve {
|
||||
c.curSlots = append(c.curSlots[:0], 0)
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.slotForSeq[seq] = slot
|
||||
c.refCount[slot] = 1
|
||||
slotList := [1]int{slot}
|
||||
c.zeroSlots(ctx, slotList[:])
|
||||
}
|
||||
|
||||
c.curSlots = append(c.curSlots[:0], slot)
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
|
||||
c.setCurSlotsInput(ctx)
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = reserve
|
||||
c.planCheckpoints(batch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
|
||||
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
|
||||
}
|
||||
|
||||
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
|
||||
switch len(slots) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
c.slotScratch[0] = int32(slots[0])
|
||||
return ctx.Input().FromInts(c.slotScratch[:], 1)
|
||||
default:
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, v := range slots {
|
||||
slotIndices[i] = int32(v)
|
||||
}
|
||||
return ctx.Input().FromInts(slotIndices, len(slotIndices))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros recurrent state for the given slots across all cached layers.
|
||||
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
slotsTensor := c.slotsInput(ctx, slots)
|
||||
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.recurrentStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
|
||||
for _, buf := range c.recurrentStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
c.setCurSlotsInput(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.recurrentStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(c.pendingRestore, seq)
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok || !c.validSlot(slot) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
|
||||
if c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
slot = newSlot
|
||||
}
|
||||
|
||||
c.shiftCheckpoints(slot, beginIndex, endIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SlotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *Recurrent) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) SeqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *Recurrent) NumSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.recurrentStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.recurrentCtxs[layer]; !ok {
|
||||
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
|
||||
c.recurrentStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
|
||||
c.ensureWritableOnce(ctx)
|
||||
return c.writableError
|
||||
}
|
||||
|
||||
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
}
|
||||
|
||||
return buf.Rows(ctx, c.SlotsTensor())
|
||||
}
|
||||
|
||||
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
ctx.Forward(src.Copy(ctx, view))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := c.convBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes new conv state for current batch sequences.
|
||||
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
|
||||
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dims) == 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := 1
|
||||
for _, d := range dims {
|
||||
if d <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
size *= d
|
||||
}
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
shape := make([]int, 0, len(dims)+1)
|
||||
shape = append(shape, dims...)
|
||||
shape = append(shape, c.NumSeqs())
|
||||
return cur.Reshape(ctx, shape...), nil
|
||||
}
|
||||
|
||||
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
|
||||
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := dim0 * dim1 * dim2
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateRecurrentState writes new recurrent state for current batch sequences.
|
||||
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
|
||||
|
||||
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *Recurrent) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *Recurrent) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
561
kvcache/recurrent_checkpoints.go
Normal file
561
kvcache/recurrent_checkpoints.go
Normal file
@@ -0,0 +1,561 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
recurrent map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) shiftRange(beginIndex, endIndex int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
offset := beginIndex - endIndex
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
maxPos := int32(-1)
|
||||
minIdx := 0
|
||||
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos >= 0 {
|
||||
if pos >= beginIndex && pos < endIndex {
|
||||
s.entries[i].pos = -1
|
||||
} else if pos >= endIndex {
|
||||
s.entries[i].pos = pos + offset
|
||||
}
|
||||
}
|
||||
|
||||
pos = s.entries[i].pos
|
||||
if pos >= 0 {
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
minIdx = i
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = maxPos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointTag() string {
|
||||
if c.logPrefix == "" {
|
||||
return "kvcache.recurrent"
|
||||
}
|
||||
return c.logPrefix
|
||||
}
|
||||
|
||||
func (c *Recurrent) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *Recurrent) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.recurrent {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.recurrent) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *Recurrent) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.recurrentStates {
|
||||
if entry.recurrent == nil || entry.recurrent[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Recurrent) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) shiftCheckpoints(slot int, beginIndex, endIndex int32) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.shiftRange(beginIndex, endIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.recurrent != nil {
|
||||
if dstEntry.recurrent == nil {
|
||||
dstEntry.recurrent = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.recurrent {
|
||||
dst := c.ensureCheckpointRecurrent(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) captureRecurrentCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointRecurrent(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointRecurrent(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureCheckpointRecurrent(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.recurrent == nil {
|
||||
entry.recurrent = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.recurrent[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointRecurCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointRecurCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.recurrentStateSize, 1)
|
||||
entry.recurrent[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Recurrent) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *Recurrent) reserveCheckpointRecurrent(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointRecurrent(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
288
kvcache/recurrent_checkpoints_test.go
Normal file
288
kvcache/recurrent_checkpoints_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestCache() *Recurrent {
|
||||
return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2})
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachePrepareRestore(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate layer 0 requires both conv and recurrent checkpoints.
|
||||
cache.convStates[0] = nil
|
||||
cache.recurrentStates[0] = nil
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.recurrent intentionally missing
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRecurrentStateShapeValidation(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
_, err := cache.RecurrentState(nil, 0, 3)
|
||||
if !errors.Is(err, ErrInvalidRecurrentShape) {
|
||||
t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreShiftRange(t *testing.T) {
|
||||
store := newSlotCheckpointStore(5)
|
||||
store.record(1)
|
||||
store.record(4)
|
||||
store.record(7)
|
||||
store.record(10)
|
||||
|
||||
store.shiftRange(2, 6)
|
||||
|
||||
var positions []int32
|
||||
for i := range store.entries {
|
||||
if store.entries[i].pos >= 0 {
|
||||
positions = append(positions, store.entries[i].pos)
|
||||
}
|
||||
}
|
||||
slices.Sort(positions)
|
||||
|
||||
want := []int32{1, 3, 6}
|
||||
if !slices.Equal(positions, want) {
|
||||
t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want)
|
||||
}
|
||||
if store.lastPos != 6 {
|
||||
t.Fatalf("expected lastPos 6, got %d", store.lastPos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(1)
|
||||
store.record(4)
|
||||
store.record(7)
|
||||
store.record(10)
|
||||
|
||||
if err := cache.Remove(1, 2, 6); err != nil {
|
||||
t.Fatalf("expected middle remove to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared after middle remove")
|
||||
}
|
||||
|
||||
var positions []int32
|
||||
for i := range store.entries {
|
||||
if store.entries[i].pos >= 0 {
|
||||
positions = append(positions, store.entries[i].pos)
|
||||
}
|
||||
}
|
||||
slices.Sort(positions)
|
||||
|
||||
want := []int32{1, 3, 6}
|
||||
if !slices.Equal(positions, want) {
|
||||
t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil
|
||||
store.entries[0].recurrent = make(map[int]ml.Tensor)
|
||||
store.entries[0].recurrent[0] = nil
|
||||
|
||||
store.record(40)
|
||||
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].recurrent == nil {
|
||||
t.Fatalf("expected recurrent map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user