This commit is contained in:
Michael Yang
2026-02-06 09:50:05 -08:00
parent 20299cb1da
commit bd5d3b0ebd
6 changed files with 253 additions and 32 deletions

View File

@@ -294,9 +294,9 @@ func (r Root) Glob(pattern string) (iter.Seq[string], error) {
}
return func(yield func(string) bool) {
for name, blob := range r.blobs {
for name := range r.blobs {
if matched, _ := filepath.Match(pattern, name); matched {
if !yield(blob.Filepath()) {
if !yield(name) {
return
}
}
@@ -307,3 +307,10 @@ func (r Root) Glob(pattern string) (iter.Seq[string], error) {
func (r Root) JoinPath(parts ...string) string {
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
}
func (r Root) Real(name string) string {
if b, ok := r.blobs[name]; ok {
return b.Filepath()
}
return ""
}

View File

@@ -4,6 +4,10 @@ package mlx
import "C"
import (
"encoding/binary"
"encoding/json"
"errors"
"io"
"iter"
"log/slog"
"maps"
@@ -47,24 +51,78 @@ func Load(path string) iter.Seq2[string, *Array] {
}
}
func LoadAll(root *model.Root, pattern string, states map[string]*Array, afterLoadFuncs []func(*model.Root) ([]*Array, error)) error {
matches, err := root.Glob(pattern)
func Parse(root *model.Root, path string) (map[string]Quantization, error) {
f, err := root.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var n uint64
if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
return nil, err
}
bts := make([]byte, n)
if _, err := io.ReadFull(f, bts); err != nil {
return nil, err
}
var m struct {
Metadata struct {
Quantization map[string]Quantization `json:"quantization"`
} `json:"__metadata__"`
}
if err := json.Unmarshal(bts, &m); err != nil {
return nil, err
}
return m.Metadata.Quantization, nil
}
func LoadWeights(root *model.Root, match string, states map[string]*Array) error {
slog.Debug("Loading weights from", "file", match)
for name, weight := range Load(root.JoinPath("blobs", root.Real(match))) {
if state, ok := states[name]; ok {
*state = *weight
}
}
return nil
}
func LoadQuantizations(root *model.Root, match string, quantizations map[string]*Quantization) error {
slog.Debug("Loading quantizations from", "file", match)
metadata, err := Parse(root, match)
if err != nil {
return err
}
weights := make(map[string]*Array)
for match := range matches {
slog.Debug("Loading weights from", "file", match)
maps.Copy(weights, maps.Collect(Load(root.JoinPath("blobs", match))))
for name := range metadata {
if q, ok := quantizations[name+".weight"]; ok {
q.GroupSize = metadata[name].GroupSize
q.Bits = metadata[name].Bits
q.Mode = metadata[name].Mode
}
}
var numBytes int
for name, weight := range states {
if _, ok := weights[name]; ok {
slog.Debug("Loading weight", "name", name, "weight", weight)
*weight = *weights[name]
numBytes += weight.NumBytes()
return nil
}
type AfterLoadFunc func(*model.Root) ([]*Array, error)
func LoadAll(root *model.Root, states map[string]*Array, quantizations map[string]*Quantization, afterLoadFuncs []AfterLoadFunc) error {
matches, err := root.Glob("model*.safetensors")
if err != nil {
return err
}
for match := range matches {
if err := errors.Join(
LoadWeights(root, match, states),
LoadQuantizations(root, match, quantizations),
); err != nil {
return err
}
}
@@ -93,7 +151,7 @@ func LoadAll(root *model.Root, pattern string, states map[string]*Array, afterLo
Eval(slices.Collect(maps.Values(states))...)
ClearCache()
slog.Info("Loaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
slog.Info("Loaded weights", "count", len(states), "memory", Memory{})
return nil
}

View File

@@ -1,12 +1,40 @@
package mlx
import "cmp"
type Quantization struct {
Scales Array `weight:"scales"`
Biases Array `weight:"biases"`
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
Mode string `json:"mode"`
}
type Linear struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
Quantization
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Array) *Array {
if m.Scales.Valid() {
x = x.QuantizedMatmul(
&m.Weight,
&m.Scales,
&m.Biases,
true,
m.GroupSize,
m.Bits,
cmp.Or(m.Mode, "affine"),
)
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
}
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
@@ -16,21 +44,59 @@ func (m Linear) Forward(x *Array) *Array {
}
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
w := m.Weight.Transpose(0, 2, 1)
// TODO: bias
return x.GatherMM(w, lhs, rhs, sorted)
if m.Scales.Valid() {
x = x.GatherQMM(
&m.Weight,
&m.Scales,
&m.Biases,
lhs,
rhs,
sorted,
m.GroupSize,
m.Bits,
cmp.Or(m.Mode, "affine"),
sorted,
)
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
} else {
w := m.Weight.Transpose(0, 2, 1)
x = x.GatherMM(w, lhs, rhs, sorted)
}
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
}
type Embedding struct {
Weight Array `weight:"weight"`
Quantization
}
func (e *Embedding) Forward(indices *Array) *Array {
if e.Scales.Valid() {
w := e.Weight.TakeAxis(indices, 0)
return w.Dequantize(
e.Scales.TakeAxis(indices, 0),
e.Biases.TakeAxis(indices, 0),
e.GroupSize,
e.Bits,
cmp.Or(e.Mode, "affine"),
)
}
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() Linear {
return Linear{
Weight: e.Weight,
Weight: e.Weight,
Quantization: e.Quantization,
}
}

View File

@@ -91,6 +91,33 @@ func (t *Array) Divide(other *Array) *Array {
return out
}
func (t *Array) Dequantize(scales, biases *Array, groupSize, bits int, mode string) *Array {
out := New("DEQUANTIZE", t, scales, biases)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_dequantize(
&out.ctx,
t.ctx,
scales.ctx,
biases.ctx,
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
C.mlx_optional_dtype{
has_value: false,
},
DefaultStream().ctx,
)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
@@ -121,6 +148,40 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
return out
}
func (t *Array) GatherQMM(weight, scales, biases, lhs, rhs *Array, transpose bool, groupSize, bits int, mode string, sorted bool) *Array {
if lhs == nil {
lhs = New("")
}
if rhs == nil {
rhs = New("")
}
out := New("GATHER_QMM", t, weight, scales, biases, lhs, rhs)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_gather_qmm(
&out.ctx,
t.ctx,
weight.ctx,
scales.ctx,
biases.ctx,
lhs.ctx,
rhs.ctx,
C.bool(transpose),
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
C.bool(sorted),
DefaultStream().ctx,
)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
@@ -157,6 +218,32 @@ func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
return out
}
func (t *Array) QuantizedMatmul(weight, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
out := New("QUANTIZED_MATMUL", t, weight, scales, biases)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_quantized_matmul(
&out.ctx,
t.ctx,
weight.ctx,
scales.ctx,
biases.ctx,
C.bool(transpose),
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
DefaultStream().ctx,
)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {

View File

@@ -28,19 +28,28 @@ type TextGeneration interface {
Unembed(*mlx.Array) *mlx.Array
}
func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) ([]*mlx.Array, error)) {
mapping := make(map[string]*mlx.Array)
var afterLoadFuncs []func(*model.Root) ([]*mlx.Array, error)
func Walk(m Model) (map[string]*mlx.Array, map[string]*mlx.Quantization, []mlx.AfterLoadFunc) {
weights := make(map[string]*mlx.Array)
quantizations := make(map[string]*mlx.Quantization)
var afterLoadFuncs []mlx.AfterLoadFunc
var fn func(v reflect.Value, tags []string)
fn = func(v reflect.Value, tags []string) {
t := v.Type()
if method := v.Addr().MethodByName("AfterLoad"); method.IsValid() {
var afterLoadFunc func(*model.Root) ([]*mlx.Array, error)
var afterLoadFunc mlx.AfterLoadFunc
reflect.ValueOf(&afterLoadFunc).Elem().Set(method)
afterLoadFuncs = append(afterLoadFuncs, afterLoadFunc)
}
if t == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
name := strings.Join(tags, ".")
weights[name] = v.Addr().Interface().(*mlx.Array)
return
} else if t == reflect.TypeOf((*mlx.Quantization)(nil)).Elem() {
quantizations[strings.Join(tags, ".")] = v.Addr().Interface().(*mlx.Quantization)
}
for _, field := range reflect.VisibleFields(t) {
if field.IsExported() {
tt, vv := field.Type, v.FieldByIndex(field.Index)
@@ -52,12 +61,6 @@ func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) ([]*mlx.Array,
tags = append(tags, tag)
}
if tt == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
name := strings.Join(tags, ".")
mapping[name] = vv.Addr().Interface().(*mlx.Array)
continue
}
switch tt.Kind() {
case reflect.Interface:
vv = vv.Elem()
@@ -76,7 +79,7 @@ func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) ([]*mlx.Array,
}
}
fn(reflect.ValueOf(m).Elem(), []string{})
return mapping, afterLoadFuncs
return weights, quantizations, afterLoadFuncs
}
var m = make(map[string]func(*model.Root) (Model, error))

View File

@@ -79,8 +79,8 @@ func (r *Runner) Load(name model.Name) (err error) {
return err
}
weights, afterLoadFuncs := base.Weights(r.Model)
return mlx.LoadAll(root, "model*.safetensors", weights, afterLoadFuncs)
weights, quantizations, afterLoadFuncs := base.Walk(r.Model)
return mlx.LoadAll(root, weights, quantizations, afterLoadFuncs)
}
func (r *Runner) Run(host, port string, mux http.Handler) error {