NOTE: This post is going to be a compiler post, not a machine learning
tutorial, so please treat it as such. Maybe it will still help you understand
ML through a compilers lens.
I had a nice chat with my friend Chris
recently.
He walked me through the basics of machine learning while I was looking at
Andrej Karpathy’s
micrograd.
If you are unfamiliar, micrograd is a very small implementation of a
scalar-valued neural network (as opposed to vectors or matrices as the
computational unit) in pure Python, which uses no libraries.
Micrograd is a combination of a couple of different and complementary parts:
- a little graph-based expression builder and evaluator
- reverse-mode automatic differentiation on that same computation graph
- neural net building blocks for a multi-layer perceptron (MLP)
(If you don’t know what a MLP is, don’t worry too much. This post should give
you a bit of background, especially if you are already comfortable with Python.
You may want to go through and read and think about the micrograd source code
before coming back. Or not! Your call. Playing with it helped me a lot. Chris
suggested trying to make a network learn XOR.)
Together, these three major components let you write code that looks like this:
from micrograd.nn import MLP
model = MLP(2, [4, 1])
And summon a neural network from thin air.
The thing that got me the first time I read it was that I thought the building
blocks were the network. In this library, no. Using a building analogy, they
are more like blueprints or scaffolding. With each evaluation of the network,
the connective tissue (intermediate computation graph) is constructed anew. In
compiler terms, the building blocks are kind of like the front-end and the
expression graph is a sort of intermediate representation (IR).
You may be sitting there wondering why I am telling you this. I normally blog
about compilers. What’s this?
It’s because once I untangled and understood the three pieces in micrograd, I
realized:
- ML models are graphs
- forward and backward passes are graph traversals
- the graph structure does not change over time
- performance is important
Which means… it sounds like a great opportunity for a compiler! This is why
projects like PyTorch and TensorFlow have compilers
(TorchScript/TorchDynamo/AOT Autograd/PrimTorch/TorchInductor/Glow, XLA, etc).
Compiling your model speeds up both training and inference. So this post will
not contain anything novel—it’s hopefully a quick sketch of a small example
of what the Big Projects do.
We’re going to compile micrograd neural nets into C. In order, we will
- do a brief overview of neural networks
- look at how micrograd does forward and backward passes
- review the chain rule
- learn why micrograd is slow
- write a small compiler
- see micrograd go zoom
Let’s go!
How micrograd does neural networks
First, a bit about multi-layer perceptrons. MLPs are densely connected neural
networks where input flows in one direction through the network. As it exists
in the upstream repository, micrograd only supports MLPs.
In case visual learning is your thing, here is a small diagram:
anyway. I made this in Excalidraw. I love Excalidraw.
In this image, circles represent data (input or intermediate computation
results) and arrows are weights and operations on the data. In this case, the
x
, y
, and z
circles are input data. The arrows going right are
multiplications with weights. The meeting of the arrows represents an addition
(forming a dot product) followed by addition of the bias (kind of like another
weight), all fed into an activation function (in this case ReLU
, for
“rectified linear unit”)1. The circles on the right are the results
of the first layer.
Karpathy implements this pretty directly, with each neuron being an instance of
the Neuron
class and having a __call__
method do the dot product. After
each dot product is an activation, in this case ReLU
, which is equivalent to
max(x, 0)
. I think the 0
is an arbitrary threshold but I am not certain.
Below is the entire blueprint code for a multilayer perceptron in micrograd
(we’ll come back to the Value
class later):
import random
from micrograd.engine import Value
class Module:
def zero_grad(self):
for p in self.parameters():
p.grad = 0
def parameters(self):
return []
class Neuron(Module):
def __init__(self, nin, nonlin=True):
self.w = [Value(random.uniform(-1,1)) for _ in range(nin)]
self.b = Value(0)
self.nonlin = nonlin
def __call__(self, x):
act = sum((wi*xi for wi,xi in zip(self.w, x)), self.b)
return act.relu() if self.nonlin else act
def parameters(self):
return self.w + [self.b]
def __repr__(self):
return f"{'ReLU' if self.nonlin else 'Linear'}Neuron({len(self.w)})"
class Layer(Module):
def __init__(self, nin, nout, **kwargs):
self.neurons = [Neuron(nin, **kwargs) for _ in range(nout)]
def __call__(self, x):
out = [n(x) for n in self.neurons]
return out[0] if len(out) == 1 else out
def parameters(self):
return [p for n in self.neurons for p in n.parameters()]
def __repr__(self):
return f"Layer of [{', '.join(str(n) for n in self.neurons)}]"
class MLP(Module):
def __init__(self, nin, nouts):
sz = [nin] + nouts
self.layers = [Layer(sz[i], sz[i+1], nonlin=i!=len(nouts)-1) for i in range(len(nouts))]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def parameters(self):
return [p for layer in self.layers for p in layer.parameters()]
def __repr__(self):
return f"MLP of [{', '.join(str(layer) for layer in self.layers)}]"
You can ignore some of the clever coding in MLP.__init__
. This ensures that
all of the layers match up end-to-end dimension-wise. It also ensures the last
layer is linear, meaning the neurons do not have an activation function
attached.
But this neural network is not built just with floating point numbers. Instead
Karpathy uses this Value
thing. What’s that about?
Intro to the expression builder
I said that one of micrograd’s three components is an expression graph builder.
Using the expression builder looks like a slightly more complicated way of
doing math in Python:
>>> from micrograd.engine import Value
>>> a = Value(2)
>>> b = Value(3)
>>> c = Value(4)
>>> d = (a + b) * c
>>> d
Value(data=20, grad=0)
>>>
The Value
class even implements all the operator methods like __add__
to
make the process painless and look as much like normal Python math as possible.
But it’s a little different than normal math. It’s different first because it
has this grad
field—which we’ll talk more about later—but also because as
it does the math it also builds up an graph (you can kind of think of it as an
abstract syntax tree, or AST).
It’s not visible in the normal string representation, though. Value
instances
have a hidden field called _prev
that stores the constituent parts that make
up an expression:
>>> d._prev
{Value(data=5, grad=0), Value(data=4, grad=0)}
>>>
They also have a hidden operator field:
This means that we have two operands to the *
node d
: c
(4) and a + b
(5).
I said you could think about it like an AST but it’s not quite an AST because
it’s not a tree. It’s expected and normal to have more of a directed acyclic
graph (DAG)-like structure.
>>> w = Value(2)
>>> x = 1 + w
>>> y = 3 * w
>>> z = x + y
>>> z
Value(data=9, grad=0)
>>>
Here x
and y
both use w
and then are both used by z
, forming a diamond
pattern.
in it, making it a directed graph instead of a tree.
It is assumed that the graph won’t have cycles in it2.
So what does creating the graph look like in code? well, the Value.__mul__
function, called on the left hand side of an x*y
operation3, looks
like this:
class Value:
# ...
def __mul__(self, other):
# create a transient value if the right hand side is a constant int or
# float, like v * 3
other = other if isinstance(other, Value) else Value(other)
# pass in new data, children, and operation
out = Value(self.data * other.data, (self, other), '*')
# ... we'll come back to this hidden part later ...
return out
The children
tuple (self, other)
are the pointers to the other nodes in the
graph.
But why do we have these expression graphs? Why not just use math? Who
cares about all the back pointers?
Let’s talk about grad(ient)
Training a neural network is a process of shaping your function (the neural
network) over time to output the results you want. Inside your function are a
bunch of coefficients (“weights”) which get iteratively adjusted during
training.
The standard training process involves your neural network structure and also
another function that tells you how far off your output is from some expected
value (a “loss function”). A simple example of a loss function is
loss(actual, expected) = (expected - actual)**2
(where **
is exponentiation
in Python). If you use this particular function across multiple inputs at a
time, it’s called Mean Squared Error (MSE)4.
If you are trying to get some expected output, you want to minimize the value
of your loss function as much as possible. In order to minimze your loss, you
have to update the weights.
To figure out which weights to update and by how much, you need to know how
much each weight contributes to the final loss. Not every weight is equal; some
have significantly more impact than others.
The question “how much did this weight contribute to the loss this round” is
answered by the value of the grad (gradient) of that weight—the first
derivative—the slope at a point. For example, in y = mx + b
, the equation
that describes a line, the derivative with respect to x
is m
, because the
value of x
is scaled by m
(and b
is a constant).
To compute the grad, you need to traverse backwards from the loss5 to
do something called reverse mode automatic differentiation (reverse mode AD).
This sounds scary. Every article online about it has scary notation and
squiggly lines. But it’s pretty okay, actually, so keep on reading.
Fortunately for us, reverse mode AD, like evaluating an AST top to bottom, it
is a graph traversal with some local state. If you can write a tree-walking
interpreter, you can do reverse mode automatic differentiation.
Reverse mode AD and backpropagation
Instead of building up a parallel graph of derivatives (a sort of “dual” to the
normal expression graph), reverse mode AD computes local derivatives at each
node in the grad
(gradient) field. Then you can propagate these gradients
backward through the graph from the loss all the way to the
weights—backpropagation.
But how do you compose all those local derivatives? There’s no way it’s simple,
right? Taking derivatives of big math expressions is scary…
It turns out, calculus already has the answer in something called the chain
rule.
The chain rule
I am not going to pretend that I am a math person. Aside from what I re-learned
in the last couple of weeks, I only vaguely remember the chain rule from 10
years ago. Most of what I remember is my friend Julia figuring it out
instantaneously and wondering why I didn’t get it yet. That’s about it. So
please look elsewhere for details if this section doesn’t do it for you. I
won’t be offended.
A quick overview
The chain rule tells you how to compute derivatives of function composition.
Using the example from Wikipedia, if you have some function h(x) = f(g(x))
,
then h'(x) = f'(g(x)) * g'(x)
(where f'
and h'
and g'
are the
derivatives of f
and h
’ and g
, respectively). This rule is nice, because
you don’t need to do anything tricky when you start composing functions, as
long as you understand how to take the derivative of each of the component
parts.
For example, if you have sin(x**2)
, you only need to know the derivative of
the component functions x**2
(it’s 2*x
) and sin(x)
(it’s cos(x)
) to
find out the answer: cos(x**2) * 2x
.
To take a look at the proof of this and also practice a bit, take a look at
this short slide
deck (PDF)
from Auburn University. Their course page table of
contents has more
slide decks6.
Also make sure to check out the list of differentiation
rules on Wikipedia.
It turns out that the chain rule comes in handy for taking derivatives of
potentially enormous expression graphs. Nobody needs to sit down and work out
how to take the derivative of your huge and no doubt overly complex function…
you just have your building blocks that you already understand, and they are
composed.
So let’s apply the chain rule to expression graphs.
Applying this to the graph
We’ll start with one Value
node at a time. For a given node, we can do one
step of the chain rule (in pseudocode):
# pseudocode
def backward(node):
for child in node._prev:
child.grad += derivative_wrt_child(child) * node.grad
Where wrt
means “with respect to”. It’s important that we take the derivative
of each child with respect to the child.
Instead of just setting child.grad
, we are increasing it for two reasons:
- one child may be shared with other parents, in which case it affects both
- batching, but that’s not important right now
To make this more concrete, let’s take a look at Karpathy’s implementation of
the derivative of *
, for example. In math, if you have f(x,y) = x*y
, then
f'(x, y) = 1*y
(with respect to x
) and f'(x, y) = x*1
(with respect to
y
). In code, that looks like:
class Value:
# ...
def __mul__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data * other.data, (self, other), '*')
# The missing snippet from earlier!
def _backward():
self.grad += other.data * out.grad
other.grad += self.data * out.grad
out._backward = _backward
return out
This means that for each of the children, we will use the other child’s data
and (because of the chain rule) multiply it by the parent expression’s grad
.
That is, self
’s grad (the left hand side) is adjusted using other
’s data
(the right hand side) and vice versa. See what a nice translation of the math
that is? Get the derivative, apply the chain rule, add to the child’s grad.
Now we have a function to do one derivative step for one operation node, but we
need to do the whole graph.
But traversing a graph is not as simple as traversing a tree. You need to avoid
visiting a node more than once and also guarantee that you visit child nodes
before parent nodes (in forward mode) or parent nodes before children nodes (in
reverse mode). The tricky thing is that while we don’t visit a node more than
once, visiting updates the node’s children (not the node itself), and nodes may
share children, so children’s grad
s may be updated multiple times. This is
expected and normal!
For that reason, we have topological sort.
Topological sort and graph transformations
A topological sort on a graph is an order where children are always visited
before their parents. In general this only works if the graph does not have
cycles, but—thankfully—we already assume above that the graph does not have
cycles.
Here is a sample topological sort on the Value
graph. It uses the nested
function build_topo
for terseness, but that is not strictly necessary.
class Value:
# ...
def topo(self):
# modified from Value.backward, which builds a topological sort
# internally
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._prev:
build_topo(child)
topo.append(v)
build_topo(self)
return topo
To get a feel for how this works, we can do a topological sort of a very simple
expression graph, 1+2
.
>>> from micrograd.engine import Value
>>> x = Value(1)
>>> y = Value(2)
>>> z = x * y
>>> z.topo()
[Value(data=1, grad=0), Value(data=2, grad=0), Value(data=3, grad=0)]
>>>
The topological sort says that in order to calculate the value 3
, we must
first calculate the values 1
and 2
. It doesn’t matter in what order we do
1
and 2
, but they both have to come before 3
.
Now that we have a way to get a graph traversal order, we can start doing some
backpropagation.
Applying this to backpropagation
If we take what we know now about the chain rule and topological sort, we can
do backpropagation on the graph. Below is the code straight from micrograd. It
first builds a topological sort and then operates on it in reverse, applying
the chain rule to each Value
one at a time.
class Value:
# ...
def backward(self):
# topological order all of the children in the graph
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._prev:
build_topo(child)
topo.append(v)
build_topo(self)
# --- the new bit ---
# go one variable at a time and apply the chain rule to get its gradient
self.grad = 1
for v in reversed(topo):
v._backward()
The Value.backward
function is normally called on the result Value
of the
loss function.
If you are wondering why we set self.grad
to 1
here before doing
backpropagation, take a moment and wonder to yourself. Maybe it’s worth drawing
a picture!
Putting it all together
I am not going to get into the specifics, but here is what a rough sketch of
very simplified training loop might look like for MLP-based classifier for the
MNIST digit recognition
problem. This code is not
runnable as-is. It needs the image loading support code and a loss function.
The hyperparameters (batch size, etc) are completely arbitrary and untuned.
The full training
code
and corresponding engine
modifications
to add exp
/log
/Max
are available in the GitHub repo.
import random
from micrograd.nn import MLP
# ...
NUM_DIGITS = 10
LEARNING_RATE = 0.1
# Each image is 28x28. Hidden layer of width 50. Output 10 digits.
model = MLP(28*28, [50, NUM_DIGITS])
# Pretend there is some kind of function that loads the labeled training images
# into memory.
db = list(images("train-images-idx3-ubyte", "train-labels-idx1-ubyte"))
num_epochs = 100
for epoch in range(num_epochs):
for image in db:
# zero grad
for p in model.parameters():
p.grad = 0.0
# forward
output = model(image.pixels)
loss = compute_loss(output)
# backward
loss.backward()
# update
for p in model.parameters():
p.data -= LEARNING_RATE * p.grad
In this snippet, constructing the MLP
(model = MLP(...)
) builds a bunch of
Neuron
s in Layer
s and initializes some weights as Value
s, but it does not
construct the graph yet. Only when it is called (as in model(image.pixels)
)
does it construct the graph and do all of the dot produc