Files
ollama/x/create/create_test.go
Patrick Devine de5cb7311f mlx: add mxfp4/mxfp8/nvfp4 importing (#15015)
This change allows importing bf16 and converting to mxfp4/mxfp8/nvfp4
and also importing fp8 and converting directly to mxfp8.
2026-03-24 13:45:44 -07:00

1665 lines
58 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package create
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/d4l3k/go-bfloat16"
st "github.com/ollama/ollama/x/imagegen/safetensors"
)
func TestIsTensorModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid diffusers model with model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
},
expected: true,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "directory with other files but no model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsTensorModelDir(dir)
if got != tt.expected {
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid safetensors model with config.json and .safetensors file",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
},
expected: true,
},
{
name: "config.json only, no safetensors files",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
},
expected: false,
},
{
name: "safetensors file only, no config.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
},
expected: false,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "multiple safetensors files with config.json",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
return err
}
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsSafetensorsModelDir(dir)
if got != tt.expected {
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
if got != false {
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
}
}
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
func createMinimalSafetensors(t *testing.T, path string) {
t.Helper()
// Create a minimal safetensors file with a single float32 tensor
header := map[string]interface{}{
"test_tensor": map[string]interface{}{
"dtype": "F32",
"shape": []int{2, 2},
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
},
}
headerJSON, err := json.Marshal(header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Write file
f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file: %v", err)
}
defer f.Close()
// Write header size (8 bytes, little endian)
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
// Write header
if _, err := f.Write(headerJSON); err != nil {
t.Fatalf("failed to write header: %v", err)
}
// Write tensor data (16 bytes of zeros for 4 float32 values)
if _, err := f.Write(make([]byte, 16)); err != nil {
t.Fatalf("failed to write tensor data: %v", err)
}
}
func createTestSafetensors(t *testing.T, path string, tensors []*st.TensorData) {
t.Helper()
data, err := io.ReadAll(st.BuildPackedSafetensorsReader(tensors))
if err != nil {
t.Fatalf("failed to build packed safetensors: %v", err)
}
if err := os.WriteFile(path, data, 0o644); err != nil {
t.Fatalf("failed to write safetensors: %v", err)
}
}
func readSingleTensorHeader(t *testing.T, data []byte) (string, []int32) {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
}
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
for name, info := range header {
if name == "__metadata__" {
continue
}
return info.Dtype, info.Shape
}
t.Fatal("no tensor entry found in header")
return "", nil
}
func readSingleTensorRaw(t *testing.T, data []byte) []byte {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
for name, info := range header {
if name == "__metadata__" {
continue
}
start := 8 + int(headerSize) + info.DataOffsets[0]
end := 8 + int(headerSize) + info.DataOffsets[1]
return data[start:end]
}
t.Fatal("no tensor entry found in header")
return nil
}
func readSafetensorsHeaderNames(t *testing.T, data []byte) []string {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
names := make([]string, 0, len(header))
for name := range header {
if name == "__metadata__" {
continue
}
names = append(names, name)
}
slices.Sort(names)
return names
}
func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Track what was created
var createdLayers []LayerInfo
var manifestWritten bool
var manifestModelName string
var manifestConfigLayer LayerInfo
var manifestLayers []LayerInfo
var statusMessages []string
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
layer := LayerInfo{
Digest: "sha256:test",
Size: int64(len(data)),
MediaType: mediaType,
Name: name,
}
createdLayers = append(createdLayers, layer)
return layer, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
layer := LayerInfo{
Digest: "sha256:tensor_" + name,
Size: int64(len(data)),
MediaType: "application/vnd.ollama.image.tensor",
Name: name,
}
createdLayers = append(createdLayers, layer)
return []LayerInfo{layer}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
manifestConfigLayer = config
manifestLayers = layers
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
// Run CreateSafetensorsModel
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify manifest was written
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-model" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
}
// Verify config layer was set
if manifestConfigLayer.Name != "config.json" {
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
}
// Verify we have at least one tensor and one config layer
hasTensor := false
hasConfig := false
for _, layer := range manifestLayers {
if layer.Name == "test_tensor" {
hasTensor = true
}
if layer.Name == "config.json" {
hasConfig = true
}
}
if !hasTensor {
t.Error("no tensor layer found in manifest")
}
if !hasConfig {
t.Error("no config layer found in manifest")
}
// Verify status messages were sent
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
dir := t.TempDir()
// Create only a safetensors file, no config.json
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Mock callbacks (minimal)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing config.json, got nil")
}
}
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
dir := t.TempDir()
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
return []LayerInfo{{}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
dir := t.TempDir()
// Create config.json
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create model.safetensors.index.json (should be skipped)
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
t.Fatalf("failed to write index.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var configNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
configNames = append(configNames, name)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify model.safetensors.index.json was not included
for _, name := range configNames {
if name == "model.safetensors.index.json" {
t.Error("model.safetensors.index.json should have been skipped")
}
}
}
func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization": {"group_size": 64, "bits": 4, "mode": "affine"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
st.NewTensorDataFromBytes("linear.biases", "BF16", []int32{4, 1}, make([]byte, 8)),
st.NewTensorDataFromBytes("plain.weight", "F32", []int32{2, 2}, make([]byte, 16)),
})
var packedHeader map[string]json.RawMessage
var tensorLayerNames []string
var createTensorLayerNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
if mediaType == "application/vnd.ollama.image.tensor" && name == "linear.weight" {
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
return LayerInfo{}, err
}
if err := json.Unmarshal(data[8:8+headerSize], &packedHeader); err != nil {
return LayerInfo{}, err
}
}
tensorLayerNames = append(tensorLayerNames, name)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType, Size: int64(len(data))}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
createTensorLayerNames = append(createTensorLayerNames, name)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if packedHeader == nil {
t.Fatal("expected packed quantized header for linear.weight")
}
if _, ok := packedHeader["linear.weight"]; !ok {
t.Fatalf("packed header missing linear.weight: %v", packedHeader)
}
if _, ok := packedHeader["linear.weight.scale"]; !ok {
t.Fatalf("packed header missing linear.weight.scale: %v", packedHeader)
}
if _, ok := packedHeader["linear.weight.bias"]; !ok {
t.Fatalf("packed header missing linear.weight.bias: %v", packedHeader)
}
var metadata map[string]string
if metaRaw, ok := packedHeader["__metadata__"]; ok {
if err := json.Unmarshal(metaRaw, &metadata); err != nil {
t.Fatalf("failed to parse packed metadata: %v", err)
}
}
if metadata["quant_type"] != "int4" {
t.Fatalf("quant_type = %q, want %q", metadata["quant_type"], "int4")
}
if metadata["group_size"] != "64" {
t.Fatalf("group_size = %q, want %q", metadata["group_size"], "64")
}
if slices.Contains(createTensorLayerNames, "linear.weight") {
t.Fatalf("linear.weight unexpectedly handled by createTensorLayer: %v", createTensorLayerNames)
}
if slices.Contains(createTensorLayerNames, "linear.scales") || slices.Contains(createTensorLayerNames, "linear.biases") {
t.Fatalf("quantized companions unexpectedly handled separately: %v", createTensorLayerNames)
}
if !slices.Contains(createTensorLayerNames, "plain.weight") {
t.Fatalf("plain.weight missing from createTensorLayer calls: %v", createTensorLayerNames)
}
if slices.Contains(tensorLayerNames, "linear.scales") || slices.Contains(tensorLayerNames, "linear.biases") {
t.Fatalf("quantized companions unexpectedly emitted as layers: %v", tensorLayerNames)
}
}
func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("dense.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
})
quantizeByName := make(map[string]string)
headerNamesByName := make(map[string][]string)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
quantizeByName[name] = quantize
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
t.Fatalf("linear.weight quantization = %q, want %q", got, "mxfp8")
}
if got := quantizeByName["norm.weight"]; got != "" {
t.Fatalf("norm.weight quantization = %q, want empty", got)
}
if got := quantizeByName["dense.weight"]; got != "" {
t.Fatalf("dense.weight quantization = %q, want empty", got)
}
if _, ok := quantizeByName["linear.weight_scale_inv"]; ok {
t.Fatal("linear.weight_scale_inv should not be imported as a standalone tensor")
}
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale_inv"}) {
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale_inv"})
}
if got := headerNamesByName["norm.weight"]; !slices.Equal(got, []string{"norm.weight"}) {
t.Fatalf("norm.weight blob tensors = %v, want %v", got, []string{"norm.weight"})
}
if got := headerNamesByName["dense.weight"]; !slices.Equal(got, []string{"dense.weight"}) {
t.Fatalf("dense.weight blob tensors = %v, want %v", got, []string{"dense.weight"})
}
}
func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T) {
tests := []struct {
name string
configJSON string
tensors []*st.TensorData
wantErr string
}{
{
name: "prequantized affine",
configJSON: `{"model_type": "test", "architectures": ["TestModel"]}`,
tensors: []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
},
wantErr: `cannot requantize already-quantized source model with --quantize "int4"`,
},
{
name: "hf fp8 source",
configJSON: `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`,
tensors: []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
},
wantErr: `cannot requantize already-quantized fp8 source model with --quantize "int4"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), tt.tensors)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
return nil, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {})
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("error = %q, want substring %q", err, tt.wantErr)
}
})
}
}
func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create 2 experts so stacking produces a [2, 128, 128] tensor
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
})
var packedLayerNames []string
var packedLayerTensors [][]PackedTensorInput
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
packedLayerNames = append(packedLayerNames, groupName)
packedLayerTensors = append(packedLayerTensors, tensors)
return LayerInfo{Name: groupName, Digest: "sha256:packed_" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if len(packedLayerNames) != 1 {
t.Fatalf("expected 1 packed layer, got %d: %v", len(packedLayerNames), packedLayerNames)
}
if packedLayerNames[0] != "language_model.model.layers.0.mlp.experts" {
t.Fatalf("unexpected packed layer name: %s", packedLayerNames[0])
}
// Verify all 6 expert tensors (2 experts × 3 proj types) were accumulated
tensors := packedLayerTensors[0]
if len(tensors) != 6 {
t.Fatalf("expected 6 tensors in packed group, got %d", len(tensors))
}
// All should be marked for mxfp8 quantization
for _, tensor := range tensors {
if tensor.Quantize != "mxfp8" {
t.Fatalf("expected mxfp8 quantize for %s, got %q", tensor.Name, tensor.Quantize)
}
}
}
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {"dtype": "bfloat16"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
gateUpValues := make([]float32, 2*128*64)
for expert := range 2 {
base := expert * 128 * 64
for i := range 64 * 64 {
gateUpValues[base+i] = 1
gateUpValues[base+64*64+i] = 2
}
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.input_layernorm.weight", "F32", []int32{64}, make([]byte, 64*4)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.A_log", "F32", []int32{32}, make([]byte, 32*4)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.conv1d.weight", "BF16", []int32{64, 1, 4}, make([]byte, 64*1*4*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert.down_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.visual.blocks.0.attn.proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("mtp.layers.0.foo.weight", "F32", []int32{64, 64}, make([]byte, 64*64*4)),
})
type tensorCall struct {
dtype string
shape []int32
quantize string
raw []byte
}
calls := make(map[string]tensorCall)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, _ = io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
headerDType, headerShape := readSingleTensorHeader(t, data)
calls[name] = tensorCall{
dtype: headerDType,
shape: headerShape,
quantize: quantize,
raw: readSingleTensorRaw(t, data),
}
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
if err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if _, ok := calls["mtp.layers.0.foo.weight"]; ok {
t.Fatal("mtp tensor should have been dropped")
}
layerNorm := calls["language_model.model.layers.0.input_layernorm.weight"]
if layerNorm.dtype != "BF16" {
t.Fatalf("input_layernorm dtype = %q, want %q", layerNorm.dtype, "BF16")
}
if layerNorm.quantize != "" {
t.Fatalf("input_layernorm quantize = %q, want empty", layerNorm.quantize)
}
layerNormValues := bfloat16.DecodeFloat32(layerNorm.raw)
if len(layerNormValues) == 0 || layerNormValues[0] != 1.0 {
t.Fatalf("input_layernorm first value = %v, want 1.0 after +1 shift", layerNormValues[0])
}
alog := calls["language_model.model.layers.0.linear_attn.A_log"]
if alog.dtype != "F32" {
t.Fatalf("A_log dtype = %q, want %q", alog.dtype, "F32")
}
conv := calls["language_model.model.layers.0.linear_attn.conv1d.weight"]
if !slices.Equal(conv.shape, []int32{64, 4, 1}) {
t.Fatalf("conv1d shape = %v, want %v", conv.shape, []int32{64, 4, 1})
}
if got := calls["language_model.model.embed_tokens.weight"].quantize; got != "int4" {
t.Fatalf("embed_tokens quantize = %q, want %q", got, "int4")
}
if got := calls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "int4" {
t.Fatalf("mlp.gate quantize = %q, want %q", got, "int4")
}
if got := calls["language_model.model.layers.0.mlp.shared_expert.down_proj.weight"].quantize; got != "int4" {
t.Fatalf("down_proj quantize = %q, want %q", got, "int4")
}
if _, ok := calls["language_model.model.layers.0.mlp.experts.gate_up_proj"]; ok {
t.Fatal("combined gate_up_proj tensor should have been rewritten")
}
if _, ok := calls["language_model.model.layers.0.mlp.experts.down_proj"]; ok {
t.Fatal("combined down_proj tensor should have been rewritten")
}
gateProj := calls["language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight"]
if !slices.Equal(gateProj.shape, []int32{2, 64, 64}) {
t.Fatalf("gate_proj shape = %v, want %v", gateProj.shape, []int32{2, 64, 64})
}
gateProjValues := bfloat16.DecodeFloat32(gateProj.raw)
if len(gateProjValues) == 0 || gateProjValues[0] != 1.0 {
t.Fatalf("gate_proj first value = %v, want 1.0", gateProjValues[0])
}
upProj := calls["language_model.model.layers.0.mlp.switch_mlp.up_proj.weight"]
if !slices.Equal(upProj.shape, []int32{2, 64, 64}) {
t.Fatalf("up_proj shape = %v, want %v", upProj.shape, []int32{2, 64, 64})
}
upProjValues := bfloat16.DecodeFloat32(upProj.raw)
if len(upProjValues) == 0 || upProjValues[0] != 2.0 {
t.Fatalf("up_proj first value = %v, want 2.0", upProjValues[0])
}
if got := calls["language_model.model.layers.0.mlp.switch_mlp.down_proj.weight"].quantize; got != "int4" {
t.Fatalf("switch_mlp down_proj quantize = %q, want %q", got, "int4")
}
vision := calls["vision_tower.blocks.0.attn.proj.weight"]
if vision.dtype != "BF16" {
t.Fatalf("vision weight dtype = %q, want %q", vision.dtype, "BF16")
}
if vision.quantize != "" {
t.Fatalf("vision weight quantize = %q, want empty", vision.quantize)
}
if _, ok := calls["language_model.model.visual.blocks.0.attn.proj.weight"]; ok {
t.Fatal("vision tensor should have been rewritten to vision_tower.*")
}
}
func TestCreateSafetensorsModel_Qwen35DirectNonAffineKeepsSensitiveWeightsBF16(t *testing.T) {
for _, quantize := range []string{"nvfp4", "mxfp8", "mxfp4"} {
t.Run(quantize, func(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {"dtype": "bfloat16"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
gateUpValues := make([]float32, 2*128*64)
for expert := range 2 {
base := expert * 128 * 64
for i := range 64 * 64 {
gateUpValues[base+i] = 1
gateUpValues[base+64*64+i] = 2
}
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_a.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_b.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert_gate.weight", "BF16", []int32{1, 64}, make([]byte, 64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
})
type tensorCall struct {
quantize string
}
type packedTensorCall struct {
Name string
Quantize string
}
tensorCalls := make(map[string]tensorCall)
packedCalls := make(map[string][]packedTensorCall)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, _ = io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantizeType string) ([]LayerInfo, error) {
_, _ = io.ReadAll(r)
tensorCalls[name] = tensorCall{quantize: quantizeType}
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
group := make([]packedTensorCall, 0, len(tensors))
for _, tensor := range tensors {
group = append(group, packedTensorCall{
Name: tensor.Name,
Quantize: tensor.Quantize,
})
}
packedCalls[groupName] = group
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
if err := CreateSafetensorsModel("test-model", dir, quantize, createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
for _, name := range []string{
"language_model.model.embed_tokens.weight",
"language_model.lm_head.weight",
"language_model.model.layers.0.linear_attn.in_proj_a.weight",
"language_model.model.layers.0.linear_attn.in_proj_b.weight",
"language_model.model.layers.0.mlp.gate.weight",
"language_model.model.layers.0.mlp.shared_expert_gate.weight",
} {
if got := tensorCalls[name].quantize; got != "" {
t.Fatalf("%s quantize = %q, want empty", name, got)
}
}
if got := tensorCalls["language_model.model.layers.0.self_attn.q_proj.weight"].quantize; got != quantize {
t.Fatalf("q_proj quantize = %q, want %q", got, quantize)
}
group := packedCalls["language_model.model.layers.0.mlp.switch_mlp"]
if len(group) != 3 {
t.Fatalf("packed switch_mlp tensor count = %d, want 3", len(group))
}
for _, tensor := range group {
if tensor.Quantize != quantize {
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, quantize)
}
}
})
}
}
func TestResolveManifestPath(t *testing.T) {
tests := []struct {
name string
modelName string
wantParts []string // Parts that should appear in the path
}{
{
name: "simple model name",
modelName: "llama2",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
},
{
name: "model name with tag",
modelName: "llama2:7b",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
},
{
name: "model name with namespace",
modelName: "myuser/mymodel",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
},
{
name: "model name with namespace and tag",
modelName: "myuser/mymodel:v1",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
},
{
name: "fully qualified model name",
modelName: "registry.example.com/namespace/model:tag",
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := resolveManifestPath(tt.modelName)
for _, part := range tt.wantParts {
if !strings.Contains(got, part) {
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
}
}
})
}
}
func TestLayerInfo(t *testing.T) {
layer := LayerInfo{
Digest: "sha256:abc123",
Size: 1024,
MediaType: "application/vnd.ollama.image.tensor",
Name: "model.weight",
}
if layer.Digest != "sha256:abc123" {
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
}
if layer.Size != 1024 {
t.Errorf("Size = %d, want %d", layer.Size, 1024)
}
if layer.MediaType != "application/vnd.ollama.image.tensor" {
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
}
if layer.Name != "model.weight" {
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
}
}
func TestModelConfig(t *testing.T) {
config := ModelConfig{
ModelFormat: "safetensors",
Capabilities: []string{"completion", "chat"},
}
if config.ModelFormat != "safetensors" {
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
}
if len(config.Capabilities) != 2 {
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
}
}
func TestManifest(t *testing.T) {
manifest := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.oci.image.manifest.v1+json",
Config: ManifestLayer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: "sha256:config",
Size: 100,
},
Layers: []ManifestLayer{
{
MediaType: "application/vnd.ollama.image.tensor",
Digest: "sha256:layer1",
Size: 1000,
Name: "weight.bin",
},
},
}
if manifest.SchemaVersion != 2 {
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
}
if manifest.Config.Digest != "sha256:config" {
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
}
if len(manifest.Layers) != 1 {
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
}
if manifest.Layers[0].Name != "weight.bin" {
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
}
}
func TestShouldQuantize(t *testing.T) {
tests := []struct {
name string
tensor string
component string
want bool
}{
// VAE component should never be quantized
{"vae weight", "decoder.weight", "vae", false},
{"vae bias", "decoder.bias", "vae", false},
// Embeddings should not be quantized
{"embedding weight", "embed_tokens.weight", "", false},
{"embedding in name", "token_embedding.weight", "", false},
// Norms should not be quantized
{"layer norm", "layer_norm.weight", "", false},
{"rms norm", "rms_norm.weight", "", false},
{"ln prefix", "ln_1.weight", "", false},
{"layernorm in name", "input_layernorm.weight", "", false},
// Biases should not be quantized
{"bias tensor", "attention.bias", "", false},
{"proj bias", "o_proj.bias", "", false},
// Linear weights should be quantized
{"linear weight", "q_proj.weight", "", true},
{"attention weight", "self_attn.weight", "", true},
{"mlp weight", "mlp.gate_proj.weight", "", true},
// Transformer component weights should be quantized
{"transformer weight", "layers.0.weight", "transformer", true},
{"text_encoder weight", "encoder.weight", "text_encoder", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantize(tt.tensor, tt.component)
if got != tt.want {
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
}
})
}
}
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
quantize string
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
{"large 2D weight mxfp4", "q_proj.weight", []int32{4096, 4096}, "mxfp4", true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
{"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
// Group size divisibility tests
// FP8/FP4/MXFP4 require divisible by 32
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
{"not divisible by 32 mxfp4", "proj.weight", []int32{128, 48}, "mxfp4", false},
{"divisible by 32 mxfp4", "proj.weight", []int32{128, 64}, "mxfp4", true},
// NVFP4 requires divisible by 16
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
}
func TestExpertGroupPrefix(t *testing.T) {
tests := []struct {
name string
want string
}{
// Expert tensors should return the group prefix
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
// Expert tensors with language_model prefix should also match
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
// Shared expert tensors should return their own group prefix
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
// Rewritten Qwen switch_mlp tensors should also be packed per-layer.
{"model.layers.1.mlp.switch_mlp.down_proj.weight", "model.layers.1.mlp.switch_mlp"},
{"language_model.layers.2.mlp.switch_mlp.gate_proj.weight", "language_model.layers.2.mlp.switch_mlp"},
{"language_model.model.layers.3.mlp.switch_mlp.up_proj.weight", "language_model.model.layers.3.mlp.switch_mlp"},
{"model.language_model.layers.4.mlp.switch_mlp.gate_proj.weight", "model.language_model.layers.4.mlp.switch_mlp"},
// Non-expert tensors should return empty string
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
{"model.embed_tokens.weight", ""}, // embedding
{"model.layers.0.self_attn.q_proj.weight", ""}, // attention
{"model.norm.weight", ""}, // norm
{"lm_head.weight", ""}, // output head
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ExpertGroupPrefix(tt.name)
if got != tt.want {
t.Errorf("ExpertGroupPrefix(%q) = %q, want %q", tt.name, got, tt.want)
}
})
}
}
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
gateUp := GetTensorQuantization(
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
[]int32{64, 22016, 4096},
"int4",
)
if gateUp != "int4" {
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
}
down := GetTensorQuantization(
"model.layers.1.mlp.experts.down_proj.weight",
[]int32{64, 4096, 14336},
"int4",
)
if down != "int8" {
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
}
combinedGateUp := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.gate_up_proj",
[]int32{256, 1024, 2048},
"int8",
)
if combinedGateUp != "int8" {
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
}
combinedDown := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.down_proj",
[]int32{256, 2048, 512},
"int4",
)
if combinedDown != "int8" {
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
}
nvfp4GateUp := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
[]int32{64, 11008, 4096},
"nvfp4",
)
if nvfp4GateUp != "nvfp4" {
t.Fatalf("nvfp4 gate_proj quantization = %q, want %q", nvfp4GateUp, "nvfp4")
}
nvfp4Down := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
[]int32{64, 4096, 11008},
"nvfp4",
)
if nvfp4Down != "nvfp4" {
t.Fatalf("nvfp4 down_proj quantization = %q, want %q", nvfp4Down, "nvfp4")
}
mxfp4GateUp := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
[]int32{64, 11008, 4096},
"mxfp4",
)
if mxfp4GateUp != "mxfp4" {
t.Fatalf("mxfp4 gate_proj quantization = %q, want %q", mxfp4GateUp, "mxfp4")
}
mxfp4Down := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
[]int32{64, 4096, 11008},
"mxfp4",
)
if mxfp4Down != "mxfp4" {
t.Fatalf("mxfp4 down_proj quantization = %q, want %q", mxfp4Down, "mxfp4")
}
}
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {"dtype": "bfloat16"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
gateUpValues := make([]float32, 2*128*64)
for expert := range 2 {
base := expert * 128 * 64
for i := range 64 * 64 {
gateUpValues[base+i] = 1
gateUpValues[base+64*64+i] = 2
}
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
})
type tensorCall struct {
quantize string
}
type packedTensorCall struct {
Name string
Dtype string
Shape []int32
Quantize string
}
tensorCalls := make(map[string]tensorCall)
packedCalls := make(map[string][]packedTensorCall)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, _ = io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
_, _ = io.ReadAll(r)
tensorCalls[name] = tensorCall{quantize: quantize}
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
group := make([]packedTensorCall, 0, len(tensors))
for _, tensor := range tensors {
group = append(group, packedTensorCall{
Name: tensor.Name,
Dtype: tensor.Dtype,
Shape: append([]int32(nil), tensor.Shape...),
Quantize: tensor.Quantize,
})
}
packedCalls[groupName] = group
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
groupName := "language_model.model.layers.0.mlp.switch_mlp"
group, ok := packedCalls[groupName]
if !ok {
t.Fatalf("missing packed group %q: %v", groupName, packedCalls)
}
if len(group) != 3 {
t.Fatalf("packed group %q has %d tensors, want 3", groupName, len(group))
}
gotNames := make([]string, 0, len(group))
for _, tensor := range group {
gotNames = append(gotNames, tensor.Name)
if tensor.Quantize != "nvfp4" {
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, "nvfp4")
}
if tensor.Dtype != "BF16" {
t.Fatalf("packed tensor %q dtype = %q, want %q", tensor.Name, tensor.Dtype, "BF16")
}
}
slices.Sort(gotNames)
wantNames := []string{
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
"language_model.model.layers.0.mlp.switch_mlp.up_proj.weight",
}
if !slices.Equal(gotNames, wantNames) {
t.Fatalf("packed tensor names = %v, want %v", gotNames, wantNames)
}
for _, name := range wantNames {
if _, ok := tensorCalls[name]; ok {
t.Fatalf("packed expert tensor %q unexpectedly handled by createTensorLayer", name)
}
}
if got := tensorCalls["language_model.model.embed_tokens.weight"].quantize; got != "" {
t.Fatalf("embed_tokens quantize = %q, want empty", got)
}
if got := tensorCalls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "" {
t.Fatalf("mlp.gate quantize = %q, want empty", got)
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var quantizeRequested []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
// Run with quantize enabled
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify quantize was passed to callback (will be false for small test tensor)
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}
// createMinimalImageGenModel creates a minimal diffusers-style model directory
func createMinimalImageGenModel(t *testing.T, dir string) {
t.Helper()
// Create model_index.json
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
t.Fatalf("failed to write model_index.json: %v", err)
}
// Create transformer directory with a safetensors file
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
// Create transformer config
transformerConfig := `{"hidden_size": 3072}`
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
t.Fatalf("failed to write transformer config: %v", err)
}
}
func TestCreateImageGenModel(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var manifestWritten bool
var manifestModelName string
var statusMessages []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-imagegen" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
}
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
dir := t.TempDir()
// Create only transformer without model_index.json
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing model_index.json, got nil")
}
}
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var quantizeRequested []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "int8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}