s/tensor/array/g

This commit is contained in:
Michael Yang
2026-02-05 20:47:04 -08:00
parent 5287acdb21
commit e19fbe7369
17 changed files with 111 additions and 111 deletions

View File

@@ -7,8 +7,8 @@ import (
)
type Cache interface {
Update(keys, values *mlx.Tensor) (newKeys, newValues *mlx.Tensor)
State() (keys, values *mlx.Tensor)
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array)
Trim(int) int
Clone() Cache
Offset() int
@@ -16,16 +16,16 @@ type Cache interface {
}
type KVCache struct {
keys, values *mlx.Tensor
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256, keys: &mlx.Tensor{}, values: &mlx.Tensor{}}
return &KVCache{step: 256, keys: &mlx.Array{}, values: &mlx.Array{}}
}
func (c *KVCache) Update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
@@ -56,7 +56,7 @@ func (c *KVCache) Update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) State() (*mlx.Tensor, *mlx.Tensor) {
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset == c.keys.Dim(2) {
return c.keys, c.values
}
@@ -94,14 +94,14 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
}
func (c *RotatingKVCache) Update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
if keys.Dim(2) > 1 {
return c.concat(keys, values)
}
return c.update(keys, values)
}
func (c *RotatingKVCache) concat(keys, values *mlx.Tensor) (newK *mlx.Tensor, newV *mlx.Tensor) {
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if !c.keys.Valid() {
c.keys, c.values = keys, values
@@ -127,7 +127,7 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Tensor) (newK *mlx.Tensor, ne
return c.keys, c.values
}
func (c *RotatingKVCache) update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
@@ -170,7 +170,7 @@ func (c *RotatingKVCache) update(keys, values *mlx.Tensor) (*mlx.Tensor, *mlx.Te
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
}
func (c *RotatingKVCache) State() (*mlx.Tensor, *mlx.Tensor) {
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset < c.keys.Dim(2) {
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())

View File

@@ -4,7 +4,7 @@ package mlx
import "C"
import "math"
func GELUApprox(t *Tensor) *Tensor {
func GELUApprox(t *Array) *Array {
return t.Multiply(
FromValue[float32](0.5),
).Multiply(
@@ -16,6 +16,6 @@ func GELUApprox(t *Tensor) *Tensor {
).AsType(t.DType())
}
func SILU(t *Tensor) *Tensor {
func SILU(t *Array) *Array {
return t.Multiply(t.Sigmoid()).AsType(t.DType())
}

View File

@@ -16,7 +16,7 @@ import (
type tensorDesc struct {
name string
inputs []*Tensor
inputs []*Array
numRefs int
}
@@ -28,15 +28,15 @@ func (d tensorDesc) LogValue() slog.Value {
)
}
type Tensor struct {
type Array struct {
ctx C.mlx_array
desc tensorDesc
}
// constructor utilities
func New(name string, inputs ...*Tensor) *Tensor {
t := &Tensor{
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
@@ -54,7 +54,7 @@ type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
func FromValue[T scalarTypes](t T) *Tensor {
func FromValue[T scalarTypes](t T) *Array {
tt := New("")
switch v := any(t).(type) {
case bool:
@@ -80,7 +80,7 @@ type arrayTypes interface {
~complex64
}
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Tensor {
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
if len(shape) == 0 {
panic("shape must be provided for non-scalar tensors")
}
@@ -130,13 +130,13 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Tensor {
return tt
}
func (t *Tensor) Set(other *Tensor) {
func (t *Array) Set(other *Array) {
other.desc.numRefs++
t.desc.inputs = []*Tensor{other}
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Tensor) Clone() *Tensor {
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
@@ -144,18 +144,18 @@ func (t *Tensor) Clone() *Tensor {
// misc. utilities
func (t *Tensor) Valid() bool {
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
func (t *Tensor) String() string {
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
func (t *Tensor) LogValue() slog.Value {
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{slog.Any("", t.desc)}
if t.Valid() {
attrs = append(attrs,
@@ -169,19 +169,19 @@ func (t *Tensor) LogValue() slog.Value {
// shape utilities
func (t Tensor) Size() int {
func (t Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t Tensor) NumBytes() int {
func (t Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t Tensor) NumDims() int {
func (t Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t Tensor) Dims() []int {
func (t Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
@@ -190,29 +190,29 @@ func (t Tensor) Dims() []int {
return dims
}
func (t Tensor) Dim(dim int) int {
func (t Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t Tensor) DType() DType {
func (t Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t Tensor) Int() int {
func (t Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t Tensor) Float() float64 {
func (t Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t Tensor) Ints() []int {
func (t Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
@@ -220,7 +220,7 @@ func (t Tensor) Ints() []int {
return ints
}
func (t Tensor) Floats() []float32 {
func (t Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
@@ -228,7 +228,7 @@ func (t Tensor) Floats() []float32 {
return floats
}
func Free(s ...*Tensor) (n int) {
func Free(s ...*Array) (n int) {
now := time.Now()
defer func() {
if n > 0 {
@@ -236,8 +236,8 @@ func Free(s ...*Tensor) (n int) {
}
}()
free := make([]*Tensor, 0, 8192)
fn := func(t *Tensor) {
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
free = append(free, t.desc.inputs...)
t.desc.numRefs--

View File

@@ -7,7 +7,7 @@ import (
"unsafe"
)
func ScaledDotProductAttention(query, key, value, mask *Tensor, scale float32) *Tensor {
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
if mask == nil {
mask = New("")
}
@@ -24,21 +24,21 @@ func ScaledDotProductAttention(query, key, value, mask *Tensor, scale float32) *
}
type LayerNorm struct {
Weight Tensor `weight:"weight"`
Bias Tensor `weight:"bias"`
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Tensor, eps float32) *Tensor {
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight Tensor `weight:"weight"`
Weight Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Tensor, eps float32) *Tensor {
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
@@ -51,7 +51,7 @@ type RoPE struct {
Scale float32
}
func (r RoPE) Forward(t *Tensor, offset int) *Tensor {
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", t, freqs)
C.mlx_fast_rope(

View File

@@ -13,8 +13,8 @@ import (
"github.com/ollama/ollama/types/model"
)
func Load(path string) iter.Seq2[string, *Tensor] {
return func(yield func(string, *Tensor) bool) {
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
@@ -40,20 +40,20 @@ func Load(path string) iter.Seq2[string, *Tensor] {
}
name := C.GoString(key)
if !yield(name, &Tensor{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
break
}
}
}
}
func LoadAll(root *model.Root, pattern string, states map[string]*Tensor, afterLoadFuncs []func(*model.Root) error) error {
func LoadAll(root *model.Root, pattern string, states map[string]*Array, afterLoadFuncs []func(*model.Root) error) error {
matches, err := root.Glob(pattern)
if err != nil {
return err
}
weights := make(map[string]*Tensor)
weights := make(map[string]*Array)
for match := range matches {
slog.Debug("Loading weights from", "file", match)
maps.Copy(weights, maps.Collect(Load(root.JoinPath("blobs", match))))
@@ -80,7 +80,7 @@ func LoadAll(root *model.Root, pattern string, states map[string]*Tensor, afterL
return nil
}
func UnloadAll(states map[string]*Tensor) {
func UnloadAll(states map[string]*Array) {
weights := slices.Collect(maps.Values(states))
for _, weight := range weights {
weight.desc.numRefs = 0

View File

@@ -16,7 +16,7 @@ import (
"unsafe"
)
func doEval(outputs []*Tensor, async bool) {
func doEval(outputs []*Array, async bool) {
vectorData := make([]C.mlx_array, 0, len(outputs))
for _, output := range outputs {
if output.Valid() {
@@ -34,10 +34,10 @@ func doEval(outputs []*Tensor, async bool) {
}
}
func AsyncEval(outputs ...*Tensor) {
func AsyncEval(outputs ...*Array) {
doEval(outputs, true)
}
func Eval(outputs ...*Tensor) {
func Eval(outputs ...*Array) {
doEval(outputs, false)
}

View File

@@ -1,12 +1,12 @@
package mlx
type Linear struct {
Weight Tensor `weight:"weight"`
Bias Tensor `weight:"bias"`
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Tensor) *Tensor {
func (m Linear) Forward(x *Array) *Array {
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
@@ -16,10 +16,10 @@ func (m Linear) Forward(x *Tensor) *Tensor {
}
type Embedding struct {
Weight Tensor `weight:"weight"`
Weight Array `weight:"weight"`
}
func (e *Embedding) Forward(indices *Tensor) *Tensor {
func (e *Embedding) Forward(indices *Array) *Array {
return e.Weight.TakeAxis(indices, 0)
}

View File

@@ -7,43 +7,43 @@ import (
"unsafe"
)
func (t *Tensor) Abs() *Tensor {
func (t *Array) Abs() *Array {
out := New("ABS", t)
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) Add(other *Tensor) *Tensor {
func (t *Array) Add(other *Array) *Array {
out := New("ADD", t, other)
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) Addmm(a, b *Tensor, alpha, beta float32) *Tensor {
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM", t, a, b)
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Tensor) Argmax(axis int, keepDims bool) *Tensor {
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX", t)
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Tensor) ArgpartitionAxis(kth int, axis int) *Tensor {
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION", t)
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Tensor) AsType(dtype DType) *Tensor {
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE", t)
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func (t *Tensor) AsStrided(shape []int, strides []int, offset int) *Tensor {
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
@@ -65,7 +65,7 @@ func (t *Tensor) AsStrided(shape []int, strides []int, offset int) *Tensor {
return out
}
func (t *Tensor) Concatenate(axis int, others ...*Tensor) *Tensor {
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
vectorData := make([]C.mlx_array, len(others)+1)
vectorData[0] = t.ctx
for i := range others {
@@ -80,49 +80,49 @@ func (t *Tensor) Concatenate(axis int, others ...*Tensor) *Tensor {
return out
}
func (t *Tensor) ExpandDims(axis int) *Tensor {
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Tensor) Logsumexp(keepDims bool) *Tensor {
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Tensor) Matmul(other *Tensor) *Tensor {
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL", t, other)
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) Multiply(other *Tensor) *Tensor {
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY", t, other)
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) Negative() *Tensor {
func (t *Array) Negative() *Array {
out := New("NEGATIVE", t)
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) Power(exponent *Tensor) *Tensor {
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER", t, exponent)
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) PutAlongAxis(indices, values *Tensor, axis int) *Tensor {
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", t, indices, values)
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Tensor) Reshape(axes ...int) *Tensor {
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
@@ -133,43 +133,43 @@ func (t *Tensor) Reshape(axes ...int) *Tensor {
return out
}
func (t *Tensor) Sigmoid() *Tensor {
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID", t)
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) Sqrt() *Tensor {
func (t *Array) Sqrt() *Array {
out := New("SQRT", t)
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) Squeeze(axis int) *Tensor {
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE", t)
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Tensor) Subtract(other *Tensor) *Tensor {
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT", t, other)
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) TakeAxis(indices *Tensor, axis int) *Tensor {
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS", t, indices)
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Tensor) Tanh() *Tensor {
func (t *Array) Tanh() *Array {
out := New("TANH", t)
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Tensor) Transpose(axes ...int) *Tensor {
func (t *Array) Transpose(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, axis := range axes {
cAxes[i] = C.int(axis)
@@ -180,7 +180,7 @@ func (t *Tensor) Transpose(axes ...int) *Tensor {
return out
}
func Zeros(dtype DType, shape ...int) *Tensor {
func Zeros(dtype DType, shape ...int) *Array {
cAxes := make([]C.int, len(shape))
for i := range shape {
cAxes[i] = C.int(shape[i])

View File

@@ -3,7 +3,7 @@ package mlx
// #include "generated.h"
import "C"
func (t *Tensor) Categorical(axis int) *Tensor {
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("", t, key)
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)

View File

@@ -57,7 +57,7 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
return args[0], args[1], args[2]
}
func (t *Tensor) Slice(slices ...slice) *Tensor {
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE", t)
C.mlx_slice(
@@ -70,7 +70,7 @@ func (t *Tensor) Slice(slices ...slice) *Tensor {
return out
}
func (t *Tensor) SliceUpdate(other *Tensor, slices ...slice) *Tensor {
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE", t, other)
C.mlx_slice_update(

View File

@@ -15,7 +15,7 @@ import (
type Model interface {
// Forward performs a forward pass through the model.
Forward(inputs *mlx.Tensor, cache []cache.Cache) *mlx.Tensor
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
// NumLayers returns the number of layers in the model.
// This is used to initialize caches.
@@ -25,11 +25,11 @@ type Model interface {
type TextGeneration interface {
Model
Unembed(*mlx.Tensor) *mlx.Tensor
Unembed(*mlx.Array) *mlx.Array
}
func Weights(m Model) (map[string]*mlx.Tensor, []func(*model.Root) error) {
mapping := make(map[string]*mlx.Tensor)
func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) error) {
mapping := make(map[string]*mlx.Array)
var afterLoadFuncs []func(*model.Root) error
var fn func(v reflect.Value, tags []string)
fn = func(v reflect.Value, tags []string) {
@@ -52,9 +52,9 @@ func Weights(m Model) (map[string]*mlx.Tensor, []func(*model.Root) error) {
tags = append(tags, tag)
}
if tt == reflect.TypeOf((*mlx.Tensor)(nil)).Elem() {
if tt == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
name := strings.Join(tags, ".")
mapping[name] = vv.Addr().Interface().(*mlx.Tensor)
mapping[name] = vv.Addr().Interface().(*mlx.Array)
continue
}

View File

@@ -30,11 +30,11 @@ func (m Model) Cache() []cache.Cache {
return caches
}
func (m *Model) Forward(inputs *mlx.Tensor, cache []cache.Cache) *mlx.Tensor {
func (m *Model) Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array {
return m.Text.Forward(inputs, cache)
}
func (m *Model) Unembed(x *mlx.Tensor) *mlx.Tensor {
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.Text.EmbedTokens.AsLinear().Forward(x)
}

View File

@@ -29,7 +29,7 @@ type TextModel struct {
Options TextOptions
}
func (m TextModel) Forward(inputs *mlx.Tensor, caches []cache.Cache) *mlx.Tensor {
func (m TextModel) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := inputs.Dim(0), inputs.Dim(1)
hiddenStates := m.EmbedTokens.Forward(inputs)
@@ -53,7 +53,7 @@ type TextDecoderLayer struct {
PostFFNorm RMSNorm `weight:"post_feedforward_layernorm"`
}
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Tensor, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Tensor {
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
residual := hiddenStates
hiddenStates = m.InputNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.Attention.Forward(hiddenStates, cache, B, L, rope, opts)
@@ -77,7 +77,7 @@ type TextAttention struct {
OProj mlx.Linear `weight:"o_proj"`
}
func (m TextAttention) Forward(hiddenStates *mlx.Tensor, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Tensor {
func (m TextAttention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
query := m.QProj.Forward(hiddenStates)
key := m.KProj.Forward(hiddenStates)
value := m.VProj.Forward(hiddenStates)
@@ -113,6 +113,6 @@ type TextMLP struct {
DownProj mlx.Linear `weight:"down_proj"`
}
func (m TextMLP) Forward(h *mlx.Tensor, opts TextOptions) *mlx.Tensor {
func (m TextMLP) Forward(h *mlx.Array, opts TextOptions) *mlx.Array {
return m.DownProj.Forward(mlx.GELUApprox(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h)))
}

View File

@@ -36,7 +36,7 @@ func (m Model) NumLayers() int {
return len(m.Layers)
}
func (m Model) Forward(inputs *mlx.Tensor, caches []cache.Cache) *mlx.Tensor {
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
slog.Debug("Model.forward", "input shape", inputs.Dims(), "m.EmbedTokens", m.EmbedTokens.Weight.Dims())
B, L := inputs.Dim(0), inputs.Dim(1)
hiddenStates := m.EmbedTokens.Forward(inputs)
@@ -56,7 +56,7 @@ type Layer struct {
MLP MLP `weight:"mlp"`
}
func (m Layer) Forward(hiddenStates *mlx.Tensor, c cache.Cache, B, L int, opts Options) *mlx.Tensor {
func (m Layer) Forward(hiddenStates *mlx.Array, c cache.Cache, B, L int, opts Options) *mlx.Array {
residual := hiddenStates
hiddenStates = m.AttentionNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.Attention.Forward(hiddenStates, c, B, L, opts)
@@ -76,7 +76,7 @@ type Attention struct {
OutputProj mlx.Linear `weight:"o_proj"`
}
func (m Attention) Forward(hiddenStates *mlx.Tensor, cache cache.Cache, B, L int, opts Options) *mlx.Tensor {
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
query := m.QueryProj.Forward(hiddenStates)
query = query.Reshape(B, L, opts.NumAttentionHeads, -1).Transpose(0, 2, 1, 3)
@@ -101,7 +101,7 @@ type MLP struct {
Down mlx.Linear `weight:"down_proj"`
}
func (m MLP) Forward(h *mlx.Tensor) *mlx.Tensor {
func (m MLP) Forward(h *mlx.Array) *mlx.Array {
return m.Down.Forward(mlx.SILU(m.Gate.Forward(h)).Multiply(m.Up.Forward(h)))
}

View File

@@ -42,8 +42,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
n := min(2<<10, total-processed-1)
temp := model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
mlx.Eval(func() []*mlx.Tensor {
s := make([]*mlx.Tensor, 2*len(caches))
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
@@ -54,7 +54,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
mlx.ClearCache()
}
step := func(token *mlx.Tensor) (*mlx.Tensor, *mlx.Tensor) {
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
logits := model.Unembed(model.Forward(token.ExpandDims(0), caches))
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)

View File

@@ -7,7 +7,7 @@ import (
)
type Sampler interface {
Sample(*mlx.Tensor) *mlx.Tensor
Sample(*mlx.Array) *mlx.Array
}
func New(temp, top_p, min_p float32, top_k int) Sampler {
@@ -34,13 +34,13 @@ func New(temp, top_p, min_p float32, top_k int) Sampler {
type greedy struct{}
func (greedy) Sample(logits *mlx.Tensor) *mlx.Tensor {
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
type chain []Sampler
func (c chain) Sample(logits *mlx.Tensor) *mlx.Tensor {
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, sampler := range c {
logits = sampler.Sample(logits)
}
@@ -49,27 +49,27 @@ func (c chain) Sample(logits *mlx.Tensor) *mlx.Tensor {
type Temperature float32
func (t Temperature) Sample(logits *mlx.Tensor) *mlx.Tensor {
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
}
type TopP float32
func (p TopP) Sample(logprobs *mlx.Tensor) *mlx.Tensor {
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type MinP float32
func (p MinP) Sample(logprobs *mlx.Tensor) *mlx.Tensor {
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type TopK int
func (k TopK) Sample(logprobs *mlx.Tensor) *mlx.Tensor {
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}