Files
ollama/x/create/dtype.go
Patrick Devine fa69b833cd mlx: add prequantized tensor packing + changes for qwen35 (#14878)
This change adds a tensorImportTransform interface for model-specific
tensor transformations during safetensors import. This allows importing
and modifying the standard HF based weights as well as the mlx-community
derived pre-quantized safetensors repos to be directly
imported into `ollama create`. Right now this only works with Qwen3.5
importing which does tensor renaming, norm weight shifting (it
adds +1 to each value of the norm vectors), conv1d transposition,
and casts to BF16s for F32 based vectors.
2026-03-17 11:21:18 -07:00

110 lines
2.7 KiB
Go

package create
import (
"encoding/binary"
"fmt"
"math"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/x448/float16"
)
// DTypeSize returns the byte size of a single element for the given dtype string.
func DTypeSize(dtype string) (int, error) {
switch strings.ToUpper(dtype) {
case "BF16", "F16":
return 2, nil
case "F32", "U32", "I32":
return 4, nil
case "F64":
return 8, nil
default:
return 0, fmt.Errorf("unsupported dtype %q", dtype)
}
}
// DecodeFloatTensor decodes raw bytes into []float32 according to the given dtype.
func DecodeFloatTensor(dtype string, raw []byte) ([]float32, error) {
switch strings.ToUpper(dtype) {
case "BF16":
return bfloat16.DecodeFloat32(raw), nil
case "F16":
if len(raw)%2 != 0 {
return nil, fmt.Errorf("invalid f16 byte length %d", len(raw))
}
values := make([]float32, len(raw)/2)
for i := range values {
values[i] = float16.Frombits(binary.LittleEndian.Uint16(raw[i*2:])).Float32()
}
return values, nil
case "F32":
if len(raw)%4 != 0 {
return nil, fmt.Errorf("invalid f32 byte length %d", len(raw))
}
values := make([]float32, len(raw)/4)
for i := range values {
values[i] = math.Float32frombits(binary.LittleEndian.Uint32(raw[i*4:]))
}
return values, nil
case "F64":
if len(raw)%8 != 0 {
return nil, fmt.Errorf("invalid f64 byte length %d", len(raw))
}
values := make([]float32, len(raw)/8)
for i := range values {
values[i] = float32(math.Float64frombits(binary.LittleEndian.Uint64(raw[i*8:])))
}
return values, nil
default:
return nil, fmt.Errorf("unsupported dtype %q", dtype)
}
}
// EncodeFloatTensor encodes []float32 into raw bytes according to the given dtype.
func EncodeFloatTensor(dtype string, values []float32) ([]byte, error) {
switch strings.ToUpper(dtype) {
case "BF16":
return bfloat16.EncodeFloat32(values), nil
case "F16":
out := make([]byte, len(values)*2)
for i, v := range values {
binary.LittleEndian.PutUint16(out[i*2:], float16.Fromfloat32(v).Bits())
}
return out, nil
case "F32":
out := make([]byte, len(values)*4)
for i, v := range values {
binary.LittleEndian.PutUint32(out[i*4:], math.Float32bits(v))
}
return out, nil
case "F64":
out := make([]byte, len(values)*8)
for i, v := range values {
binary.LittleEndian.PutUint64(out[i*8:], math.Float64bits(float64(v)))
}
return out, nil
default:
return nil, fmt.Errorf("unsupported dtype %q", dtype)
}
}
func sourceQuantType(mode string, bits int) string {
switch strings.ToLower(mode) {
case "affine":
switch bits {
case 4:
return "int4"
case 8:
return "int8"
}
case "nvfp4":
return "nvfp4"
case "mxfp8":
return "mxfp8"
case "mxfp4":
return "mxfp4"
}
return ""
}