The repository: https://github.com/siliconlad/ferrix
Earlier this year, I set myself the goal of playing around with Rust, a language known for its memory safety and performance. In this post, I’ll walk through how I built a basic static matrix library.
As a fun side quest, I made donuts rotate with Ferrix (based on the blog post by Andy Sloane). The code for this is here.
Because I was using Rust, which is a compiled language, I wanted to type my library so that almost all operations could be checked at compile time. That way, if it compiles, then it will probably work!
To do this, I have to store the shape of a matrix as part of the type. This lets us ensure operations (e.g. matrix multiplication) are only allowed between certain matrices.
Here’s a snippet of code to illustrate my point. We define a mul
operation between two matrices of shape (N, M)
and (M, P)
to produce a new matrix of shape (N, P)
. The compiler will ensure that any two matrices provided to mul
adhere to the necessary dimensional constraints (i.e. that the number of columns in the first matrix matches the number of rows in the second).
// T is the type of the elements in the matrix
// R and C are the rows and columns of the matrix
pub struct Matrix<T, const R: usize, const C: usize> {
data: [[T; C]; R],
}
impl<
T,
const N: usize,
const M: usize,
const P: usize
> Mul<Matrix<T, M, P> for Matrix<T, N, M> {
fn mul(self, rhs: Matrix<T, M, P>) -> Matrix<T, N, P> {
// Implementation details
}
};
// This works because (2, 3) x (3, 1) > (2, 1)
let m1 = Matrix<f32, 2, 3>::new();
let m2 = Matrix<f32, 3, 1>::new();
let m3 = m1 * m2; // This works!
// This doesn't work, because (2, 4) x (3, 1) > ??
let m1 = Matrix<f32, 2, 4>::new();
let m2 = Matrix<f32, 3, 1>::new();
let m3 = m1 * m2; // No luck :(
This is the basis for the entire library. Once you have this, it’s pretty much just a matter of implementing the necessary operations and methods that you need for a matrix.
The first challenge was figuring out how to represent row vectors and column vectors. The natural option is to define them as special cases of Matrix
.
// Type aliases for special cases of a matrix
type Vector<T, const N: usize> = Matrix<T, N, 1>;
type RowVector<T, const N: usize> = Matrix<T, 1, N>;
The biggest advantage of this method is also the biggest downside. Every method I define for the Matrix
type, will also work for Vector
and RowVector
. This reduces potential code duplication. However, it introduces a challenge.
Rust does not have what’s known as specialization.
Here’s an example. We define two new
methods: one for the general Matrix
case, and the other for the Vector
specific case. Without specialization, Rust doesn’t allow me to define both methods simultaneously, since they would conflict.
impl<T, const N: usize> Matrix<T, N, 1> {
fn new(self, rhs: [T; N]) -> Matrix<T, N, 1> {
// Implementation details
}
}
// Not possible because new is already defined sbove.
impl<T, const N: usize, const M: usize> Matrix<T, N, M> {
fn new(self, rhs: [[T; M]; N]) -> Matrix<T, N, M> {
// Implementation details
}
}
This creates a problem because I would prefer to define my Vector
types like this:
Vector::new([1.0, 2.0])
But because there is no specialization, I have to define them like this:
Vector::new([[1.0], [2.0]])
This is a minor instance of a general theme of not being able to tailor Matrix
methods to fit the special case of Vector
and RowVector
(this affects methods for creating views too). You probably could get it to work, but I decided to have separate Vector
and RowVector
types.
I wanted my cleaner API!
Another decision I made was to support views. This is a common concept in other libraries dealing with Vectors and Matrices (e.g. Eigen in C++). A view provides a way to reference existing data in a matrix without making a full copy, improving performance for large matrices.
Take the following example where *
is matrix multiplication.
let v1 = Vector<f32, 3>;
let m1 = Matrix<f32, 3, 2>;
let result = m1.t() * v1;
In a naive implementation, t()
would return a new Matrix
of shape (2, 3)
before the matrix multiplication. However, for large matrices, this is expensive due to the cost of copying.
Instead, we can create views like MatrixTransposeView
which stores a reference to the original matrix but presents the data as if it were transposed, thus avoiding the overhead of copying.
I have these views (v0.1.0
):
VectorView
& VectorViewMut
RowVectorView
& RowVectorViewMut
MatrixView
& MatrixViewMut
MatrixTransposeView
& MatrixTransposeViewMut
Due to the decision of having three main types (Vector
, RowVector
, and Matrix
) as well as having views to limit the amount of copies, specifying all the operators became a pain.
Take the + (Add)
operation where Vector
is on the left-hand side of the +
operator. We would want the following combinations to be defined:
Vector
+ Vector
Vector
+ VectorView
Vector
+ VectorViewMut
Vector
+ Matrix
Vector
+ MatrixView
Vector
+ MatrixViewMut
Vector
+ MatrixTransposeView
Vector
+ MatrixTransposeViewMut
We also would want to define the operations for references for each of the bullets above!
Vector
+ Vector
Vector
+ &Vector
&Vector
+ Vector
&Vector
+ &Vector
Rust macros helped a little to cover every case (see src/ops
if you’re curious about how I did it). There must be a better way, but it seems to work (if you know a better way, please let me know!).
This was definitely a huge downside to having so many different types. It will only get worse with the more types I add (e.g. 3rd and 4th order Tensors) unless I can figure out a better system.
An interesting detail from the earlier Matrix
definition is that the underlying storage must be an array of arrays. This is because stable Rust does not yet support complex generic expressions.
To store the matrix data using a single array, I would need to multiply the number of rows R
with the number of columns C
at compile time. Unfortunately, Rust does not allow this.
pub struct Matrix<T, const R: usize, const C: usize> {
data: [T; C * R],
}
This actually prevents me from implementing some other features:
Vector<T, N> * Vector<T, M> = Vector<T, N + M>
Matrix<T, N, M> → Vector<T, N * M>
Implementing a static matrix in Rust was much more involved than I expected. Being new to Rust probably also meant I made many suboptimal design choices. While Rust lacks some features that could have made my life easier, it was a valuable learning experience.
I learned a lot about how to take a Rust project from zero to one. I learned how to use Rust macros (a little bit), and had a lot of fun trying to figure out the best way to structure my code.
If you’re curious about the donut, check out the examples
folder :)
There’s a lot more to be done to make this library usable, especially around performance. Perhaps in a later post we will explore implementing SIMD and/or CUDA to try and speed things up.
Until next time!