Archives
Categories
Blogroll
This is the eighth post in my trek through Sebastian Raschka‘s book
“Build a Large Language Model (from Scratch)“.
I’m blogging about bits that grab my interest, and things I had to rack my
brains over, as a way
to get things straight in my own head — and perhaps to help anyone else that
is working through it too. It’s been almost a month since my
last update — and
if you were suspecting that I was
blogging about blogging and spending time
getting LaTeX working on this site as
procrastination because this next section was always going to be a hard one, then you
were 100% right! The good news is that — as so often happens with these things —
it turned out to not be all that tough when I really got down to it. Momentum
regained.
If you found this blog through the blogging-about-blogging, welcome! Those
posts were not all that typical, though, and I hope
you’ll enjoy this return to my normal form.
This time I’m covering section 3.4, “Implementing self-attention
with trainable weights”. How do we create a system that can learn how to interpret
how much attention to pay to words in a sentence, when looking at other words — for
example, that learns that in “the fat cat sat on the mat”, when you’re looking at “cat”,
the word “fat” is important, but when you’re looking at “mat”, “fat” doesn’t matter
as much?
Before diving into that, especially given the amount of time since the last post,
let’s start with the 1,000-foot view of how the GPT-type
decoder-only transformer-based LLMs (hereafter “LLMs” to save me from RSI) work.
For each step I’ve linked to the posts where I went throught the details.
- You start off with a string, presumably of words. (Part 2)
- You split it up into tokens (words like “the”, or chunks like “semi”). (Part 2)
- The job of the LLM is to predict the next token, given all of the tokens in the
string so far. (Part 1) - Step 1: map the tokens to a sequence of
vectors called token embeddings. A particular token,
say, “the”, will have a specific embedding — these start out random but the LLM
works out useful embeddings as it’s trained. (Part 3) - Step 2: generate another sequence of position embeddings — vectors of the
same size as the token embeddings, also starting random but trained, that represent
“this is the first token”, “this is
the second token”, and so on. (Part 3. ) - Step 3: add the two sequences to generate a new sequence of input embeddings.
The first input embedding is the first token embedding plus the first position
embedding (added element-wise), the second is the second token embedding plus the second
position embedding, and so on. (Part 3) - Step 4: self-attention. Take the input embeddings
and for each one, generate a list of attention scores. These
are numbers that represent how much attention to pay to each other token when considering the token
in question. So (assuming one token per word) in “the fat cat sat on the mat”,
the token “cat” would need a list of 7 attention scores — how much attention to
pay to the first “the”, how much to pay to “fat”, how much to pay to itself, “cat”,
how much to pay to “sat”, and so on. Exactly how it does that is what this section
of the book covers — up until now we’ve been using a “toy” example calculation.
(Part 4,
Part 5, Part 6,
Part 7). - Step 5: normalise the attention scores to attention weights. We
want each token’s list of attention weights to add up to one — we do this by running each list through
the softmax function.
(Part 4,
Part 5, Part 6,
Part 7). - Step 6: generate a new sequence of context vectors.
In the system that we’ve built so far, this contains, for each token, the sum of multiplying all of the input embeddings
by their respective attention weights and adding the results together.
So in that example above, the context vector for “cat”
would be the input embedding for the first “the” times “cat”‘s attention score for
that “the”, plus the input embedding for “fat” times “cat”‘s attention score for
“fat”, and so on for every other token in the sequence.
(Part 4,
Part 5, Part 6,
Part 7).
After all of this is done, we have a sequence of context vectors,
each of which should in some way represent the meaning of its respective token in
the input, including those bits of meaning it gets from all of the other tokens.
So the context vector for “cat” will include some hint of its fatness, for example.
What happens with those context vectors that allows the LLM to use them to predict
what the next token might be? That bit is still to be explained, so
we’ll have to wait and see. But the first thing to learn is how we create a trainable
attention mechanism that can take the input vectors and generate the attention
scores so that we can work out those context vectors in the first place.
The answer Raschka gives in this section is called scaled dot product attention.
He gives a crystal-clear runthrough of the code to do it, but I had to bang my head
against it for a weekend to get to a solid mental model.
So, instead of going through the
section bit-by-bit, I’ll present my own explanation of how it works — to save me
from future head-banging when trying to remember it, and perhaps to save other people’s
foreheads from the same fate.
The summary, ahead of time
I’m a long-time fan of the Pimsleur
style of language course, where they start each tutorial with minute or so of conversation
in the language you’re trying to learn, then say “in 30 minutes, you’ll hear that again
and you’ll understand it”. You go through the lession, they play the conversation again, and you
do indeed understand it.
So here is a compressed summary of how self-attention works,
in my own words, based on Raschka’s explanation. It might look like a wall of jargon now, but
(hopefully) by the time
you’ve finished reading this blog post, you’ll be able to re-read it and it will all make sense.
We have an input sequence of length , of tokens. We have converted it to a
sequence of input embeddings,
each of which is a vector of length — each of these can be treated as a
point in -dimensional
space. Let’s represent that sequence of embeddings with values like this: . Our goal is to produce a
sequence of length made up of context vectors, each of which represents the
the meaning of the respective input token in the context of the input as a whole. These
context vectors will each be of length (which in practice is often equal to ,
but could in theory be of any length).
We define three matrices, the query weights matrix , the key weights matrix ,
and the value weights matrix . These are made up of trainable
weights; each one of them is sized . Because of those dimensions, we
can treat them as operations that project a vector of length — a point in -dimensonal
space — to a vector of length — a point in
a -dimensional space. We will call these projected spaces key space,
query space and
value space. To convert an input vector into query space, for example, we just
multiply it by , like this .
When we are considering input , we want to work out its attention weights for
every input in the sequence (including itself). The first step is to work out the attention score,
which, when considering another input , is calculated by taking the dot
product of the projection of into query space, and the projection of into
key space. Doing this across all inputs provides us with an attention score
for every other token for . We then divide these by the square root of the
dimensionality of the spaces we are projecting into, , and run the resulting
list through the softmax function to make them all add up to one. This list is the
attention weights for . This process is called scaled dot product attention.
The next step is to generate a context vector for . This is simply the
sum of the projections of all of the inputs into the value space, each one multiplied
by its associated attention weight.
By performing these operations for each of the input vectors, we can generate a list
of length made up of context vectors of length , each of which represents the meaning of a input token in the context of
the input as a whole.
Importantly, with clever use of matrix multiplication, all of this can be done for
all inputs in the sequence, producing a context vector for every one of them,
with just five matrix multiplications and a transpose.
Now let’s explain it
First things first, if there’s anyone there that understood all of that without
already knowing how attention mechanisms work, then I salute you! It was pretty
dense, and I hope it didn’t read like my friend Jonathan’s
parody of incomprehensible guides to using git.
For me, it took eight re-reads of Raschka’s (emininently clear and readable)
explanation to get to a level where I felt I understood it. I think it’s also worth noting
that it’s very much a “mechanistic” explanation — it says how we do these calculations
without saying why. I think that the “why” is actually out of scope for this book,
but it’s something that fascinates me, and I’ll blog about it soon. But,
in order to understand the “why”, I think we need to have a solid grounding in the
“how”, so let’s dig into that for this post.
Up until this section of the book, we have been working out the attention scores by taking the dot product
of the input embeddings against each other — that is, when you’re looking
at , the attention score for is just . I suspected
earlier that the reason that Raschka was using that specific operation for his
“toy” self-attention was that the real implementation is similar, and that has turned
out right, as we’re doing scaled dot products here. But what we do is adjust them first — , the one that we’re considering,
is multiplied by the query weights matrix first, and the other one is
multiplied by the key weights matrix . Raschka refers to this as a projection,
which for me is a really nice way to look at it. But his reference is just in passing,
and for me it needed a bit more digging in.
Matrices as projections between spaces
If your matrix maths is a bit rusty — like mine was — and you haven’t read the
primer I posted the other week, then
you might want to check it out now.
From your schooldays, you might remember that matrices can be used to apply geometric
transformations. For example, if you take a vector representing a point, you can multiply
it by a matrix to rotate that point about the origin.
You can use a matrix like this to rotate things anti-clockwise by degrees:
This being matrix multiplication, you could add on more points — that is, if the
first matrix had more rows, each of which was a point you wanted to rotate, the same
multiplication would rotate them all by . So you can see that matrix as
being a function that maps sets of points to their rotated equivalents. This
works in higher dimensions, too — a matrix like this can represent
transformations in 2 dimensions, but, for example, in 3d graphics, people
use matrices to do similar transformations to the points that make up
3d objects.
An alternative way of looking at this matrix is that it’s a function that
projects points from
one 2-dimensional space to another, the target space being the first space rotated
by degrees anti-clockwise. For a simple 2d example like this, or even the
3d ones, that’s not
necessarily a better way of seeing it. It’s a philosophical difference rather
than a practical one.
But imagine if the matrix wasn’t square —
that is, it had a different number of rows to the number of columns.
If you had a matrix, it could be used to multiply a matrix of vectors
in 3d space and produce a matrix in 2d space. Remember the rule for matrix multiplication:
a matrix times a matrix will give you a one.
That is actually super-useful;
if you’ve done any 3d graphics, you might remember the
frustum matrix which is used
to convert the 3d points you’re working with to 2d points on a screen. Without
going into too much detail, it allows you to project those 3d points into a 2d
space with a single matrix multiplication.
So: a matrix can be seen as a way to project a vector that represents a
point in -dimensional space into one that represents one in a different -dimensional
space.
What we’re doing in self-attention is taking our -dimensional vectors that make
up the input embedding sequence, then projecting them into three different -dimensional
spaces, and working with the projected versions. Why do we do this? That’s the
question I want to look into in my future post on the “why”, but for now, I think one thing that is fairly
clear is that because these projections are learned as part of the training (remember,
the three matrices we’re using for the projections are made up of trainable weights),
it’s putting some kind of indirection into the mix that the simple dot product attention
that we were using before didn’t have.