mirror of
https://github.com/ollama/ollama.git
synced 2026-03-27 02:58:43 +07:00
s/tensor/array/g
This commit is contained in:
20
x/mlxrunner/cache/cache.go
vendored
20
x/mlxrunner/cache/cache.go
vendored
@@ -7,8 +7,8 @@ import (
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Tensor) (newKeys, newValues *mlx.Tensor)
|
||||
State() (keys, values *mlx.Tensor)
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Offset() int
|
||||
@@ -16,16 +16,16 @@ type Cache interface {
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Tensor
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256, keys: &mlx.Tensor{}, values: &mlx.Tensor{}}
|
||||
return &KVCache{step: 256, keys: &mlx.Array{}, values: &mlx.Array{}}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
|
||||
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
@@ -56,7 +56,7 @@ func (c *KVCache) Update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) State() (*mlx.Tensor, *mlx.Tensor) {
|
||||
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.offset == c.keys.Dim(2) {
|
||||
return c.keys, c.values
|
||||
}
|
||||
@@ -94,14 +94,14 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
|
||||
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
if keys.Dim(2) > 1 {
|
||||
return c.concat(keys, values)
|
||||
}
|
||||
return c.update(keys, values)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Tensor) (newK *mlx.Tensor, newV *mlx.Tensor) {
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
if !c.keys.Valid() {
|
||||
c.keys, c.values = keys, values
|
||||
@@ -127,7 +127,7 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Tensor) (newK *mlx.Tensor, ne
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
|
||||
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
@@ -170,7 +170,7 @@ func (c *RotatingKVCache) update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Te
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() (*mlx.Tensor, *mlx.Tensor) {
|
||||
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.offset < c.keys.Dim(2) {
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
|
||||
@@ -4,7 +4,7 @@ package mlx
|
||||
import "C"
|
||||
import "math"
|
||||
|
||||
func GELUApprox(t *Tensor) *Tensor {
|
||||
func GELUApprox(t *Array) *Array {
|
||||
return t.Multiply(
|
||||
FromValue[float32](0.5),
|
||||
).Multiply(
|
||||
@@ -16,6 +16,6 @@ func GELUApprox(t *Tensor) *Tensor {
|
||||
).AsType(t.DType())
|
||||
}
|
||||
|
||||
func SILU(t *Tensor) *Tensor {
|
||||
func SILU(t *Array) *Array {
|
||||
return t.Multiply(t.Sigmoid()).AsType(t.DType())
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
type tensorDesc struct {
|
||||
name string
|
||||
inputs []*Tensor
|
||||
inputs []*Array
|
||||
numRefs int
|
||||
}
|
||||
|
||||
@@ -28,15 +28,15 @@ func (d tensorDesc) LogValue() slog.Value {
|
||||
)
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
desc tensorDesc
|
||||
}
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string, inputs ...*Tensor) *Tensor {
|
||||
t := &Tensor{
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
name: name,
|
||||
inputs: inputs,
|
||||
@@ -54,7 +54,7 @@ type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
||||
func FromValue[T scalarTypes](t T) *Tensor {
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
@@ -80,7 +80,7 @@ type arrayTypes interface {
|
||||
~complex64
|
||||
}
|
||||
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Tensor {
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
if len(shape) == 0 {
|
||||
panic("shape must be provided for non-scalar tensors")
|
||||
}
|
||||
@@ -130,13 +130,13 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Tensor {
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Tensor) Set(other *Tensor) {
|
||||
func (t *Array) Set(other *Array) {
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Tensor{other}
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
func (t *Tensor) Clone() *Tensor {
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.desc.name, t.desc.inputs...)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
@@ -144,18 +144,18 @@ func (t *Tensor) Clone() *Tensor {
|
||||
|
||||
// misc. utilities
|
||||
|
||||
func (t *Tensor) Valid() bool {
|
||||
func (t *Array) Valid() bool {
|
||||
return t.ctx.ctx != nil
|
||||
}
|
||||
|
||||
func (t *Tensor) String() string {
|
||||
func (t *Array) String() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_array_tostring(&str, t.ctx)
|
||||
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func (t *Tensor) LogValue() slog.Value {
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{slog.Any("", t.desc)}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
@@ -169,19 +169,19 @@ func (t *Tensor) LogValue() slog.Value {
|
||||
|
||||
// shape utilities
|
||||
|
||||
func (t Tensor) Size() int {
|
||||
func (t Array) Size() int {
|
||||
return int(C.mlx_array_size(t.ctx))
|
||||
}
|
||||
|
||||
func (t Tensor) NumBytes() int {
|
||||
func (t Array) NumBytes() int {
|
||||
return int(C.mlx_array_nbytes(t.ctx))
|
||||
}
|
||||
|
||||
func (t Tensor) NumDims() int {
|
||||
func (t Array) NumDims() int {
|
||||
return int(C.mlx_array_ndim(t.ctx))
|
||||
}
|
||||
|
||||
func (t Tensor) Dims() []int {
|
||||
func (t Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
@@ -190,29 +190,29 @@ func (t Tensor) Dims() []int {
|
||||
return dims
|
||||
}
|
||||
|
||||
func (t Tensor) Dim(dim int) int {
|
||||
func (t Array) Dim(dim int) int {
|
||||
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
|
||||
}
|
||||
|
||||
func (t Tensor) DType() DType {
|
||||
func (t Array) DType() DType {
|
||||
return DType(C.mlx_array_dtype(t.ctx))
|
||||
}
|
||||
|
||||
// data utilities
|
||||
|
||||
func (t Tensor) Int() int {
|
||||
func (t Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
func (t Tensor) Float() float64 {
|
||||
func (t Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
|
||||
func (t Tensor) Ints() []int {
|
||||
func (t Array) Ints() []int {
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
@@ -220,7 +220,7 @@ func (t Tensor) Ints() []int {
|
||||
return ints
|
||||
}
|
||||
|
||||
func (t Tensor) Floats() []float32 {
|
||||
func (t Array) Floats() []float32 {
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
@@ -228,7 +228,7 @@ func (t Tensor) Floats() []float32 {
|
||||
return floats
|
||||
}
|
||||
|
||||
func Free(s ...*Tensor) (n int) {
|
||||
func Free(s ...*Array) (n int) {
|
||||
now := time.Now()
|
||||
defer func() {
|
||||
if n > 0 {
|
||||
@@ -236,8 +236,8 @@ func Free(s ...*Tensor) (n int) {
|
||||
}
|
||||
}()
|
||||
|
||||
free := make([]*Tensor, 0, 8192)
|
||||
fn := func(t *Tensor) {
|
||||
free := make([]*Array, 0, 8192)
|
||||
fn := func(t *Array) {
|
||||
if t.Valid() {
|
||||
free = append(free, t.desc.inputs...)
|
||||
t.desc.numRefs--
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func ScaledDotProductAttention(query, key, value, mask *Tensor, scale float32) *Tensor {
|
||||
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
|
||||
if mask == nil {
|
||||
mask = New("")
|
||||
}
|
||||
@@ -24,21 +24,21 @@ func ScaledDotProductAttention(query, key, value, mask *Tensor, scale float32) *
|
||||
}
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight Tensor `weight:"weight"`
|
||||
Bias Tensor `weight:"bias"`
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Tensor, eps float32) *Tensor {
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight Tensor `weight:"weight"`
|
||||
Weight Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (r RMSNorm) Forward(x *Tensor, eps float32) *Tensor {
|
||||
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
@@ -51,7 +51,7 @@ type RoPE struct {
|
||||
Scale float32
|
||||
}
|
||||
|
||||
func (r RoPE) Forward(t *Tensor, offset int) *Tensor {
|
||||
func (r RoPE) Forward(t *Array, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", t, freqs)
|
||||
C.mlx_fast_rope(
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func Load(path string) iter.Seq2[string, *Tensor] {
|
||||
return func(yield func(string, *Tensor) bool) {
|
||||
func Load(path string) iter.Seq2[string, *Array] {
|
||||
return func(yield func(string, *Array) bool) {
|
||||
string2array := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(string2array)
|
||||
|
||||
@@ -40,20 +40,20 @@ func Load(path string) iter.Seq2[string, *Tensor] {
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
if !yield(name, &Tensor{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func LoadAll(root *model.Root, pattern string, states map[string]*Tensor, afterLoadFuncs []func(*model.Root) error) error {
|
||||
func LoadAll(root *model.Root, pattern string, states map[string]*Array, afterLoadFuncs []func(*model.Root) error) error {
|
||||
matches, err := root.Glob(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
weights := make(map[string]*Tensor)
|
||||
weights := make(map[string]*Array)
|
||||
for match := range matches {
|
||||
slog.Debug("Loading weights from", "file", match)
|
||||
maps.Copy(weights, maps.Collect(Load(root.JoinPath("blobs", match))))
|
||||
@@ -80,7 +80,7 @@ func LoadAll(root *model.Root, pattern string, states map[string]*Tensor, afterL
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnloadAll(states map[string]*Tensor) {
|
||||
func UnloadAll(states map[string]*Array) {
|
||||
weights := slices.Collect(maps.Values(states))
|
||||
for _, weight := range weights {
|
||||
weight.desc.numRefs = 0
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func doEval(outputs []*Tensor, async bool) {
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
vectorData := make([]C.mlx_array, 0, len(outputs))
|
||||
for _, output := range outputs {
|
||||
if output.Valid() {
|
||||
@@ -34,10 +34,10 @@ func doEval(outputs []*Tensor, async bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func AsyncEval(outputs ...*Tensor) {
|
||||
func AsyncEval(outputs ...*Array) {
|
||||
doEval(outputs, true)
|
||||
}
|
||||
|
||||
func Eval(outputs ...*Tensor) {
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package mlx
|
||||
|
||||
type Linear struct {
|
||||
Weight Tensor `weight:"weight"`
|
||||
Bias Tensor `weight:"bias"`
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation: x @ Weight.T + Bias
|
||||
func (m Linear) Forward(x *Tensor) *Tensor {
|
||||
func (m Linear) Forward(x *Array) *Array {
|
||||
w := m.Weight.Transpose(1, 0)
|
||||
if m.Bias.Valid() {
|
||||
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
@@ -16,10 +16,10 @@ func (m Linear) Forward(x *Tensor) *Tensor {
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight Tensor `weight:"weight"`
|
||||
Weight Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *Tensor) *Tensor {
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,43 +7,43 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (t *Tensor) Abs() *Tensor {
|
||||
func (t *Array) Abs() *Array {
|
||||
out := New("ABS", t)
|
||||
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Add(other *Tensor) *Tensor {
|
||||
func (t *Array) Add(other *Array) *Array {
|
||||
out := New("ADD", t, other)
|
||||
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Addmm(a, b *Tensor, alpha, beta float32) *Tensor {
|
||||
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
||||
out := New("ADDMM", t, a, b)
|
||||
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Argmax(axis int, keepDims bool) *Tensor {
|
||||
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", t)
|
||||
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) ArgpartitionAxis(kth int, axis int) *Tensor {
|
||||
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
out := New("ARGPARTITION", t)
|
||||
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) AsType(dtype DType) *Tensor {
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE", t)
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) AsStrided(shape []int, strides []int, offset int) *Tensor {
|
||||
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
@@ -65,7 +65,7 @@ func (t *Tensor) AsStrided(shape []int, strides []int, offset int) *Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Concatenate(axis int, others ...*Tensor) *Tensor {
|
||||
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
vectorData := make([]C.mlx_array, len(others)+1)
|
||||
vectorData[0] = t.ctx
|
||||
for i := range others {
|
||||
@@ -80,49 +80,49 @@ func (t *Tensor) Concatenate(axis int, others ...*Tensor) *Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) ExpandDims(axis int) *Tensor {
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS", t)
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Logsumexp(keepDims bool) *Tensor {
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", t)
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Matmul(other *Tensor) *Tensor {
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL", t, other)
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Multiply(other *Tensor) *Tensor {
|
||||
func (t *Array) Multiply(other *Array) *Array {
|
||||
out := New("MULTIPLY", t, other)
|
||||
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Negative() *Tensor {
|
||||
func (t *Array) Negative() *Array {
|
||||
out := New("NEGATIVE", t)
|
||||
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Power(exponent *Tensor) *Tensor {
|
||||
func (t *Array) Power(exponent *Array) *Array {
|
||||
out := New("POWER", t, exponent)
|
||||
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) PutAlongAxis(indices, values *Tensor, axis int) *Tensor {
|
||||
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", t, indices, values)
|
||||
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Reshape(axes ...int) *Tensor {
|
||||
func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
@@ -133,43 +133,43 @@ func (t *Tensor) Reshape(axes ...int) *Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Sigmoid() *Tensor {
|
||||
func (t *Array) Sigmoid() *Array {
|
||||
out := New("SIGMOID", t)
|
||||
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Sqrt() *Tensor {
|
||||
func (t *Array) Sqrt() *Array {
|
||||
out := New("SQRT", t)
|
||||
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Squeeze(axis int) *Tensor {
|
||||
func (t *Array) Squeeze(axis int) *Array {
|
||||
out := New("SQUEEZE", t)
|
||||
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Subtract(other *Tensor) *Tensor {
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT", t, other)
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) TakeAxis(indices *Tensor, axis int) *Tensor {
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS", t, indices)
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Tanh() *Tensor {
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH", t)
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) Transpose(axes ...int) *Tensor {
|
||||
func (t *Array) Transpose(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, axis := range axes {
|
||||
cAxes[i] = C.int(axis)
|
||||
@@ -180,7 +180,7 @@ func (t *Tensor) Transpose(axes ...int) *Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func Zeros(dtype DType, shape ...int) *Tensor {
|
||||
func Zeros(dtype DType, shape ...int) *Array {
|
||||
cAxes := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cAxes[i] = C.int(shape[i])
|
||||
|
||||
@@ -3,7 +3,7 @@ package mlx
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
func (t *Tensor) Categorical(axis int) *Tensor {
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
key := New("")
|
||||
out := New("", t, key)
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
|
||||
@@ -57,7 +57,7 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
return args[0], args[1], args[2]
|
||||
}
|
||||
|
||||
func (t *Tensor) Slice(slices ...slice) *Tensor {
|
||||
func (t *Array) Slice(slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE", t)
|
||||
C.mlx_slice(
|
||||
@@ -70,7 +70,7 @@ func (t *Tensor) Slice(slices ...slice) *Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Tensor) SliceUpdate(other *Tensor, slices ...slice) *Tensor {
|
||||
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE_UPDATE", t, other)
|
||||
C.mlx_slice_update(
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
type Model interface {
|
||||
// Forward performs a forward pass through the model.
|
||||
Forward(inputs *mlx.Tensor, cache []cache.Cache) *mlx.Tensor
|
||||
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
||||
|
||||
// NumLayers returns the number of layers in the model.
|
||||
// This is used to initialize caches.
|
||||
@@ -25,11 +25,11 @@ type Model interface {
|
||||
|
||||
type TextGeneration interface {
|
||||
Model
|
||||
Unembed(*mlx.Tensor) *mlx.Tensor
|
||||
Unembed(*mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
func Weights(m Model) (map[string]*mlx.Tensor, []func(*model.Root) error) {
|
||||
mapping := make(map[string]*mlx.Tensor)
|
||||
func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) error) {
|
||||
mapping := make(map[string]*mlx.Array)
|
||||
var afterLoadFuncs []func(*model.Root) error
|
||||
var fn func(v reflect.Value, tags []string)
|
||||
fn = func(v reflect.Value, tags []string) {
|
||||
@@ -52,9 +52,9 @@ func Weights(m Model) (map[string]*mlx.Tensor, []func(*model.Root) error) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*mlx.Tensor)(nil)).Elem() {
|
||||
if tt == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
|
||||
name := strings.Join(tags, ".")
|
||||
mapping[name] = vv.Addr().Interface().(*mlx.Tensor)
|
||||
mapping[name] = vv.Addr().Interface().(*mlx.Array)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -30,11 +30,11 @@ func (m Model) Cache() []cache.Cache {
|
||||
return caches
|
||||
}
|
||||
|
||||
func (m *Model) Forward(inputs *mlx.Tensor, cache []cache.Cache) *mlx.Tensor {
|
||||
func (m *Model) Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array {
|
||||
return m.Text.Forward(inputs, cache)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Tensor) *mlx.Tensor {
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.Text.EmbedTokens.AsLinear().Forward(x)
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ type TextModel struct {
|
||||
Options TextOptions
|
||||
}
|
||||
|
||||
func (m TextModel) Forward(inputs *mlx.Tensor, caches []cache.Cache) *mlx.Tensor {
|
||||
func (m TextModel) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := inputs.Dim(0), inputs.Dim(1)
|
||||
hiddenStates := m.EmbedTokens.Forward(inputs)
|
||||
|
||||
@@ -53,7 +53,7 @@ type TextDecoderLayer struct {
|
||||
PostFFNorm RMSNorm `weight:"post_feedforward_layernorm"`
|
||||
}
|
||||
|
||||
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Tensor, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Tensor {
|
||||
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.InputNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.Attention.Forward(hiddenStates, cache, B, L, rope, opts)
|
||||
@@ -77,7 +77,7 @@ type TextAttention struct {
|
||||
OProj mlx.Linear `weight:"o_proj"`
|
||||
}
|
||||
|
||||
func (m TextAttention) Forward(hiddenStates *mlx.Tensor, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Tensor {
|
||||
func (m TextAttention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
|
||||
query := m.QProj.Forward(hiddenStates)
|
||||
key := m.KProj.Forward(hiddenStates)
|
||||
value := m.VProj.Forward(hiddenStates)
|
||||
@@ -113,6 +113,6 @@ type TextMLP struct {
|
||||
DownProj mlx.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
func (m TextMLP) Forward(h *mlx.Tensor, opts TextOptions) *mlx.Tensor {
|
||||
func (m TextMLP) Forward(h *mlx.Array, opts TextOptions) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.GELUApprox(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h)))
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func (m Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m Model) Forward(inputs *mlx.Tensor, caches []cache.Cache) *mlx.Tensor {
|
||||
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
slog.Debug("Model.forward", "input shape", inputs.Dims(), "m.EmbedTokens", m.EmbedTokens.Weight.Dims())
|
||||
B, L := inputs.Dim(0), inputs.Dim(1)
|
||||
hiddenStates := m.EmbedTokens.Forward(inputs)
|
||||
@@ -56,7 +56,7 @@ type Layer struct {
|
||||
MLP MLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
func (m Layer) Forward(hiddenStates *mlx.Tensor, c cache.Cache, B, L int, opts Options) *mlx.Tensor {
|
||||
func (m Layer) Forward(hiddenStates *mlx.Array, c cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.AttentionNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.Attention.Forward(hiddenStates, c, B, L, opts)
|
||||
@@ -76,7 +76,7 @@ type Attention struct {
|
||||
OutputProj mlx.Linear `weight:"o_proj"`
|
||||
}
|
||||
|
||||
func (m Attention) Forward(hiddenStates *mlx.Tensor, cache cache.Cache, B, L int, opts Options) *mlx.Tensor {
|
||||
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
query := m.QueryProj.Forward(hiddenStates)
|
||||
query = query.Reshape(B, L, opts.NumAttentionHeads, -1).Transpose(0, 2, 1, 3)
|
||||
|
||||
@@ -101,7 +101,7 @@ type MLP struct {
|
||||
Down mlx.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
func (m MLP) Forward(h *mlx.Tensor) *mlx.Tensor {
|
||||
func (m MLP) Forward(h *mlx.Array) *mlx.Array {
|
||||
return m.Down.Forward(mlx.SILU(m.Gate.Forward(h)).Multiply(m.Up.Forward(h)))
|
||||
}
|
||||
|
||||
|
||||
@@ -42,8 +42,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
n := min(2<<10, total-processed-1)
|
||||
temp := model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
defer mlx.Free(temp)
|
||||
mlx.Eval(func() []*mlx.Tensor {
|
||||
s := make([]*mlx.Tensor, 2*len(caches))
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
@@ -54,7 +54,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
step := func(token *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
logits := model.Unembed(model.Forward(token.ExpandDims(0), caches))
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type Sampler interface {
|
||||
Sample(*mlx.Tensor) *mlx.Tensor
|
||||
Sample(*mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
func New(temp, top_p, min_p float32, top_k int) Sampler {
|
||||
@@ -34,13 +34,13 @@ func New(temp, top_p, min_p float32, top_k int) Sampler {
|
||||
|
||||
type greedy struct{}
|
||||
|
||||
func (greedy) Sample(logits *mlx.Tensor) *mlx.Tensor {
|
||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Argmax(-1, false)
|
||||
}
|
||||
|
||||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *mlx.Tensor) *mlx.Tensor {
|
||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
for _, sampler := range c {
|
||||
logits = sampler.Sample(logits)
|
||||
}
|
||||
@@ -49,27 +49,27 @@ func (c chain) Sample(logits *mlx.Tensor) *mlx.Tensor {
|
||||
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Tensor) *mlx.Tensor {
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
|
||||
}
|
||||
|
||||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logprobs *mlx.Tensor) *mlx.Tensor {
|
||||
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
}
|
||||
|
||||
type MinP float32
|
||||
|
||||
func (p MinP) Sample(logprobs *mlx.Tensor) *mlx.Tensor {
|
||||
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
func (k TopK) Sample(logprobs *mlx.Tensor) *mlx.Tensor {
|
||||
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
|
||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user