From fa69b833cd1323b2d96b80da9e38cadc7e8fe97a Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Tue, 17 Mar 2026 11:21:18 -0700 Subject: [PATCH] mlx: add prequantized tensor packing + changes for qwen35 (#14878) This change adds a tensorImportTransform interface for model-specific tensor transformations during safetensors import. This allows importing and modifying the standard HF based weights as well as the mlx-community derived pre-quantized safetensors repos to be directly imported into `ollama create`. Right now this only works with Qwen3.5 importing which does tensor renaming, norm weight shifting (it adds +1 to each value of the norm vectors), conv1d transposition, and casts to BF16s for F32 based vectors. --- x/create/create.go | 295 ++++++++++++++++++++++--- x/create/create_test.go | 319 +++++++++++++++++++++++++++ x/create/dtype.go | 109 ++++++++++ x/create/qwen35.go | 323 ++++++++++++++++++++++++++++ x/imagegen/safetensors/extractor.go | 29 ++- 5 files changed, 1039 insertions(+), 36 deletions(-) create mode 100644 x/create/dtype.go create mode 100644 x/create/qwen35.go diff --git a/x/create/create.go b/x/create/create.go index 9fb9b1e64..efd4065c2 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -9,6 +9,7 @@ import ( "regexp" "slices" "sort" + "strconv" "strings" "github.com/ollama/ollama/envconfig" @@ -422,6 +423,117 @@ type PackedTensorInput struct { // groupName is the group prefix (e.g., "model.layers.1.mlp.experts"). type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) +type sourceQuantization struct { + Bits int `json:"bits"` + GroupSize int `json:"group_size"` + Mode string `json:"mode"` +} + +type sourceModelConfig struct { + ModelType string `json:"model_type"` + Architectures []string `json:"architectures"` + Quantization sourceQuantization `json:"quantization"` + QuantizationConfig sourceQuantization `json:"quantization_config"` + TextConfig struct { + ModelType string `json:"model_type"` + Quantization sourceQuantization `json:"quantization"` + QuantizationConfig sourceQuantization `json:"quantization_config"` + } `json:"text_config"` +} + +func readSourceModelConfig(modelDir string) (sourceModelConfig, error) { + configPath := filepath.Join(modelDir, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return sourceModelConfig{}, err + } + + var cfg sourceModelConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return sourceModelConfig{}, err + } + + return cfg, nil +} + +func (cfg sourceModelConfig) Architecture() string { + if len(cfg.Architectures) > 0 && cfg.Architectures[0] != "" { + return cfg.Architectures[0] + } + if cfg.ModelType != "" { + return cfg.ModelType + } + return cfg.TextConfig.ModelType +} + +func (cfg sourceModelConfig) QuantMetadata() map[string]string { + // Use the first non-empty quantization config found + var q sourceQuantization + for _, candidate := range []sourceQuantization{ + cfg.Quantization, + cfg.QuantizationConfig, + cfg.TextConfig.Quantization, + cfg.TextConfig.QuantizationConfig, + } { + if candidate.Bits != 0 { + q = candidate + break + } + } + + quantType := sourceQuantType(q.Mode, q.Bits) + if quantType == "" { + return nil + } + + metadata := map[string]string{"quant_type": quantType} + if q.GroupSize > 0 { + metadata["group_size"] = strconv.Itoa(q.GroupSize) + } + return metadata +} + +type tensorImportTransform interface { + skipTensor(name string) bool + transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) + quantizationType(name string, shape []int32, quantize string) string +} + +type noopImportTransform struct{} + +func (noopImportTransform) skipTensor(string) bool { return false } + +func (noopImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) { + if td == nil { + return nil, nil + } + return []*safetensors.TensorData{td}, nil +} + +func (noopImportTransform) quantizationType(name string, shape []int32, quantize string) string { + return GetTensorQuantization(name, shape, quantize) +} + +type tensorImportTransformFactory func(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) + +var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{ + "Qwen3_5ForCausalLM": newQwen35ImportTransform, + "Qwen3_5ForConditionalGeneration": newQwen35ImportTransform, + "Qwen3NextForCausalLM": newQwen35ImportTransform, + "Qwen3NextForConditionalGeneration": newQwen35ImportTransform, + "Qwen3_5MoeForCausalLM": newQwen35ImportTransform, + "Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform, + "Qwen3NextMoeForCausalLM": newQwen35ImportTransform, + "Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform, +} + +func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) { + if factory, ok := tensorImportTransformRegistry[cfg.Architecture()]; ok { + return factory(modelDir, cfg) + } + return noopImportTransform{}, nil +} + // CreateSafetensorsModel imports a standard safetensors model from a directory. // This handles Hugging Face style models with config.json and *.safetensors files. // Stores each tensor as a separate blob for fine-grained deduplication. @@ -430,6 +542,15 @@ type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string), createPackedLayer ...PackedTensorLayerCreator) error { var layers []LayerInfo var configLayer LayerInfo + sourceConfig, err := readSourceModelConfig(modelDir) + if err != nil { + return fmt.Errorf("failed to read source config.json: %w", err) + } + sourceQuantMetadata := sourceConfig.QuantMetadata() + importTransform, err := newTensorImportTransform(modelDir, sourceConfig) + if err != nil { + return fmt.Errorf("failed to construct import transform for architecture %q: %w", sourceConfig.Architecture(), err) + } // Resolve the optional packed layer creator var packedCreator PackedTensorLayerCreator @@ -474,6 +595,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La } tensorNames := extractor.ListTensors() + tensorSet := make(map[string]struct{}, len(tensorNames)) + for _, name := range tensorNames { + tensorSet[name] = struct{}{} + } quantizeMsg := "" if quantize != "" { quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize) @@ -484,6 +609,13 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La hasExpertTensors := false for _, tensorName := range tensorNames { + if importTransform.skipTensor(tensorName) { + continue + } + if shouldSkipPrequantizedCompanion(tensorName, tensorSet) { + continue + } + td, err := extractor.GetTensor(tensorName) if err != nil { extractor.Close() @@ -491,45 +623,67 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La return fmt.Errorf("failed to get tensor %s: %w", tensorName, err) } - // Determine quantization type for this tensor (empty string if not quantizing) - // GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN) - quantizeType := "" - if quantize != "" { - quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize) - } - - // Check if this tensor belongs to an expert group for packing - groupPrefix := "" - if packedCreator != nil { - groupPrefix = ExpertGroupPrefix(tensorName) - } - - if groupPrefix != "" { - // Accumulate expert tensor for packed blob. - // The Reader uses a file-backed SectionReader, so we must - // keep the extractor open until this group is flushed. - hasExpertTensors = true - if _, exists := expertGroups[groupPrefix]; !exists { - expertGroupOrder = append(expertGroupOrder, groupPrefix) - } - expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{ - Name: tensorName, - Dtype: td.Dtype, - Shape: td.Shape, - Quantize: quantizeType, - Reader: td.SafetensorsReader(), - }) - } else { - // Store as minimal safetensors format (88 bytes header overhead) - // This enables native mmap loading via mlx_load_safetensors - // createTensorLayer returns multiple layers if quantizing (weight + scales) - newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType) + if quantize == "" { + layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer) if err != nil { extractor.Close() closeExtractors() - return fmt.Errorf("failed to create layer for %s: %w", tensorName, err) + return err + } + if ok { + layers = append(layers, layer) + continue + } + } + + outputTensors, err := importTransform.transformTensor(td) + if err != nil { + extractor.Close() + closeExtractors() + return fmt.Errorf("failed to transform tensor %s: %w", tensorName, err) + } + + for _, outTD := range outputTensors { + // Determine quantization type for this tensor (empty string if not quantizing) + // GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN) + quantizeType := "" + if quantize != "" { + quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, quantize) + } + + // Check if this tensor belongs to an expert group for packing + groupPrefix := "" + if packedCreator != nil { + groupPrefix = ExpertGroupPrefix(outTD.Name) + } + + if groupPrefix != "" { + // Accumulate expert tensor for packed blob. + // The Reader uses a file-backed SectionReader, so we must + // keep the extractor open until this group is flushed. + hasExpertTensors = true + if _, exists := expertGroups[groupPrefix]; !exists { + expertGroupOrder = append(expertGroupOrder, groupPrefix) + } + expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{ + Name: outTD.Name, + Dtype: outTD.Dtype, + Shape: outTD.Shape, + Quantize: quantizeType, + Reader: outTD.SafetensorsReader(), + }) + } else { + // Store as minimal safetensors format (88 bytes header overhead) + // This enables native mmap loading via mlx_load_safetensors + // createTensorLayer returns multiple layers if quantizing (weight + scales) + newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType) + if err != nil { + extractor.Close() + closeExtractors() + return fmt.Errorf("failed to create layer for %s: %w", outTD.Name, err) + } + layers = append(layers, newLayers...) } - layers = append(layers, newLayers...) } } @@ -605,3 +759,74 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers))) return nil } + +func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{}) bool { + switch { + case strings.HasSuffix(name, ".scales"): + _, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"] + return ok + case strings.HasSuffix(name, ".biases"): + _, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"] + return ok + default: + return false + } +} + +func createPrequantizedLayer( + extractor *safetensors.TensorExtractor, + td *safetensors.TensorData, + tensorName string, + tensorSet map[string]struct{}, + metadata map[string]string, + createLayer LayerCreator, +) (LayerInfo, bool, error) { + scaleName, biasName, ok := prequantizedCompanions(tensorName, tensorSet) + if !ok { + return LayerInfo{}, false, nil + } + + tensors := []*safetensors.TensorData{td.WithName(tensorName)} + + scaleTD, err := extractor.GetTensor(scaleName) + if err != nil { + return LayerInfo{}, false, fmt.Errorf("failed to get tensor %s: %w", scaleName, err) + } + tensors = append(tensors, scaleTD.WithName(tensorName+".scale")) + + if biasName != "" { + biasTD, err := extractor.GetTensor(biasName) + if err != nil { + return LayerInfo{}, false, fmt.Errorf("failed to get tensor %s: %w", biasName, err) + } + tensors = append(tensors, biasTD.WithName(tensorName+".bias")) + } + + layer, err := createLayer( + safetensors.BuildPackedSafetensorsReaderWithMetadata(tensors, metadata), + "application/vnd.ollama.image.tensor", + tensorName, + ) + if err != nil { + return LayerInfo{}, false, fmt.Errorf("failed to create prequantized layer for %s: %w", tensorName, err) + } + return layer, true, nil +} + +func prequantizedCompanions(weightName string, tensorSet map[string]struct{}) (scaleName, biasName string, ok bool) { + if !strings.HasSuffix(weightName, ".weight") { + return "", "", false + } + + base := strings.TrimSuffix(weightName, ".weight") + scaleName = base + ".scales" + if _, ok := tensorSet[scaleName]; !ok { + return "", "", false + } + + biasName = base + ".biases" + if _, ok := tensorSet[biasName]; !ok { + biasName = "" + } + return scaleName, biasName, true +} diff --git a/x/create/create_test.go b/x/create/create_test.go index 7d9e68956..9eac89614 100644 --- a/x/create/create_test.go +++ b/x/create/create_test.go @@ -7,8 +7,12 @@ import ( "io" "os" "path/filepath" + "slices" "strings" "testing" + + "github.com/d4l3k/go-bfloat16" + st "github.com/ollama/ollama/x/imagegen/safetensors" ) func TestIsTensorModelDir(t *testing.T) { @@ -173,6 +177,75 @@ func createMinimalSafetensors(t *testing.T, path string) { } } +func createTestSafetensors(t *testing.T, path string, tensors []*st.TensorData) { + t.Helper() + + data, err := io.ReadAll(st.BuildPackedSafetensorsReader(tensors)) + if err != nil { + t.Fatalf("failed to build packed safetensors: %v", err) + } + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatalf("failed to write safetensors: %v", err) + } +} + +func readSingleTensorHeader(t *testing.T, data []byte) (string, []int32) { + t.Helper() + + var headerSize uint64 + if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil { + t.Fatalf("failed to read header size: %v", err) + } + + var header map[string]struct { + Dtype string `json:"dtype"` + Shape []int32 `json:"shape"` + } + if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil { + t.Fatalf("failed to parse header: %v", err) + } + + for name, info := range header { + if name == "__metadata__" { + continue + } + return info.Dtype, info.Shape + } + + t.Fatal("no tensor entry found in header") + return "", nil +} + +func readSingleTensorRaw(t *testing.T, data []byte) []byte { + t.Helper() + + var headerSize uint64 + if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil { + t.Fatalf("failed to read header size: %v", err) + } + + var header map[string]struct { + Dtype string `json:"dtype"` + Shape []int32 `json:"shape"` + DataOffsets [2]int `json:"data_offsets"` + } + if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil { + t.Fatalf("failed to parse header: %v", err) + } + + for name, info := range header { + if name == "__metadata__" { + continue + } + start := 8 + int(headerSize) + info.DataOffsets[0] + end := 8 + int(headerSize) + info.DataOffsets[1] + return data[start:end] + } + + t.Fatal("no tensor entry found in header") + return nil +} + func TestCreateSafetensorsModel(t *testing.T) { dir := t.TempDir() @@ -374,6 +447,252 @@ func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) { } } +func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) { + dir := t.TempDir() + + configJSON := `{ + "model_type": "test", + "architectures": ["TestModel"], + "quantization": {"group_size": 64, "bits": 4, "mode": "affine"} + }` + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil { + t.Fatalf("failed to write config.json: %v", err) + } + + createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{ + st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)), + st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)), + st.NewTensorDataFromBytes("linear.biases", "BF16", []int32{4, 1}, make([]byte, 8)), + st.NewTensorDataFromBytes("plain.weight", "F32", []int32{2, 2}, make([]byte, 16)), + }) + + var packedHeader map[string]json.RawMessage + var tensorLayerNames []string + var createTensorLayerNames []string + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + data, err := io.ReadAll(r) + if err != nil { + return LayerInfo{}, err + } + if mediaType == "application/vnd.ollama.image.tensor" && name == "linear.weight" { + var headerSize uint64 + if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil { + return LayerInfo{}, err + } + if err := json.Unmarshal(data[8:8+headerSize], &packedHeader); err != nil { + return LayerInfo{}, err + } + } + tensorLayerNames = append(tensorLayerNames, name) + return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType, Size: int64(len(data))}, nil + } + + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + if _, err := io.ReadAll(r); err != nil { + return nil, err + } + createTensorLayerNames = append(createTensorLayerNames, name) + return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil + } + + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + return nil + } + + progressFn := func(status string) {} + + if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil { + t.Fatalf("CreateSafetensorsModel failed: %v", err) + } + + if packedHeader == nil { + t.Fatal("expected packed quantized header for linear.weight") + } + if _, ok := packedHeader["linear.weight"]; !ok { + t.Fatalf("packed header missing linear.weight: %v", packedHeader) + } + if _, ok := packedHeader["linear.weight.scale"]; !ok { + t.Fatalf("packed header missing linear.weight.scale: %v", packedHeader) + } + if _, ok := packedHeader["linear.weight.bias"]; !ok { + t.Fatalf("packed header missing linear.weight.bias: %v", packedHeader) + } + + var metadata map[string]string + if metaRaw, ok := packedHeader["__metadata__"]; ok { + if err := json.Unmarshal(metaRaw, &metadata); err != nil { + t.Fatalf("failed to parse packed metadata: %v", err) + } + } + if metadata["quant_type"] != "int4" { + t.Fatalf("quant_type = %q, want %q", metadata["quant_type"], "int4") + } + if metadata["group_size"] != "64" { + t.Fatalf("group_size = %q, want %q", metadata["group_size"], "64") + } + + if slices.Contains(createTensorLayerNames, "linear.weight") { + t.Fatalf("linear.weight unexpectedly handled by createTensorLayer: %v", createTensorLayerNames) + } + if slices.Contains(createTensorLayerNames, "linear.scales") || slices.Contains(createTensorLayerNames, "linear.biases") { + t.Fatalf("quantized companions unexpectedly handled separately: %v", createTensorLayerNames) + } + if !slices.Contains(createTensorLayerNames, "plain.weight") { + t.Fatalf("plain.weight missing from createTensorLayer calls: %v", createTensorLayerNames) + } + if slices.Contains(tensorLayerNames, "linear.scales") || slices.Contains(tensorLayerNames, "linear.biases") { + t.Fatalf("quantized companions unexpectedly emitted as layers: %v", tensorLayerNames) + } +} + +func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) { + dir := t.TempDir() + + configJSON := `{ + "model_type": "test", + "architectures": ["Qwen3_5MoeForConditionalGeneration"], + "text_config": {"dtype": "bfloat16"} + }` + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil { + t.Fatalf("failed to write config.json: %v", err) + } + + gateUpValues := make([]float32, 2*128*64) + for expert := range 2 { + base := expert * 128 * 64 + for i := range 64 * 64 { + gateUpValues[base+i] = 1 + gateUpValues[base+64*64+i] = 2 + } + } + + createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{ + st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)), + st.NewTensorDataFromBytes("model.language_model.layers.0.input_layernorm.weight", "F32", []int32{64}, make([]byte, 64*4)), + st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.A_log", "F32", []int32{32}, make([]byte, 32*4)), + st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.conv1d.weight", "BF16", []int32{64, 1, 4}, make([]byte, 64*1*4*2)), + st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)), + st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)), + st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))), + st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert.down_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)), + st.NewTensorDataFromBytes("model.visual.blocks.0.attn.proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)), + st.NewTensorDataFromBytes("mtp.layers.0.foo.weight", "F32", []int32{64, 64}, make([]byte, 64*64*4)), + }) + + type tensorCall struct { + dtype string + shape []int32 + quantize string + raw []byte + } + calls := make(map[string]tensorCall) + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + _, _ = io.ReadAll(r) + return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil + } + + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + headerDType, headerShape := readSingleTensorHeader(t, data) + calls[name] = tensorCall{ + dtype: headerDType, + shape: headerShape, + quantize: quantize, + raw: readSingleTensorRaw(t, data), + } + return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil + } + + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + return nil + } + + if err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil { + t.Fatalf("CreateSafetensorsModel failed: %v", err) + } + + if _, ok := calls["mtp.layers.0.foo.weight"]; ok { + t.Fatal("mtp tensor should have been dropped") + } + + layerNorm := calls["language_model.model.layers.0.input_layernorm.weight"] + if layerNorm.dtype != "BF16" { + t.Fatalf("input_layernorm dtype = %q, want %q", layerNorm.dtype, "BF16") + } + if layerNorm.quantize != "" { + t.Fatalf("input_layernorm quantize = %q, want empty", layerNorm.quantize) + } + layerNormValues := bfloat16.DecodeFloat32(layerNorm.raw) + if len(layerNormValues) == 0 || layerNormValues[0] != 1.0 { + t.Fatalf("input_layernorm first value = %v, want 1.0 after +1 shift", layerNormValues[0]) + } + + alog := calls["language_model.model.layers.0.linear_attn.A_log"] + if alog.dtype != "F32" { + t.Fatalf("A_log dtype = %q, want %q", alog.dtype, "F32") + } + + conv := calls["language_model.model.layers.0.linear_attn.conv1d.weight"] + if !slices.Equal(conv.shape, []int32{64, 4, 1}) { + t.Fatalf("conv1d shape = %v, want %v", conv.shape, []int32{64, 4, 1}) + } + + if got := calls["language_model.model.embed_tokens.weight"].quantize; got != "int4" { + t.Fatalf("embed_tokens quantize = %q, want %q", got, "int4") + } + if got := calls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "int4" { + t.Fatalf("mlp.gate quantize = %q, want %q", got, "int4") + } + if got := calls["language_model.model.layers.0.mlp.shared_expert.down_proj.weight"].quantize; got != "int4" { + t.Fatalf("down_proj quantize = %q, want %q", got, "int4") + } + + if _, ok := calls["language_model.model.layers.0.mlp.experts.gate_up_proj"]; ok { + t.Fatal("combined gate_up_proj tensor should have been rewritten") + } + if _, ok := calls["language_model.model.layers.0.mlp.experts.down_proj"]; ok { + t.Fatal("combined down_proj tensor should have been rewritten") + } + + gateProj := calls["language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight"] + if !slices.Equal(gateProj.shape, []int32{2, 64, 64}) { + t.Fatalf("gate_proj shape = %v, want %v", gateProj.shape, []int32{2, 64, 64}) + } + gateProjValues := bfloat16.DecodeFloat32(gateProj.raw) + if len(gateProjValues) == 0 || gateProjValues[0] != 1.0 { + t.Fatalf("gate_proj first value = %v, want 1.0", gateProjValues[0]) + } + + upProj := calls["language_model.model.layers.0.mlp.switch_mlp.up_proj.weight"] + if !slices.Equal(upProj.shape, []int32{2, 64, 64}) { + t.Fatalf("up_proj shape = %v, want %v", upProj.shape, []int32{2, 64, 64}) + } + upProjValues := bfloat16.DecodeFloat32(upProj.raw) + if len(upProjValues) == 0 || upProjValues[0] != 2.0 { + t.Fatalf("up_proj first value = %v, want 2.0", upProjValues[0]) + } + + if got := calls["language_model.model.layers.0.mlp.switch_mlp.down_proj.weight"].quantize; got != "int4" { + t.Fatalf("switch_mlp down_proj quantize = %q, want %q", got, "int4") + } + + vision := calls["vision_tower.blocks.0.attn.proj.weight"] + if vision.dtype != "BF16" { + t.Fatalf("vision weight dtype = %q, want %q", vision.dtype, "BF16") + } + if vision.quantize != "" { + t.Fatalf("vision weight quantize = %q, want empty", vision.quantize) + } + if _, ok := calls["language_model.model.visual.blocks.0.attn.proj.weight"]; ok { + t.Fatal("vision tensor should have been rewritten to vision_tower.*") + } +} + func TestResolveManifestPath(t *testing.T) { tests := []struct { name string diff --git a/x/create/dtype.go b/x/create/dtype.go new file mode 100644 index 000000000..a7181df79 --- /dev/null +++ b/x/create/dtype.go @@ -0,0 +1,109 @@ +package create + +import ( + "encoding/binary" + "fmt" + "math" + "strings" + + "github.com/d4l3k/go-bfloat16" + "github.com/x448/float16" +) + +// DTypeSize returns the byte size of a single element for the given dtype string. +func DTypeSize(dtype string) (int, error) { + switch strings.ToUpper(dtype) { + case "BF16", "F16": + return 2, nil + case "F32", "U32", "I32": + return 4, nil + case "F64": + return 8, nil + default: + return 0, fmt.Errorf("unsupported dtype %q", dtype) + } +} + +// DecodeFloatTensor decodes raw bytes into []float32 according to the given dtype. +func DecodeFloatTensor(dtype string, raw []byte) ([]float32, error) { + switch strings.ToUpper(dtype) { + case "BF16": + return bfloat16.DecodeFloat32(raw), nil + case "F16": + if len(raw)%2 != 0 { + return nil, fmt.Errorf("invalid f16 byte length %d", len(raw)) + } + values := make([]float32, len(raw)/2) + for i := range values { + values[i] = float16.Frombits(binary.LittleEndian.Uint16(raw[i*2:])).Float32() + } + return values, nil + case "F32": + if len(raw)%4 != 0 { + return nil, fmt.Errorf("invalid f32 byte length %d", len(raw)) + } + values := make([]float32, len(raw)/4) + for i := range values { + values[i] = math.Float32frombits(binary.LittleEndian.Uint32(raw[i*4:])) + } + return values, nil + case "F64": + if len(raw)%8 != 0 { + return nil, fmt.Errorf("invalid f64 byte length %d", len(raw)) + } + values := make([]float32, len(raw)/8) + for i := range values { + values[i] = float32(math.Float64frombits(binary.LittleEndian.Uint64(raw[i*8:]))) + } + return values, nil + default: + return nil, fmt.Errorf("unsupported dtype %q", dtype) + } +} + +// EncodeFloatTensor encodes []float32 into raw bytes according to the given dtype. +func EncodeFloatTensor(dtype string, values []float32) ([]byte, error) { + switch strings.ToUpper(dtype) { + case "BF16": + return bfloat16.EncodeFloat32(values), nil + case "F16": + out := make([]byte, len(values)*2) + for i, v := range values { + binary.LittleEndian.PutUint16(out[i*2:], float16.Fromfloat32(v).Bits()) + } + return out, nil + case "F32": + out := make([]byte, len(values)*4) + for i, v := range values { + binary.LittleEndian.PutUint32(out[i*4:], math.Float32bits(v)) + } + return out, nil + case "F64": + out := make([]byte, len(values)*8) + for i, v := range values { + binary.LittleEndian.PutUint64(out[i*8:], math.Float64bits(float64(v))) + } + return out, nil + default: + return nil, fmt.Errorf("unsupported dtype %q", dtype) + } +} + +func sourceQuantType(mode string, bits int) string { + switch strings.ToLower(mode) { + case "affine": + switch bits { + case 4: + return "int4" + case 8: + return "int8" + } + case "nvfp4": + return "nvfp4" + case "mxfp8": + return "mxfp8" + case "mxfp4": + return "mxfp4" + } + return "" +} diff --git a/x/create/qwen35.go b/x/create/qwen35.go new file mode 100644 index 000000000..77bec9ee7 --- /dev/null +++ b/x/create/qwen35.go @@ -0,0 +1,323 @@ +package create + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +type qwen35ImportTransform struct { + shouldShiftNormWeights bool + rewriteLanguageModel bool +} + +type qwen35SourceInfo struct { + hasPrequantizedWeights bool + shouldShiftNormWeights bool +} + +func newQwen35ImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) { + sourceInfo, err := qwen35InspectSource(modelDir) + if err != nil { + return qwen35ImportTransform{}, err + } + if sourceInfo.hasPrequantizedWeights { + return noopImportTransform{}, nil + } + + return qwen35ImportTransform{ + shouldShiftNormWeights: sourceInfo.shouldShiftNormWeights, + rewriteLanguageModel: strings.Contains(cfg.Architecture(), "ConditionalGeneration"), + }, nil +} + +func qwen35InspectSource(modelDir string) (qwen35SourceInfo, error) { + entries, err := os.ReadDir(modelDir) + if err != nil { + return qwen35SourceInfo{}, err + } + + var info qwen35SourceInfo + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") { + continue + } + + extractor, err := safetensors.OpenForExtraction(filepath.Join(modelDir, entry.Name())) + if err != nil { + return qwen35SourceInfo{}, err + } + + for _, name := range extractor.ListTensors() { + if strings.HasSuffix(name, ".scales") { + extractor.Close() + info.hasPrequantizedWeights = true + return info, nil + } + // This should change when MTP is supported + if strings.Contains(name, "mtp.") { + info.shouldShiftNormWeights = true + continue + } + if info.shouldShiftNormWeights || !strings.Contains(name, "conv1d.weight") { + continue + } + + td, err := extractor.GetTensor(name) + if err != nil { + extractor.Close() + return qwen35SourceInfo{}, err + } + if len(td.Shape) == 3 && td.Shape[2] != 1 { + info.shouldShiftNormWeights = true + } + } + + extractor.Close() + } + + return info, nil +} + +func (t qwen35ImportTransform) skipTensor(name string) bool { + return strings.Contains(name, "mtp.") +} + +func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string { + if strings.HasPrefix(name, "vision_tower.") { + return "" + } + + stackedExpert := isStackedExpertWeight(name) + if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") || + strings.HasSuffix(name, ".biases") || strings.HasSuffix(name, ".scales") { + return "" + } + if !stackedExpert && !strings.HasSuffix(name, ".weight") { + return "" + } + if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") { + return "" + } + if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) { + return "" + } + + var elems int64 = 1 + for _, d := range shape { + elems *= int64(d) + } + if elems < 1024 { + return "" + } + + quantNorm := normalizeQuantType(quantize) + groupSize := int32(32) + switch quantNorm { + case "nvfp4": + groupSize = 16 + case "int4", "int8": + groupSize = 64 + } + if shape[len(shape)-1]%groupSize != 0 { + return "" + } + + return quantNorm +} + +func (t qwen35ImportTransform) rewriteTensorData(td *safetensors.TensorData) (*safetensors.TensorData, error) { + if td == nil { + return td, nil + } + + shiftNorm := t.shouldShiftNormWeights && qwen35ShouldShiftNormKey(td.Name) + transposeConv := strings.Contains(td.Name, "conv1d.weight") && len(td.Shape) == 3 && td.Shape[2] != 1 + castToBF16 := qwen35NeedsCastToBF16(td.Name, td.Dtype) + if !shiftNorm && !transposeConv && !castToBF16 { + return td, nil + } + + raw, err := io.ReadAll(td.Reader()) + if err != nil { + return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err) + } + + values, err := DecodeFloatTensor(td.Dtype, raw) + if err != nil { + return nil, fmt.Errorf("failed to decode tensor %s: %w", td.Name, err) + } + + shape := append([]int32(nil), td.Shape...) + if transposeConv { + values, shape = qwen35TransposeConv1D(values, shape) + } + if shiftNorm { + for i := range values { + values[i] += 1.0 + } + } + + targetDtype := td.Dtype + if castToBF16 { + targetDtype = "BF16" + } + + out, err := EncodeFloatTensor(targetDtype, values) + if err != nil { + return nil, fmt.Errorf("failed to encode tensor %s: %w", td.Name, err) + } + + return safetensors.NewTensorDataFromBytes(td.Name, targetDtype, shape, out), nil +} + +func (t qwen35ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) { + if td == nil { + return nil, nil + } + + name := t.canonicalTensorName(td.Name) + + // Phase 1: rename/split into intermediate tensors + var intermediates []*safetensors.TensorData + stripped := strings.TrimSuffix(name, ".weight") + switch { + case strings.HasSuffix(stripped, ".mlp.experts.gate_up_proj"): + prefix := strings.TrimSuffix(stripped, ".mlp.experts.gate_up_proj") + raw, err := io.ReadAll(td.Reader()) + if err != nil { + return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err) + } + gateRaw, upRaw, splitShape, err := qwen35SplitAxis1Raw(raw, td.Dtype, td.Shape) + if err != nil { + return nil, fmt.Errorf("failed to split tensor %s: %w", td.Name, err) + } + intermediates = []*safetensors.TensorData{ + safetensors.NewTensorDataFromBytes(prefix+".mlp.switch_mlp.gate_proj.weight", td.Dtype, splitShape, gateRaw), + safetensors.NewTensorDataFromBytes(prefix+".mlp.switch_mlp.up_proj.weight", td.Dtype, splitShape, upRaw), + } + case strings.HasSuffix(stripped, ".mlp.experts.down_proj"): + newName := strings.TrimSuffix(stripped, ".mlp.experts.down_proj") + ".mlp.switch_mlp.down_proj.weight" + intermediates = []*safetensors.TensorData{td.WithName(newName)} + default: + intermediates = []*safetensors.TensorData{td.WithName(name)} + } + + // Phase 2: rewrite all intermediates + results := make([]*safetensors.TensorData, 0, len(intermediates)) + for _, inter := range intermediates { + rewritten, err := t.rewriteTensorData(inter) + if err != nil { + return nil, err + } + results = append(results, rewritten) + } + return results, nil +} + +func (t qwen35ImportTransform) canonicalTensorName(name string) string { + // Vision tensors: normalize to vision_tower.* prefix + switch { + case strings.HasPrefix(name, "model.visual."): + return "vision_tower." + strings.TrimPrefix(name, "model.visual.") + case strings.HasPrefix(name, "vision_tower."): + return name + } + + // Language model tensors: normalize to language_model.model.* prefix + if !t.rewriteLanguageModel { + return name + } + switch { + case strings.HasPrefix(name, "model.language_model"): + return "language_model.model" + strings.TrimPrefix(name, "model.language_model") + case strings.HasPrefix(name, "language_model."): + return name + default: + return "language_model." + name + } +} + +func qwen35ShouldShiftNormKey(key string) bool { + for _, suffix := range []string{ + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + "model.norm.weight", + ".q_norm.weight", + ".k_norm.weight", + } { + if strings.HasSuffix(key, suffix) { + return true + } + } + return false +} + +func qwen35NeedsCastToBF16(name, dtype string) bool { + if strings.HasSuffix(name, "A_log") { + return false + } + switch strings.ToUpper(dtype) { + case "F16", "F32", "F64": + return true + default: + return false + } +} + +func qwen35TransposeConv1D(values []float32, shape []int32) ([]float32, []int32) { + if len(shape) != 3 { + return values, shape + } + + d0, d1, d2 := int(shape[0]), int(shape[1]), int(shape[2]) + out := make([]float32, len(values)) + for i := range d0 { + for j := range d1 { + for k := range d2 { + inIdx := (i*d1+j)*d2 + k + outIdx := (i*d2+k)*d1 + j + out[outIdx] = values[inIdx] + } + } + } + + return out, []int32{shape[0], shape[2], shape[1]} +} + +func qwen35SplitAxis1Raw(raw []byte, dtype string, shape []int32) ([]byte, []byte, []int32, error) { + if len(shape) != 3 { + return nil, nil, nil, fmt.Errorf("expected 3D tensor, got shape %v", shape) + } + if shape[1]%2 != 0 { + return nil, nil, nil, fmt.Errorf("axis 1 dim %d is not even", shape[1]) + } + + elemSize, err := DTypeSize(dtype) + if err != nil { + return nil, nil, nil, err + } + + d0, d1, d2 := int(shape[0]), int(shape[1]), int(shape[2]) + perExpertBytes := d1 * d2 * elemSize + if len(raw) != d0*perExpertBytes { + return nil, nil, nil, fmt.Errorf("raw byte length %d does not match shape %v and dtype %s", len(raw), shape, dtype) + } + + halfD1 := d1 / 2 + halfExpertBytes := halfD1 * d2 * elemSize + gateRaw := make([]byte, d0*halfExpertBytes) + upRaw := make([]byte, d0*halfExpertBytes) + for e := range d0 { + src := e * perExpertBytes + dst := e * halfExpertBytes + copy(gateRaw[dst:dst+halfExpertBytes], raw[src:src+halfExpertBytes]) + copy(upRaw[dst:dst+halfExpertBytes], raw[src+halfExpertBytes:src+perExpertBytes]) + } + + return gateRaw, upRaw, []int32{shape[0], int32(halfD1), shape[2]}, nil +} diff --git a/x/imagegen/safetensors/extractor.go b/x/imagegen/safetensors/extractor.go index 65a4f6da0..549222eab 100644 --- a/x/imagegen/safetensors/extractor.go +++ b/x/imagegen/safetensors/extractor.go @@ -36,6 +36,23 @@ type TensorData struct { reader *io.SectionReader } +// WithName returns a shallow copy of TensorData with a different logical tensor +// name but the same underlying raw data reader. +func (td *TensorData) WithName(name string) *TensorData { + if td == nil { + return nil + } + shape := make([]int32, len(td.Shape)) + copy(shape, td.Shape) + return &TensorData{ + Name: name, + Dtype: td.Dtype, + Shape: shape, + Size: td.Size, + reader: td.reader, + } +} + // Reader returns an io.Reader for the tensor's raw bytes. func (td *TensorData) Reader() io.Reader { return td.reader @@ -117,8 +134,15 @@ func ExtractRawFromSafetensors(r io.Reader) ([]byte, error) { // into a single blob without loading all data into memory. // Each TensorData must have been obtained from GetTensor. func BuildPackedSafetensorsReader(tensors []*TensorData) io.Reader { + return BuildPackedSafetensorsReaderWithMetadata(tensors, nil) +} + +// BuildPackedSafetensorsReaderWithMetadata builds a streaming io.Reader that +// outputs a valid safetensors file containing multiple tensors and optional +// metadata. +func BuildPackedSafetensorsReaderWithMetadata(tensors []*TensorData, metadata map[string]string) io.Reader { // Build the header with sequential data offsets - header := make(map[string]tensorInfo, len(tensors)) + header := make(map[string]any, len(tensors)+1) var offset int for _, td := range tensors { header[td.Name] = tensorInfo{ @@ -128,6 +152,9 @@ func BuildPackedSafetensorsReader(tensors []*TensorData) io.Reader { } offset += int(td.Size) } + if len(metadata) > 0 { + header["__metadata__"] = metadata + } headerJSON, _ := json.Marshal(header)