mlx: add prequantized tensor packing + changes for qwen35 (#14878)

This change adds a tensorImportTransform interface for model-specific
tensor transformations during safetensors import. This allows importing
and modifying the standard HF based weights as well as the mlx-community
derived pre-quantized safetensors repos to be directly
imported into `ollama create`. Right now this only works with Qwen3.5
importing which does tensor renaming, norm weight shifting (it
adds +1 to each value of the norm vectors), conv1d transposition,
and casts to BF16s for F32 based vectors.
This commit is contained in:
Patrick Devine
2026-03-17 11:21:18 -07:00
committed by GitHub
parent bbbad97686
commit fa69b833cd
5 changed files with 1039 additions and 36 deletions

View File

@@ -9,6 +9,7 @@ import (
"regexp"
"slices"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/envconfig"
@@ -422,6 +423,117 @@ type PackedTensorInput struct {
// groupName is the group prefix (e.g., "model.layers.1.mlp.experts").
type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error)
type sourceQuantization struct {
Bits int `json:"bits"`
GroupSize int `json:"group_size"`
Mode string `json:"mode"`
}
type sourceModelConfig struct {
ModelType string `json:"model_type"`
Architectures []string `json:"architectures"`
Quantization sourceQuantization `json:"quantization"`
QuantizationConfig sourceQuantization `json:"quantization_config"`
TextConfig struct {
ModelType string `json:"model_type"`
Quantization sourceQuantization `json:"quantization"`
QuantizationConfig sourceQuantization `json:"quantization_config"`
} `json:"text_config"`
}
func readSourceModelConfig(modelDir string) (sourceModelConfig, error) {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return sourceModelConfig{}, err
}
var cfg sourceModelConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return sourceModelConfig{}, err
}
return cfg, nil
}
func (cfg sourceModelConfig) Architecture() string {
if len(cfg.Architectures) > 0 && cfg.Architectures[0] != "" {
return cfg.Architectures[0]
}
if cfg.ModelType != "" {
return cfg.ModelType
}
return cfg.TextConfig.ModelType
}
func (cfg sourceModelConfig) QuantMetadata() map[string]string {
// Use the first non-empty quantization config found
var q sourceQuantization
for _, candidate := range []sourceQuantization{
cfg.Quantization,
cfg.QuantizationConfig,
cfg.TextConfig.Quantization,
cfg.TextConfig.QuantizationConfig,
} {
if candidate.Bits != 0 {
q = candidate
break
}
}
quantType := sourceQuantType(q.Mode, q.Bits)
if quantType == "" {
return nil
}
metadata := map[string]string{"quant_type": quantType}
if q.GroupSize > 0 {
metadata["group_size"] = strconv.Itoa(q.GroupSize)
}
return metadata
}
type tensorImportTransform interface {
skipTensor(name string) bool
transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error)
quantizationType(name string, shape []int32, quantize string) string
}
type noopImportTransform struct{}
func (noopImportTransform) skipTensor(string) bool { return false }
func (noopImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
if td == nil {
return nil, nil
}
return []*safetensors.TensorData{td}, nil
}
func (noopImportTransform) quantizationType(name string, shape []int32, quantize string) string {
return GetTensorQuantization(name, shape, quantize)
}
type tensorImportTransformFactory func(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error)
var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
"Qwen3_5ForCausalLM": newQwen35ImportTransform,
"Qwen3_5ForConditionalGeneration": newQwen35ImportTransform,
"Qwen3NextForCausalLM": newQwen35ImportTransform,
"Qwen3NextForConditionalGeneration": newQwen35ImportTransform,
"Qwen3_5MoeForCausalLM": newQwen35ImportTransform,
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
}
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
if factory, ok := tensorImportTransformRegistry[cfg.Architecture()]; ok {
return factory(modelDir, cfg)
}
return noopImportTransform{}, nil
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
@@ -430,6 +542,15 @@ type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string), createPackedLayer ...PackedTensorLayerCreator) error {
var layers []LayerInfo
var configLayer LayerInfo
sourceConfig, err := readSourceModelConfig(modelDir)
if err != nil {
return fmt.Errorf("failed to read source config.json: %w", err)
}
sourceQuantMetadata := sourceConfig.QuantMetadata()
importTransform, err := newTensorImportTransform(modelDir, sourceConfig)
if err != nil {
return fmt.Errorf("failed to construct import transform for architecture %q: %w", sourceConfig.Architecture(), err)
}
// Resolve the optional packed layer creator
var packedCreator PackedTensorLayerCreator
@@ -474,6 +595,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
}
tensorNames := extractor.ListTensors()
tensorSet := make(map[string]struct{}, len(tensorNames))
for _, name := range tensorNames {
tensorSet[name] = struct{}{}
}
quantizeMsg := ""
if quantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
@@ -484,6 +609,13 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
hasExpertTensors := false
for _, tensorName := range tensorNames {
if importTransform.skipTensor(tensorName) {
continue
}
if shouldSkipPrequantizedCompanion(tensorName, tensorSet) {
continue
}
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
@@ -491,45 +623,67 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Determine quantization type for this tensor (empty string if not quantizing)
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := ""
if quantize != "" {
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
}
// Check if this tensor belongs to an expert group for packing
groupPrefix := ""
if packedCreator != nil {
groupPrefix = ExpertGroupPrefix(tensorName)
}
if groupPrefix != "" {
// Accumulate expert tensor for packed blob.
// The Reader uses a file-backed SectionReader, so we must
// keep the extractor open until this group is flushed.
hasExpertTensors = true
if _, exists := expertGroups[groupPrefix]; !exists {
expertGroupOrder = append(expertGroupOrder, groupPrefix)
}
expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{
Name: tensorName,
Dtype: td.Dtype,
Shape: td.Shape,
Quantize: quantizeType,
Reader: td.SafetensorsReader(),
})
} else {
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
if quantize == "" {
layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
return err
}
if ok {
layers = append(layers, layer)
continue
}
}
outputTensors, err := importTransform.transformTensor(td)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to transform tensor %s: %w", tensorName, err)
}
for _, outTD := range outputTensors {
// Determine quantization type for this tensor (empty string if not quantizing)
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := ""
if quantize != "" {
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, quantize)
}
// Check if this tensor belongs to an expert group for packing
groupPrefix := ""
if packedCreator != nil {
groupPrefix = ExpertGroupPrefix(outTD.Name)
}
if groupPrefix != "" {
// Accumulate expert tensor for packed blob.
// The Reader uses a file-backed SectionReader, so we must
// keep the extractor open until this group is flushed.
hasExpertTensors = true
if _, exists := expertGroups[groupPrefix]; !exists {
expertGroupOrder = append(expertGroupOrder, groupPrefix)
}
expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{
Name: outTD.Name,
Dtype: outTD.Dtype,
Shape: outTD.Shape,
Quantize: quantizeType,
Reader: outTD.SafetensorsReader(),
})
} else {
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to create layer for %s: %w", outTD.Name, err)
}
layers = append(layers, newLayers...)
}
layers = append(layers, newLayers...)
}
}
@@ -605,3 +759,74 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}
func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{}) bool {
switch {
case strings.HasSuffix(name, ".scales"):
_, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"]
return ok
case strings.HasSuffix(name, ".biases"):
_, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"]
return ok
default:
return false
}
}
func createPrequantizedLayer(
extractor *safetensors.TensorExtractor,
td *safetensors.TensorData,
tensorName string,
tensorSet map[string]struct{},
metadata map[string]string,
createLayer LayerCreator,
) (LayerInfo, bool, error) {
scaleName, biasName, ok := prequantizedCompanions(tensorName, tensorSet)
if !ok {
return LayerInfo{}, false, nil
}
tensors := []*safetensors.TensorData{td.WithName(tensorName)}
scaleTD, err := extractor.GetTensor(scaleName)
if err != nil {
return LayerInfo{}, false, fmt.Errorf("failed to get tensor %s: %w", scaleName, err)
}
tensors = append(tensors, scaleTD.WithName(tensorName+".scale"))
if biasName != "" {
biasTD, err := extractor.GetTensor(biasName)
if err != nil {
return LayerInfo{}, false, fmt.Errorf("failed to get tensor %s: %w", biasName, err)
}
tensors = append(tensors, biasTD.WithName(tensorName+".bias"))
}
layer, err := createLayer(
safetensors.BuildPackedSafetensorsReaderWithMetadata(tensors, metadata),
"application/vnd.ollama.image.tensor",
tensorName,
)
if err != nil {
return LayerInfo{}, false, fmt.Errorf("failed to create prequantized layer for %s: %w", tensorName, err)
}
return layer, true, nil
}
func prequantizedCompanions(weightName string, tensorSet map[string]struct{}) (scaleName, biasName string, ok bool) {
if !strings.HasSuffix(weightName, ".weight") {
return "", "", false
}
base := strings.TrimSuffix(weightName, ".weight")
scaleName = base + ".scales"
if _, ok := tensorSet[scaleName]; !ok {
return "", "", false
}
biasName = base + ".biases"
if _, ok := tensorSet[biasName]; !ok {
biasName = ""
}
return scaleName, biasName, true
}

View File

@@ -7,8 +7,12 @@ import (
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/d4l3k/go-bfloat16"
st "github.com/ollama/ollama/x/imagegen/safetensors"
)
func TestIsTensorModelDir(t *testing.T) {
@@ -173,6 +177,75 @@ func createMinimalSafetensors(t *testing.T, path string) {
}
}
func createTestSafetensors(t *testing.T, path string, tensors []*st.TensorData) {
t.Helper()
data, err := io.ReadAll(st.BuildPackedSafetensorsReader(tensors))
if err != nil {
t.Fatalf("failed to build packed safetensors: %v", err)
}
if err := os.WriteFile(path, data, 0o644); err != nil {
t.Fatalf("failed to write safetensors: %v", err)
}
}
func readSingleTensorHeader(t *testing.T, data []byte) (string, []int32) {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
}
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
for name, info := range header {
if name == "__metadata__" {
continue
}
return info.Dtype, info.Shape
}
t.Fatal("no tensor entry found in header")
return "", nil
}
func readSingleTensorRaw(t *testing.T, data []byte) []byte {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
for name, info := range header {
if name == "__metadata__" {
continue
}
start := 8 + int(headerSize) + info.DataOffsets[0]
end := 8 + int(headerSize) + info.DataOffsets[1]
return data[start:end]
}
t.Fatal("no tensor entry found in header")
return nil
}
func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir()
@@ -374,6 +447,252 @@ func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
}
}
func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization": {"group_size": 64, "bits": 4, "mode": "affine"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
st.NewTensorDataFromBytes("linear.biases", "BF16", []int32{4, 1}, make([]byte, 8)),
st.NewTensorDataFromBytes("plain.weight", "F32", []int32{2, 2}, make([]byte, 16)),
})
var packedHeader map[string]json.RawMessage
var tensorLayerNames []string
var createTensorLayerNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
if mediaType == "application/vnd.ollama.image.tensor" && name == "linear.weight" {
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
return LayerInfo{}, err
}
if err := json.Unmarshal(data[8:8+headerSize], &packedHeader); err != nil {
return LayerInfo{}, err
}
}
tensorLayerNames = append(tensorLayerNames, name)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType, Size: int64(len(data))}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
createTensorLayerNames = append(createTensorLayerNames, name)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if packedHeader == nil {
t.Fatal("expected packed quantized header for linear.weight")
}
if _, ok := packedHeader["linear.weight"]; !ok {
t.Fatalf("packed header missing linear.weight: %v", packedHeader)
}
if _, ok := packedHeader["linear.weight.scale"]; !ok {
t.Fatalf("packed header missing linear.weight.scale: %v", packedHeader)
}
if _, ok := packedHeader["linear.weight.bias"]; !ok {
t.Fatalf("packed header missing linear.weight.bias: %v", packedHeader)
}
var metadata map[string]string
if metaRaw, ok := packedHeader["__metadata__"]; ok {
if err := json.Unmarshal(metaRaw, &metadata); err != nil {
t.Fatalf("failed to parse packed metadata: %v", err)
}
}
if metadata["quant_type"] != "int4" {
t.Fatalf("quant_type = %q, want %q", metadata["quant_type"], "int4")
}
if metadata["group_size"] != "64" {
t.Fatalf("group_size = %q, want %q", metadata["group_size"], "64")
}
if slices.Contains(createTensorLayerNames, "linear.weight") {
t.Fatalf("linear.weight unexpectedly handled by createTensorLayer: %v", createTensorLayerNames)
}
if slices.Contains(createTensorLayerNames, "linear.scales") || slices.Contains(createTensorLayerNames, "linear.biases") {
t.Fatalf("quantized companions unexpectedly handled separately: %v", createTensorLayerNames)
}
if !slices.Contains(createTensorLayerNames, "plain.weight") {
t.Fatalf("plain.weight missing from createTensorLayer calls: %v", createTensorLayerNames)
}
if slices.Contains(tensorLayerNames, "linear.scales") || slices.Contains(tensorLayerNames, "linear.biases") {
t.Fatalf("quantized companions unexpectedly emitted as layers: %v", tensorLayerNames)
}
}
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {"dtype": "bfloat16"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
gateUpValues := make([]float32, 2*128*64)
for expert := range 2 {
base := expert * 128 * 64
for i := range 64 * 64 {
gateUpValues[base+i] = 1
gateUpValues[base+64*64+i] = 2
}
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.input_layernorm.weight", "F32", []int32{64}, make([]byte, 64*4)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.A_log", "F32", []int32{32}, make([]byte, 32*4)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.conv1d.weight", "BF16", []int32{64, 1, 4}, make([]byte, 64*1*4*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert.down_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.visual.blocks.0.attn.proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("mtp.layers.0.foo.weight", "F32", []int32{64, 64}, make([]byte, 64*64*4)),
})
type tensorCall struct {
dtype string
shape []int32
quantize string
raw []byte
}
calls := make(map[string]tensorCall)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, _ = io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
headerDType, headerShape := readSingleTensorHeader(t, data)
calls[name] = tensorCall{
dtype: headerDType,
shape: headerShape,
quantize: quantize,
raw: readSingleTensorRaw(t, data),
}
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
if err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if _, ok := calls["mtp.layers.0.foo.weight"]; ok {
t.Fatal("mtp tensor should have been dropped")
}
layerNorm := calls["language_model.model.layers.0.input_layernorm.weight"]
if layerNorm.dtype != "BF16" {
t.Fatalf("input_layernorm dtype = %q, want %q", layerNorm.dtype, "BF16")
}
if layerNorm.quantize != "" {
t.Fatalf("input_layernorm quantize = %q, want empty", layerNorm.quantize)
}
layerNormValues := bfloat16.DecodeFloat32(layerNorm.raw)
if len(layerNormValues) == 0 || layerNormValues[0] != 1.0 {
t.Fatalf("input_layernorm first value = %v, want 1.0 after +1 shift", layerNormValues[0])
}
alog := calls["language_model.model.layers.0.linear_attn.A_log"]
if alog.dtype != "F32" {
t.Fatalf("A_log dtype = %q, want %q", alog.dtype, "F32")
}
conv := calls["language_model.model.layers.0.linear_attn.conv1d.weight"]
if !slices.Equal(conv.shape, []int32{64, 4, 1}) {
t.Fatalf("conv1d shape = %v, want %v", conv.shape, []int32{64, 4, 1})
}
if got := calls["language_model.model.embed_tokens.weight"].quantize; got != "int4" {
t.Fatalf("embed_tokens quantize = %q, want %q", got, "int4")
}
if got := calls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "int4" {
t.Fatalf("mlp.gate quantize = %q, want %q", got, "int4")
}
if got := calls["language_model.model.layers.0.mlp.shared_expert.down_proj.weight"].quantize; got != "int4" {
t.Fatalf("down_proj quantize = %q, want %q", got, "int4")
}
if _, ok := calls["language_model.model.layers.0.mlp.experts.gate_up_proj"]; ok {
t.Fatal("combined gate_up_proj tensor should have been rewritten")
}
if _, ok := calls["language_model.model.layers.0.mlp.experts.down_proj"]; ok {
t.Fatal("combined down_proj tensor should have been rewritten")
}
gateProj := calls["language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight"]
if !slices.Equal(gateProj.shape, []int32{2, 64, 64}) {
t.Fatalf("gate_proj shape = %v, want %v", gateProj.shape, []int32{2, 64, 64})
}
gateProjValues := bfloat16.DecodeFloat32(gateProj.raw)
if len(gateProjValues) == 0 || gateProjValues[0] != 1.0 {
t.Fatalf("gate_proj first value = %v, want 1.0", gateProjValues[0])
}
upProj := calls["language_model.model.layers.0.mlp.switch_mlp.up_proj.weight"]
if !slices.Equal(upProj.shape, []int32{2, 64, 64}) {
t.Fatalf("up_proj shape = %v, want %v", upProj.shape, []int32{2, 64, 64})
}
upProjValues := bfloat16.DecodeFloat32(upProj.raw)
if len(upProjValues) == 0 || upProjValues[0] != 2.0 {
t.Fatalf("up_proj first value = %v, want 2.0", upProjValues[0])
}
if got := calls["language_model.model.layers.0.mlp.switch_mlp.down_proj.weight"].quantize; got != "int4" {
t.Fatalf("switch_mlp down_proj quantize = %q, want %q", got, "int4")
}
vision := calls["vision_tower.blocks.0.attn.proj.weight"]
if vision.dtype != "BF16" {
t.Fatalf("vision weight dtype = %q, want %q", vision.dtype, "BF16")
}
if vision.quantize != "" {
t.Fatalf("vision weight quantize = %q, want empty", vision.quantize)
}
if _, ok := calls["language_model.model.visual.blocks.0.attn.proj.weight"]; ok {
t.Fatal("vision tensor should have been rewritten to vision_tower.*")
}
}
func TestResolveManifestPath(t *testing.T) {
tests := []struct {
name string

109
x/create/dtype.go Normal file
View File

@@ -0,0 +1,109 @@
package create
import (
"encoding/binary"
"fmt"
"math"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/x448/float16"
)
// DTypeSize returns the byte size of a single element for the given dtype string.
func DTypeSize(dtype string) (int, error) {
switch strings.ToUpper(dtype) {
case "BF16", "F16":
return 2, nil
case "F32", "U32", "I32":
return 4, nil
case "F64":
return 8, nil
default:
return 0, fmt.Errorf("unsupported dtype %q", dtype)
}
}
// DecodeFloatTensor decodes raw bytes into []float32 according to the given dtype.
func DecodeFloatTensor(dtype string, raw []byte) ([]float32, error) {
switch strings.ToUpper(dtype) {
case "BF16":
return bfloat16.DecodeFloat32(raw), nil
case "F16":
if len(raw)%2 != 0 {
return nil, fmt.Errorf("invalid f16 byte length %d", len(raw))
}
values := make([]float32, len(raw)/2)
for i := range values {
values[i] = float16.Frombits(binary.LittleEndian.Uint16(raw[i*2:])).Float32()
}
return values, nil
case "F32":
if len(raw)%4 != 0 {
return nil, fmt.Errorf("invalid f32 byte length %d", len(raw))
}
values := make([]float32, len(raw)/4)
for i := range values {
values[i] = math.Float32frombits(binary.LittleEndian.Uint32(raw[i*4:]))
}
return values, nil
case "F64":
if len(raw)%8 != 0 {
return nil, fmt.Errorf("invalid f64 byte length %d", len(raw))
}
values := make([]float32, len(raw)/8)
for i := range values {
values[i] = float32(math.Float64frombits(binary.LittleEndian.Uint64(raw[i*8:])))
}
return values, nil
default:
return nil, fmt.Errorf("unsupported dtype %q", dtype)
}
}
// EncodeFloatTensor encodes []float32 into raw bytes according to the given dtype.
func EncodeFloatTensor(dtype string, values []float32) ([]byte, error) {
switch strings.ToUpper(dtype) {
case "BF16":
return bfloat16.EncodeFloat32(values), nil
case "F16":
out := make([]byte, len(values)*2)
for i, v := range values {
binary.LittleEndian.PutUint16(out[i*2:], float16.Fromfloat32(v).Bits())
}
return out, nil
case "F32":
out := make([]byte, len(values)*4)
for i, v := range values {
binary.LittleEndian.PutUint32(out[i*4:], math.Float32bits(v))
}
return out, nil
case "F64":
out := make([]byte, len(values)*8)
for i, v := range values {
binary.LittleEndian.PutUint64(out[i*8:], math.Float64bits(float64(v)))
}
return out, nil
default:
return nil, fmt.Errorf("unsupported dtype %q", dtype)
}
}
func sourceQuantType(mode string, bits int) string {
switch strings.ToLower(mode) {
case "affine":
switch bits {
case 4:
return "int4"
case 8:
return "int8"
}
case "nvfp4":
return "nvfp4"
case "mxfp8":
return "mxfp8"
case "mxfp4":
return "mxfp4"
}
return ""
}

323
x/create/qwen35.go Normal file
View File

@@ -0,0 +1,323 @@
package create
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
type qwen35ImportTransform struct {
shouldShiftNormWeights bool
rewriteLanguageModel bool
}
type qwen35SourceInfo struct {
hasPrequantizedWeights bool
shouldShiftNormWeights bool
}
func newQwen35ImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
sourceInfo, err := qwen35InspectSource(modelDir)
if err != nil {
return qwen35ImportTransform{}, err
}
if sourceInfo.hasPrequantizedWeights {
return noopImportTransform{}, nil
}
return qwen35ImportTransform{
shouldShiftNormWeights: sourceInfo.shouldShiftNormWeights,
rewriteLanguageModel: strings.Contains(cfg.Architecture(), "ConditionalGeneration"),
}, nil
}
func qwen35InspectSource(modelDir string) (qwen35SourceInfo, error) {
entries, err := os.ReadDir(modelDir)
if err != nil {
return qwen35SourceInfo{}, err
}
var info qwen35SourceInfo
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
extractor, err := safetensors.OpenForExtraction(filepath.Join(modelDir, entry.Name()))
if err != nil {
return qwen35SourceInfo{}, err
}
for _, name := range extractor.ListTensors() {
if strings.HasSuffix(name, ".scales") {
extractor.Close()
info.hasPrequantizedWeights = true
return info, nil
}
// This should change when MTP is supported
if strings.Contains(name, "mtp.") {
info.shouldShiftNormWeights = true
continue
}
if info.shouldShiftNormWeights || !strings.Contains(name, "conv1d.weight") {
continue
}
td, err := extractor.GetTensor(name)
if err != nil {
extractor.Close()
return qwen35SourceInfo{}, err
}
if len(td.Shape) == 3 && td.Shape[2] != 1 {
info.shouldShiftNormWeights = true
}
}
extractor.Close()
}
return info, nil
}
func (t qwen35ImportTransform) skipTensor(name string) bool {
return strings.Contains(name, "mtp.")
}
func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
if strings.HasPrefix(name, "vision_tower.") {
return ""
}
stackedExpert := isStackedExpertWeight(name)
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") ||
strings.HasSuffix(name, ".biases") || strings.HasSuffix(name, ".scales") {
return ""
}
if !stackedExpert && !strings.HasSuffix(name, ".weight") {
return ""
}
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
return ""
}
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
return ""
}
var elems int64 = 1
for _, d := range shape {
elems *= int64(d)
}
if elems < 1024 {
return ""
}
quantNorm := normalizeQuantType(quantize)
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
groupSize = 16
case "int4", "int8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {
return ""
}
return quantNorm
}
func (t qwen35ImportTransform) rewriteTensorData(td *safetensors.TensorData) (*safetensors.TensorData, error) {
if td == nil {
return td, nil
}
shiftNorm := t.shouldShiftNormWeights && qwen35ShouldShiftNormKey(td.Name)
transposeConv := strings.Contains(td.Name, "conv1d.weight") && len(td.Shape) == 3 && td.Shape[2] != 1
castToBF16 := qwen35NeedsCastToBF16(td.Name, td.Dtype)
if !shiftNorm && !transposeConv && !castToBF16 {
return td, nil
}
raw, err := io.ReadAll(td.Reader())
if err != nil {
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
}
values, err := DecodeFloatTensor(td.Dtype, raw)
if err != nil {
return nil, fmt.Errorf("failed to decode tensor %s: %w", td.Name, err)
}
shape := append([]int32(nil), td.Shape...)
if transposeConv {
values, shape = qwen35TransposeConv1D(values, shape)
}
if shiftNorm {
for i := range values {
values[i] += 1.0
}
}
targetDtype := td.Dtype
if castToBF16 {
targetDtype = "BF16"
}
out, err := EncodeFloatTensor(targetDtype, values)
if err != nil {
return nil, fmt.Errorf("failed to encode tensor %s: %w", td.Name, err)
}
return safetensors.NewTensorDataFromBytes(td.Name, targetDtype, shape, out), nil
}
func (t qwen35ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
if td == nil {
return nil, nil
}
name := t.canonicalTensorName(td.Name)
// Phase 1: rename/split into intermediate tensors
var intermediates []*safetensors.TensorData
stripped := strings.TrimSuffix(name, ".weight")
switch {
case strings.HasSuffix(stripped, ".mlp.experts.gate_up_proj"):
prefix := strings.TrimSuffix(stripped, ".mlp.experts.gate_up_proj")
raw, err := io.ReadAll(td.Reader())
if err != nil {
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
}
gateRaw, upRaw, splitShape, err := qwen35SplitAxis1Raw(raw, td.Dtype, td.Shape)
if err != nil {
return nil, fmt.Errorf("failed to split tensor %s: %w", td.Name, err)
}
intermediates = []*safetensors.TensorData{
safetensors.NewTensorDataFromBytes(prefix+".mlp.switch_mlp.gate_proj.weight", td.Dtype, splitShape, gateRaw),
safetensors.NewTensorDataFromBytes(prefix+".mlp.switch_mlp.up_proj.weight", td.Dtype, splitShape, upRaw),
}
case strings.HasSuffix(stripped, ".mlp.experts.down_proj"):
newName := strings.TrimSuffix(stripped, ".mlp.experts.down_proj") + ".mlp.switch_mlp.down_proj.weight"
intermediates = []*safetensors.TensorData{td.WithName(newName)}
default:
intermediates = []*safetensors.TensorData{td.WithName(name)}
}
// Phase 2: rewrite all intermediates
results := make([]*safetensors.TensorData, 0, len(intermediates))
for _, inter := range intermediates {
rewritten, err := t.rewriteTensorData(inter)
if err != nil {
return nil, err
}
results = append(results, rewritten)
}
return results, nil
}
func (t qwen35ImportTransform) canonicalTensorName(name string) string {
// Vision tensors: normalize to vision_tower.* prefix
switch {
case strings.HasPrefix(name, "model.visual."):
return "vision_tower." + strings.TrimPrefix(name, "model.visual.")
case strings.HasPrefix(name, "vision_tower."):
return name
}
// Language model tensors: normalize to language_model.model.* prefix
if !t.rewriteLanguageModel {
return name
}
switch {
case strings.HasPrefix(name, "model.language_model"):
return "language_model.model" + strings.TrimPrefix(name, "model.language_model")
case strings.HasPrefix(name, "language_model."):
return name
default:
return "language_model." + name
}
}
func qwen35ShouldShiftNormKey(key string) bool {
for _, suffix := range []string{
".input_layernorm.weight",
".post_attention_layernorm.weight",
"model.norm.weight",
".q_norm.weight",
".k_norm.weight",
} {
if strings.HasSuffix(key, suffix) {
return true
}
}
return false
}
func qwen35NeedsCastToBF16(name, dtype string) bool {
if strings.HasSuffix(name, "A_log") {
return false
}
switch strings.ToUpper(dtype) {
case "F16", "F32", "F64":
return true
default:
return false
}
}
func qwen35TransposeConv1D(values []float32, shape []int32) ([]float32, []int32) {
if len(shape) != 3 {
return values, shape
}
d0, d1, d2 := int(shape[0]), int(shape[1]), int(shape[2])
out := make([]float32, len(values))
for i := range d0 {
for j := range d1 {
for k := range d2 {
inIdx := (i*d1+j)*d2 + k
outIdx := (i*d2+k)*d1 + j
out[outIdx] = values[inIdx]
}
}
}
return out, []int32{shape[0], shape[2], shape[1]}
}
func qwen35SplitAxis1Raw(raw []byte, dtype string, shape []int32) ([]byte, []byte, []int32, error) {
if len(shape) != 3 {
return nil, nil, nil, fmt.Errorf("expected 3D tensor, got shape %v", shape)
}
if shape[1]%2 != 0 {
return nil, nil, nil, fmt.Errorf("axis 1 dim %d is not even", shape[1])
}
elemSize, err := DTypeSize(dtype)
if err != nil {
return nil, nil, nil, err
}
d0, d1, d2 := int(shape[0]), int(shape[1]), int(shape[2])
perExpertBytes := d1 * d2 * elemSize
if len(raw) != d0*perExpertBytes {
return nil, nil, nil, fmt.Errorf("raw byte length %d does not match shape %v and dtype %s", len(raw), shape, dtype)
}
halfD1 := d1 / 2
halfExpertBytes := halfD1 * d2 * elemSize
gateRaw := make([]byte, d0*halfExpertBytes)
upRaw := make([]byte, d0*halfExpertBytes)
for e := range d0 {
src := e * perExpertBytes
dst := e * halfExpertBytes
copy(gateRaw[dst:dst+halfExpertBytes], raw[src:src+halfExpertBytes])
copy(upRaw[dst:dst+halfExpertBytes], raw[src+halfExpertBytes:src+perExpertBytes])
}
return gateRaw, upRaw, []int32{shape[0], int32(halfD1), shape[2]}, nil
}

View File

@@ -36,6 +36,23 @@ type TensorData struct {
reader *io.SectionReader
}
// WithName returns a shallow copy of TensorData with a different logical tensor
// name but the same underlying raw data reader.
func (td *TensorData) WithName(name string) *TensorData {
if td == nil {
return nil
}
shape := make([]int32, len(td.Shape))
copy(shape, td.Shape)
return &TensorData{
Name: name,
Dtype: td.Dtype,
Shape: shape,
Size: td.Size,
reader: td.reader,
}
}
// Reader returns an io.Reader for the tensor's raw bytes.
func (td *TensorData) Reader() io.Reader {
return td.reader
@@ -117,8 +134,15 @@ func ExtractRawFromSafetensors(r io.Reader) ([]byte, error) {
// into a single blob without loading all data into memory.
// Each TensorData must have been obtained from GetTensor.
func BuildPackedSafetensorsReader(tensors []*TensorData) io.Reader {
return BuildPackedSafetensorsReaderWithMetadata(tensors, nil)
}
// BuildPackedSafetensorsReaderWithMetadata builds a streaming io.Reader that
// outputs a valid safetensors file containing multiple tensors and optional
// metadata.
func BuildPackedSafetensorsReaderWithMetadata(tensors []*TensorData, metadata map[string]string) io.Reader {
// Build the header with sequential data offsets
header := make(map[string]tensorInfo, len(tensors))
header := make(map[string]any, len(tensors)+1)
var offset int
for _, td := range tensors {
header[td.Name] = tensorInfo{
@@ -128,6 +152,9 @@ func BuildPackedSafetensorsReader(tensors []*TensorData) io.Reader {
}
offset += int(td.Size)
}
if len(metadata) > 0 {
header["__metadata__"] = metadata
}
headerJSON, _ := json.Marshal(header)