mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
This change adds a tensorImportTransform interface for model-specific tensor transformations during safetensors import. This allows importing and modifying the standard HF based weights as well as the mlx-community derived pre-quantized safetensors repos to be directly imported into `ollama create`. Right now this only works with Qwen3.5 importing which does tensor renaming, norm weight shifting (it adds +1 to each value of the norm vectors), conv1d transposition, and casts to BF16s for F32 based vectors.
110 lines
2.7 KiB
Go
110 lines
2.7 KiB
Go
package create
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"math"
|
|
"strings"
|
|
|
|
"github.com/d4l3k/go-bfloat16"
|
|
"github.com/x448/float16"
|
|
)
|
|
|
|
// DTypeSize returns the byte size of a single element for the given dtype string.
|
|
func DTypeSize(dtype string) (int, error) {
|
|
switch strings.ToUpper(dtype) {
|
|
case "BF16", "F16":
|
|
return 2, nil
|
|
case "F32", "U32", "I32":
|
|
return 4, nil
|
|
case "F64":
|
|
return 8, nil
|
|
default:
|
|
return 0, fmt.Errorf("unsupported dtype %q", dtype)
|
|
}
|
|
}
|
|
|
|
// DecodeFloatTensor decodes raw bytes into []float32 according to the given dtype.
|
|
func DecodeFloatTensor(dtype string, raw []byte) ([]float32, error) {
|
|
switch strings.ToUpper(dtype) {
|
|
case "BF16":
|
|
return bfloat16.DecodeFloat32(raw), nil
|
|
case "F16":
|
|
if len(raw)%2 != 0 {
|
|
return nil, fmt.Errorf("invalid f16 byte length %d", len(raw))
|
|
}
|
|
values := make([]float32, len(raw)/2)
|
|
for i := range values {
|
|
values[i] = float16.Frombits(binary.LittleEndian.Uint16(raw[i*2:])).Float32()
|
|
}
|
|
return values, nil
|
|
case "F32":
|
|
if len(raw)%4 != 0 {
|
|
return nil, fmt.Errorf("invalid f32 byte length %d", len(raw))
|
|
}
|
|
values := make([]float32, len(raw)/4)
|
|
for i := range values {
|
|
values[i] = math.Float32frombits(binary.LittleEndian.Uint32(raw[i*4:]))
|
|
}
|
|
return values, nil
|
|
case "F64":
|
|
if len(raw)%8 != 0 {
|
|
return nil, fmt.Errorf("invalid f64 byte length %d", len(raw))
|
|
}
|
|
values := make([]float32, len(raw)/8)
|
|
for i := range values {
|
|
values[i] = float32(math.Float64frombits(binary.LittleEndian.Uint64(raw[i*8:])))
|
|
}
|
|
return values, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported dtype %q", dtype)
|
|
}
|
|
}
|
|
|
|
// EncodeFloatTensor encodes []float32 into raw bytes according to the given dtype.
|
|
func EncodeFloatTensor(dtype string, values []float32) ([]byte, error) {
|
|
switch strings.ToUpper(dtype) {
|
|
case "BF16":
|
|
return bfloat16.EncodeFloat32(values), nil
|
|
case "F16":
|
|
out := make([]byte, len(values)*2)
|
|
for i, v := range values {
|
|
binary.LittleEndian.PutUint16(out[i*2:], float16.Fromfloat32(v).Bits())
|
|
}
|
|
return out, nil
|
|
case "F32":
|
|
out := make([]byte, len(values)*4)
|
|
for i, v := range values {
|
|
binary.LittleEndian.PutUint32(out[i*4:], math.Float32bits(v))
|
|
}
|
|
return out, nil
|
|
case "F64":
|
|
out := make([]byte, len(values)*8)
|
|
for i, v := range values {
|
|
binary.LittleEndian.PutUint64(out[i*8:], math.Float64bits(float64(v)))
|
|
}
|
|
return out, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported dtype %q", dtype)
|
|
}
|
|
}
|
|
|
|
func sourceQuantType(mode string, bits int) string {
|
|
switch strings.ToLower(mode) {
|
|
case "affine":
|
|
switch bits {
|
|
case 4:
|
|
return "int4"
|
|
case 8:
|
|
return "int8"
|
|
}
|
|
case "nvfp4":
|
|
return "nvfp4"
|
|
case "mxfp8":
|
|
return "mxfp8"
|
|
case "mxfp4":
|
|
return "mxfp4"
|
|
}
|
|
return ""
|
|
}
|