STML - tiny tensor library to run pytorch models

I wanted to use SAM2 for a project, but I couldn't use my new favourite langugage odin, because sam is written in python and uses pytorch and god knows what other libraries. I cannot make it run in the browser, it takes forever to install, I need cuda and weird drivers, ugh. Imagine car manufacturers shipped a whole factory alongise the car itself when you buy one. It is wasteful and confusing for the user.

I'm not the first person with this idea. 4-letter efforts like onnx or ggml already do that, but I wanted to do something like that myself. After all, when you look deep down in the couple thousand line source code, beneath all the image loading and device shuffling, it's just a bunch of tensor operations, right? As long as I can muliply and add numbers, I should be able to port any model to run anywhere.

(Bad) basic tensor operations

In pytorch, numpy and everything else I've seen, a tensor is an array of numbers and a shape telling you how to use them. For example, an image might be a tensor with shape (3,1024,1024), which can be understood as 3 channels (RGB), 1024 rows per channel and 1024 columns in each row. This can be implemented in odin the following way:

Tensor :: struct {
    shape: []int,
    data:  []f32,
}
Here data is a slice, a nice syntax for a structure with a pointer to some data and length. Operating with tensors will look like this:
add :: proc(a, b: Tensor) -> Tensor {
    out := Tensor {
        shape = slice.clone(a.shape),
        data = make([]f32, len(a.data))
    }
    for i in 0..<len(a.data) {
        out.data[i] = a.data[i] + b.data[i]
    }
    return out
}
Notice that the code above will brake if len(a.data) != len(b.data). More complex operations like matrix multiplication actually need to verify the shapes itself. When can assert(), risk out of memory exceptions, or...

Compile time tensor shape matching

You can add two images toghether as long as their shapes match. During runtime pytorch might throw an exception when it encounters incompatbile tenors. I hate that. I wish there was some way to catch while I'm writing the code, not while running. Let's use odin's parametric structures to achieve that.

V :: struct($a: int)             { data: [a      ]f32 } // 1D - Vector
M :: struct($a, $b: int)         { data: [a*b    ]f32 } // 2D - Matrix
T :: struct($a, $b, $c: int)     { data: [a*b*c  ]f32 } // 3D - Tri
Q :: struct($a, $b, $c, $d: int) { data: [a*b*c*d]f32 } // 4D - Quad
The downside is that instead of a single struct "tensor", we have a struct for each dimentioness a tensor can be. We have a struct for vectors, a struct for matrices, etc. This looks ugly, but it turns out to be really usefull later on. Here's what add looks like:
add_slice :: proc(a, b, o: []f32) { for i in 0..<len(a) { o[i] = a[i] + b[i] } }
add_v :: proc(a, b, o: ^V($x))    { add_slice(a.data[:], b.data[:], o.data[:]) }
add_m :: proc(a, b, o: ^M($x,$y)) { add_slice(a.data[:], b.data[:], o.data[:]) }
add :: proc{add_slice, add_v, add_m}
The last line allows us to call add(something, else) and the compiler will figure out which of the three function is appropriate. The compiler enforces that the tensors we add are either 2 matrices with the same shape, or a vectors with the same length. Here a matrix multiplication:
mul :: proc(a: ^M($ar, $ac), b: ^M(ac, $bc), o: ^M(ar, bc)) {
    for r in 0..<ar {
        for c in 0..<bc {
            for k in 0..<ac {
                o.data[r*bc+c] += a.data[r*ac+k] * b.data[k*bc+c]
            }
        }
    }
}
This way if you try to do something stupid like this:
a: M(10, 5)
b: M(6, 3)
o: M(10, 3)
mul(a, b, o)
The compiler can tell you that b is of the wrong size. If you change M(6,3) to M(5,3), odin is happy. Here's what the conv2d operations's check look like:
conv2d :: proc(img: ^T($I,$R,$C), w: ^Q($O,I,$S,S), b: ^V(O), out: ^T(O, $OR, $OC), $stride: int) {
    #assert(OR == (R-(S-1)-1)/stride+1, "Image, convolution and output width do not match")
    #assert(OC == (C-(S-1)-1)/stride+1, "Image, convolution and output height do not match")
    ...
}

Memory and reshaping

Since M(4,4) is practically an alias for 16 floats, we can trivially cast it to some other shape:

a := M(4,4)
b := cast(^V(16))&M
Now b is a pointer to V(16), meaning that if you change V.data you also change a.data This allows us to other fun stuff, like take slices from a tensor:
c := cast(^M(2,4))raw_data(M.data[8:])
I've called those functions as and get in stml's source code.

Training a model

Here's an MLP in pytorch that trains in a couple of seconds. The core model is just 2 fully connected layers:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 20)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(20, 10)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x
Saving the weights so we can #load them from odin.
net.fc1.weight.detach().numpy().tofile('fc_weights/fc1.weight.bin')
net.fc1.bias  .detach().numpy().tofile('fc_weights/fc1.bias.bin')
net.fc2.weight.detach().numpy().tofile('fc_weights/fc2.weight.bin')
net.fc2.bias  .detach().numpy().tofile('fc_weights/fc2.bias.bin')

Running the model with stml

Odin has this nifty keyword #load that allows us to take any file and embed it into our program's code. This means I can have a single binary file that I give to people like llama file. Notice that pytorch's linear layers have the weights transmuted, so the shape is (20, 784) and not (784, 20).
fc1_weight := load(20, 784, "./assets/fc_weights/fc1.weight.bin")
fc1_bias   := load( 1,  20, "./assets/fc_weights/fc1.bias.bin")
fc2_weight := load(10,  20, "./assets/fc_weights/fc2.weight.bin")
fc2_bias   := load( 1,  10, "./assets/fc_weights/fc2.bias.bin")
Executing:
hid1 := new(stml.M(1, 20))
hid2 := new(stml.M(1, 10))

stml.mulT(img,  fc1_weight, hid1) // img=^T(1,28,28)
stml.add(hid1,  fc1_bias,   hid1)
stml.relu_(hid1)
stml.mulT(hid1, fc2_weight, hid2)
stml.add(hid2,  fc2_bias,   hid2)
stml.sigmoid_(hid2)