mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
quant
This commit is contained in:
@@ -294,9 +294,9 @@ func (r Root) Glob(pattern string) (iter.Seq[string], error) {
|
||||
}
|
||||
|
||||
return func(yield func(string) bool) {
|
||||
for name, blob := range r.blobs {
|
||||
for name := range r.blobs {
|
||||
if matched, _ := filepath.Match(pattern, name); matched {
|
||||
if !yield(blob.Filepath()) {
|
||||
if !yield(name) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -307,3 +307,10 @@ func (r Root) Glob(pattern string) (iter.Seq[string], error) {
|
||||
func (r Root) JoinPath(parts ...string) string {
|
||||
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
|
||||
}
|
||||
|
||||
func (r Root) Real(name string) string {
|
||||
if b, ok := r.blobs[name]; ok {
|
||||
return b.Filepath()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -4,6 +4,10 @@ package mlx
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
@@ -47,24 +51,78 @@ func Load(path string) iter.Seq2[string, *Array] {
|
||||
}
|
||||
}
|
||||
|
||||
func LoadAll(root *model.Root, pattern string, states map[string]*Array, afterLoadFuncs []func(*model.Root) ([]*Array, error)) error {
|
||||
matches, err := root.Glob(pattern)
|
||||
func Parse(root *model.Root, path string) (map[string]Quantization, error) {
|
||||
f, err := root.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var n uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bts := make([]byte, n)
|
||||
if _, err := io.ReadFull(f, bts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var m struct {
|
||||
Metadata struct {
|
||||
Quantization map[string]Quantization `json:"quantization"`
|
||||
} `json:"__metadata__"`
|
||||
}
|
||||
if err := json.Unmarshal(bts, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.Metadata.Quantization, nil
|
||||
}
|
||||
|
||||
func LoadWeights(root *model.Root, match string, states map[string]*Array) error {
|
||||
slog.Debug("Loading weights from", "file", match)
|
||||
for name, weight := range Load(root.JoinPath("blobs", root.Real(match))) {
|
||||
if state, ok := states[name]; ok {
|
||||
*state = *weight
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadQuantizations(root *model.Root, match string, quantizations map[string]*Quantization) error {
|
||||
slog.Debug("Loading quantizations from", "file", match)
|
||||
metadata, err := Parse(root, match)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
weights := make(map[string]*Array)
|
||||
for match := range matches {
|
||||
slog.Debug("Loading weights from", "file", match)
|
||||
maps.Copy(weights, maps.Collect(Load(root.JoinPath("blobs", match))))
|
||||
for name := range metadata {
|
||||
if q, ok := quantizations[name+".weight"]; ok {
|
||||
q.GroupSize = metadata[name].GroupSize
|
||||
q.Bits = metadata[name].Bits
|
||||
q.Mode = metadata[name].Mode
|
||||
}
|
||||
}
|
||||
|
||||
var numBytes int
|
||||
for name, weight := range states {
|
||||
if _, ok := weights[name]; ok {
|
||||
slog.Debug("Loading weight", "name", name, "weight", weight)
|
||||
*weight = *weights[name]
|
||||
numBytes += weight.NumBytes()
|
||||
return nil
|
||||
}
|
||||
|
||||
type AfterLoadFunc func(*model.Root) ([]*Array, error)
|
||||
|
||||
func LoadAll(root *model.Root, states map[string]*Array, quantizations map[string]*Quantization, afterLoadFuncs []AfterLoadFunc) error {
|
||||
matches, err := root.Glob("model*.safetensors")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for match := range matches {
|
||||
if err := errors.Join(
|
||||
LoadWeights(root, match, states),
|
||||
LoadQuantizations(root, match, quantizations),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +151,7 @@ func LoadAll(root *model.Root, pattern string, states map[string]*Array, afterLo
|
||||
|
||||
Eval(slices.Collect(maps.Values(states))...)
|
||||
ClearCache()
|
||||
slog.Info("Loaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
|
||||
slog.Info("Loaded weights", "count", len(states), "memory", Memory{})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,40 @@
|
||||
package mlx
|
||||
|
||||
import "cmp"
|
||||
|
||||
type Quantization struct {
|
||||
Scales Array `weight:"scales"`
|
||||
Biases Array `weight:"biases"`
|
||||
GroupSize int `json:"group_size"`
|
||||
Bits int `json:"bits"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
type Linear struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
|
||||
Quantization
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation: x @ Weight.T + Bias
|
||||
func (m Linear) Forward(x *Array) *Array {
|
||||
if m.Scales.Valid() {
|
||||
x = x.QuantizedMatmul(
|
||||
&m.Weight,
|
||||
&m.Scales,
|
||||
&m.Biases,
|
||||
true,
|
||||
m.GroupSize,
|
||||
m.Bits,
|
||||
cmp.Or(m.Mode, "affine"),
|
||||
)
|
||||
if m.Bias.Valid() {
|
||||
x = m.Bias.Add(x)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
w := m.Weight.Transpose(1, 0)
|
||||
if m.Bias.Valid() {
|
||||
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
@@ -16,21 +44,59 @@ func (m Linear) Forward(x *Array) *Array {
|
||||
}
|
||||
|
||||
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
||||
w := m.Weight.Transpose(0, 2, 1)
|
||||
// TODO: bias
|
||||
return x.GatherMM(w, lhs, rhs, sorted)
|
||||
if m.Scales.Valid() {
|
||||
x = x.GatherQMM(
|
||||
&m.Weight,
|
||||
&m.Scales,
|
||||
&m.Biases,
|
||||
lhs,
|
||||
rhs,
|
||||
sorted,
|
||||
m.GroupSize,
|
||||
m.Bits,
|
||||
cmp.Or(m.Mode, "affine"),
|
||||
sorted,
|
||||
)
|
||||
if m.Bias.Valid() {
|
||||
x = m.Bias.Add(x)
|
||||
}
|
||||
return x
|
||||
} else {
|
||||
w := m.Weight.Transpose(0, 2, 1)
|
||||
x = x.GatherMM(w, lhs, rhs, sorted)
|
||||
}
|
||||
|
||||
if m.Bias.Valid() {
|
||||
x = m.Bias.Add(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight Array `weight:"weight"`
|
||||
|
||||
Quantization
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
if e.Scales.Valid() {
|
||||
w := e.Weight.TakeAxis(indices, 0)
|
||||
return w.Dequantize(
|
||||
e.Scales.TakeAxis(indices, 0),
|
||||
e.Biases.TakeAxis(indices, 0),
|
||||
e.GroupSize,
|
||||
e.Bits,
|
||||
cmp.Or(e.Mode, "affine"),
|
||||
)
|
||||
}
|
||||
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
func (e *Embedding) AsLinear() Linear {
|
||||
return Linear{
|
||||
Weight: e.Weight,
|
||||
Weight: e.Weight,
|
||||
Quantization: e.Quantization,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +91,33 @@ func (t *Array) Divide(other *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Dequantize(scales, biases *Array, groupSize, bits int, mode string) *Array {
|
||||
out := New("DEQUANTIZE", t, scales, biases)
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
C.mlx_dequantize(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
scales.ctx,
|
||||
biases.ctx,
|
||||
C.mlx_optional_int{
|
||||
value: C.int(groupSize),
|
||||
has_value: C.bool(groupSize > 0),
|
||||
},
|
||||
C.mlx_optional_int{
|
||||
value: C.int(bits),
|
||||
has_value: C.bool(bits > 0),
|
||||
},
|
||||
cMode,
|
||||
C.mlx_optional_dtype{
|
||||
has_value: false,
|
||||
},
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS", t)
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
@@ -121,6 +148,40 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) GatherQMM(weight, scales, biases, lhs, rhs *Array, transpose bool, groupSize, bits int, mode string, sorted bool) *Array {
|
||||
if lhs == nil {
|
||||
lhs = New("")
|
||||
}
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_QMM", t, weight, scales, biases, lhs, rhs)
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
C.mlx_gather_qmm(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
weight.ctx,
|
||||
scales.ctx,
|
||||
biases.ctx,
|
||||
lhs.ctx,
|
||||
rhs.ctx,
|
||||
C.bool(transpose),
|
||||
C.mlx_optional_int{
|
||||
value: C.int(groupSize),
|
||||
has_value: C.bool(groupSize > 0),
|
||||
},
|
||||
C.mlx_optional_int{
|
||||
value: C.int(bits),
|
||||
has_value: C.bool(bits > 0),
|
||||
},
|
||||
cMode,
|
||||
C.bool(sorted),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", t)
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
@@ -157,6 +218,32 @@ func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) QuantizedMatmul(weight, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
|
||||
out := New("QUANTIZED_MATMUL", t, weight, scales, biases)
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
C.mlx_quantized_matmul(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
weight.ctx,
|
||||
scales.ctx,
|
||||
biases.ctx,
|
||||
C.bool(transpose),
|
||||
C.mlx_optional_int{
|
||||
value: C.int(groupSize),
|
||||
has_value: C.bool(groupSize > 0),
|
||||
},
|
||||
C.mlx_optional_int{
|
||||
value: C.int(bits),
|
||||
has_value: C.bool(bits > 0),
|
||||
},
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i := range axes {
|
||||
|
||||
@@ -28,19 +28,28 @@ type TextGeneration interface {
|
||||
Unembed(*mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) ([]*mlx.Array, error)) {
|
||||
mapping := make(map[string]*mlx.Array)
|
||||
var afterLoadFuncs []func(*model.Root) ([]*mlx.Array, error)
|
||||
func Walk(m Model) (map[string]*mlx.Array, map[string]*mlx.Quantization, []mlx.AfterLoadFunc) {
|
||||
weights := make(map[string]*mlx.Array)
|
||||
quantizations := make(map[string]*mlx.Quantization)
|
||||
var afterLoadFuncs []mlx.AfterLoadFunc
|
||||
var fn func(v reflect.Value, tags []string)
|
||||
fn = func(v reflect.Value, tags []string) {
|
||||
t := v.Type()
|
||||
|
||||
if method := v.Addr().MethodByName("AfterLoad"); method.IsValid() {
|
||||
var afterLoadFunc func(*model.Root) ([]*mlx.Array, error)
|
||||
var afterLoadFunc mlx.AfterLoadFunc
|
||||
reflect.ValueOf(&afterLoadFunc).Elem().Set(method)
|
||||
afterLoadFuncs = append(afterLoadFuncs, afterLoadFunc)
|
||||
}
|
||||
|
||||
if t == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
|
||||
name := strings.Join(tags, ".")
|
||||
weights[name] = v.Addr().Interface().(*mlx.Array)
|
||||
return
|
||||
} else if t == reflect.TypeOf((*mlx.Quantization)(nil)).Elem() {
|
||||
quantizations[strings.Join(tags, ".")] = v.Addr().Interface().(*mlx.Quantization)
|
||||
}
|
||||
|
||||
for _, field := range reflect.VisibleFields(t) {
|
||||
if field.IsExported() {
|
||||
tt, vv := field.Type, v.FieldByIndex(field.Index)
|
||||
@@ -52,12 +61,6 @@ func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) ([]*mlx.Array,
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
|
||||
name := strings.Join(tags, ".")
|
||||
mapping[name] = vv.Addr().Interface().(*mlx.Array)
|
||||
continue
|
||||
}
|
||||
|
||||
switch tt.Kind() {
|
||||
case reflect.Interface:
|
||||
vv = vv.Elem()
|
||||
@@ -76,7 +79,7 @@ func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) ([]*mlx.Array,
|
||||
}
|
||||
}
|
||||
fn(reflect.ValueOf(m).Elem(), []string{})
|
||||
return mapping, afterLoadFuncs
|
||||
return weights, quantizations, afterLoadFuncs
|
||||
}
|
||||
|
||||
var m = make(map[string]func(*model.Root) (Model, error))
|
||||
|
||||
@@ -79,8 +79,8 @@ func (r *Runner) Load(name model.Name) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
weights, afterLoadFuncs := base.Weights(r.Model)
|
||||
return mlx.LoadAll(root, "model*.safetensors", weights, afterLoadFuncs)
|
||||
weights, quantizations, afterLoadFuncs := base.Walk(r.Model)
|
||||
return mlx.LoadAll(root, weights, quantizations, afterLoadFuncs)
|
||||
}
|
||||
|
||||
func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
|
||||
Reference in New Issue
Block a user