mlx: get parameters from modelfile during model creation (#14747)

This commit is contained in:
Patrick Devine
2026-03-09 15:33:24 -07:00
committed by GitHub
parent 6be2de8214
commit 3e06bde643
3 changed files with 180 additions and 29 deletions

View File

@@ -183,29 +183,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to parse Modelfile: %w", err) return fmt.Errorf("failed to parse Modelfile: %w", err)
} }
// Extract FROM path and configuration modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
var modelDir string if err != nil {
mfConfig := &xcreateclient.ModelfileConfig{} return err
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
} }
// Resolve relative paths based on Modelfile location // Resolve relative paths based on Modelfile location

View File

@@ -13,9 +13,12 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"slices"
"strings" "strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest" "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create" "github.com/ollama/ollama/x/create"
@@ -27,11 +30,79 @@ const MinOllamaVersion = "0.14.0"
// ModelfileConfig holds configuration extracted from a Modelfile. // ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct { type ModelfileConfig struct {
Template string Template string
System string System string
License string License string
Parser string Parser string
Renderer string Renderer string
Parameters map[string]any
}
var ignoredModelfileParameters = []string{
"penalize_newline",
"low_vram",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"mirostat",
"mirostat_tau",
"mirostat_eta",
}
// ConfigFromModelfile extracts the model directory and x/create-specific
// Modelfile configuration from a parsed Modelfile.
func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig, error) {
var modelDir string
mfConfig := &ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
case "adapter", "message", "requires":
continue
default:
if slices.Contains(ignoredModelfileParameters, cmd.Name) {
continue
}
ps, err := api.FormatParams(map[string][]string{cmd.Name: {cmd.Args}})
if err != nil {
return "", nil, err
}
if mfConfig.Parameters == nil {
mfConfig.Parameters = make(map[string]any)
}
for k, v := range ps {
if ks, ok := mfConfig.Parameters[k].([]string); ok {
mfConfig.Parameters[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
mfConfig.Parameters[k] = vs
} else {
mfConfig.Parameters[k] = v
}
}
}
}
if modelDir == "" {
modelDir = "."
}
return modelDir, mfConfig, nil
} }
// CreateOptions holds all options for model creation. // CreateOptions holds all options for model creation.
@@ -39,7 +110,7 @@ type CreateOptions struct {
ModelName string ModelName string
ModelDir string ModelDir string
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
} }
// CreateModel imports a model from a local directory. // CreateModel imports a model from a local directory.
@@ -351,6 +422,19 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
layers = append(layers, layer) layers = append(layers, layer)
} }
if len(mf.Parameters) > 0 {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(mf.Parameters); err != nil {
return nil, fmt.Errorf("failed to encode parameters: %w", err)
}
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, fmt.Errorf("failed to create params layer: %w", err)
}
layers = append(layers, layer)
}
return layers, nil return layers, nil
} }

View File

@@ -1,7 +1,13 @@
package client package client
import ( import (
"encoding/json"
"os"
"strings"
"testing" "testing"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/parser"
) )
func TestModelfileConfig(t *testing.T) { func TestModelfileConfig(t *testing.T) {
@@ -31,6 +37,40 @@ func TestModelfileConfig(t *testing.T) {
} }
} }
func TestConfigFromModelfile(t *testing.T) {
modelfile, err := parser.ParseFile(strings.NewReader(`
FROM ./model
TEMPLATE {{ .Prompt }}
PARAMETER temperature 0.7
PARAMETER stop USER:
PARAMETER stop ASSISTANT:
`))
if err != nil {
t.Fatal(err)
}
modelDir, mfConfig, err := ConfigFromModelfile(modelfile)
if err != nil {
t.Fatal(err)
}
if modelDir != "./model" {
t.Fatalf("modelDir = %q, want %q", modelDir, "./model")
}
if mfConfig.Template != "{{ .Prompt }}" {
t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}")
}
if got := mfConfig.Parameters["temperature"]; got != float32(0.7) {
t.Fatalf("temperature = %#v, want %v", got, float32(0.7))
}
if got := mfConfig.Parameters["stop"]; got == nil || len(got.([]string)) != 2 {
t.Fatalf("unexpected stop params: %#v", got)
}
}
func TestModelfileConfig_Empty(t *testing.T) { func TestModelfileConfig_Empty(t *testing.T) {
config := &ModelfileConfig{} config := &ModelfileConfig{}
@@ -120,6 +160,9 @@ func TestCreateOptions(t *testing.T) {
License: "MIT", License: "MIT",
Parser: "qwen3-thinking", Parser: "qwen3-thinking",
Renderer: "qwen3", Renderer: "qwen3",
Parameters: map[string]any{
"temperature": float32(0.7),
},
}, },
} }
@@ -144,6 +187,9 @@ func TestCreateOptions(t *testing.T) {
if opts.Modelfile.Renderer != "qwen3" { if opts.Modelfile.Renderer != "qwen3" {
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3") t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
} }
if opts.Modelfile.Parameters["temperature"] != float32(0.7) {
t.Errorf("Modelfile.Parameters[temperature] = %v, want %v", opts.Modelfile.Parameters["temperature"], float32(0.7))
}
} }
func TestResolveParserName(t *testing.T) { func TestResolveParserName(t *testing.T) {
@@ -252,3 +298,44 @@ func TestQuantizeSupported(t *testing.T) {
// We can't easily test both cases, so just verify it returns something // We can't easily test both cases, so just verify it returns something
_ = supported _ = supported
} }
func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
layers, err := createModelfileLayers(&ModelfileConfig{
Parameters: map[string]any{
"temperature": float32(0.7),
"stop": []string{"USER:", "ASSISTANT:"},
},
})
if err != nil {
t.Fatal(err)
}
if len(layers) != 1 {
t.Fatalf("len(layers) = %d, want 1", len(layers))
}
if layers[0].MediaType != "application/vnd.ollama.image.params" {
t.Fatalf("MediaType = %q, want %q", layers[0].MediaType, "application/vnd.ollama.image.params")
}
blobPath, err := manifest.BlobsPath(layers[0].Digest)
if err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(blobPath)
if err != nil {
t.Fatal(err)
}
var got map[string]any
if err := json.Unmarshal(data, &got); err != nil {
t.Fatal(err)
}
if got["temperature"] != float64(0.7) {
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
}
}