tldr; techniques to speed up training and inference of LLMs to use large context window up to 100K input tokens during training and inference: ALiBi positional embedding, Sparse Attention, FlashAttention, Multi-Query attention, Conditional computation, and 80GB A100 GPUs.
Recently there were several announcements about new Large Language Models (LLMs) that can consume an extremely large context window, such as 65K tokens (MPT-7B-StoryWriter-65k+ by MosaicML) or even 100K tokens (Introducing 100K Context Windows by Antropic). In the Palm-2 technical report, Google doesn’t reveal the context size but mentions that they “increase the context length of the model significantly.”
For comparison, the current GPT-4 model can work with the context length of 32K input tokens. And most of the open-source LLMs have a context length of 2K tokens.
That’s impressive since having such a large context length means the prompt can be literally a size of a book. The Great Gatsby is 72K tokens, 210 pages, and 6 hours of reading at a 1.7 min/page speed. So the model can scan and keep this amount of “custom” information to process queries!
I was trying to wrap my head around how that is technically possible, so in this blog post, I collect scattered pieces of information (this thread was the first clue) and cover the following:
- Why context length matters and why it can be a game changer
- What are the main limitations in the original Transformer architecture when working with large context lengths
- The computational complexity of the transformer architecture
- What optimization techniques currently exist to speed up the transformer and increase the context length up to 100K
Here and later, we use the “context length,” “context window,” and “the number of input tokens” interchangeably, denoting them as n.
The blog post is a bit long, so there is a summary with the main points and tricks:
- 1st problem is the quadratic time and space complexity of attention layer computations w.r.t. the number of input tokens n.
- When the embedding size d > n, the 2nd problem is the quadratic time complexity of linear layers w.r.t. embedding size d.
- 3rd problem is Positional Sinusoidal Embedding used in the original architecture.
- In Transformer architecture, the shapes of learnable matrix weights are agnostic to the number of input tokens n.
- So, a trained Transformer in 2K context lengths can consume tokens of any length, even 100K. But the model will not produce meaningful results on 100K tokens during inference if it isn’t trained on 100K.
- Training the vanilla Transformer on a giant corpus and only on a large context length is unfeasibly expensive due to the quadratic complexity w.r.t to n and d. LLaMA on 2K context length was estimated to be trained for ~$3M. Thus, LLaMA on 100K would cost ~$150M.
- One option is to train the model on 2K tokens context and then fine-tune it in longer contexts (for example, 65K). But it won’t work with the original Transformer because of the Positional Sinusoidal Encoding.
- [Trick #1] To address this, remove Positional Sinusoidal Encoding and use ALiBi, a simple and elegant positional embedding that doesn’t hurt accuracy. Then you can train on 2K and fine-tune on 100K.
- [Trick #2] You don’t need to calculate attention scores between all tokens. Some tokens are more important than others, so Sparse Attention can be used. It will speed up both training and inference.
- [Trick #3] Flash Attention efficiently implements the attention layer for GPU. It uses tiling and avoids materialization of big intermediate matrices (n, n) that doesn’t fit into GPU SRAM. It will speed up both training and inference.
- [Trick #4] Multi-Query attention instead of Multi-Head attention. That means you share weights across all heads when linearly projecting K and V. It dramatically speeds up incremental inference.
- [Trick #5] Conditional computation avoids applying all model parameters to all tokens from the input sequence. CoLT5 applies heavy computations only to the most important tokens and processes the rest of the tokens with a lighter version of layers. It will speed up both training and inference.
- [Trick #6] To fit a large context, you need a lot of RAM in GPU, so people use 80GB A100 GPUs.
To sum up, the more you speed up the training and inference, the larger the context length you can use.
Let’s now discuss all these points in more detail.
Context length is one of the critical limitations of LLMs. And increasing it to already 100K is an incredible achievement (I wonder how this statement will look in a year).
One of the important use cases where people want to apply LLMs is “dropping a large pile of custom data into an LLM” (documents related to the company or a particular problem, various heterogeneous texts, etc) and asking questions about this particular data, not some abstract data from the internet that LLM saw during training.
To overcome this limitation now, people do various things:
- Trying summarization techniques and sophisticated chained prompts
- Maintaining vector databases to keep embeddings for custom documents and then “searching” across them by some similarity metric
- Fine-tuning the LLM with custom data when possible (not all commercial LLMs allow that, and it is not an obvious task for open-source LLMs)
- Developing custom smaller LLMs for this particular data (again, not an obvious task)
Having a large context length allows an already powerful LLM (that saw the whole internet) to look at your context and data and interact with you on a completely different level with a higher personalization. And all these without changing the model’s weights and doing your “training” on the fly, “in memory.” And overall, a large context window brings more accuracy, fluency, and creativity to the model.
One analogy here might be computer RAM, where the operating system keeps the real-time context of all your applications. With a substantial context length, LLM can be like a “reasoning computer,” keeping a lot of user context.
It’s important to note that in Transformer architecture, the shapes of all learnable matrix weights are not dependent on the number of input tokens n. All trainable parameters (embedding lookup, projection layers, softmax layer, and attention layers) do not depend on input length and must handle variable-length inputs. That’s great that we have this out-of-the-box property of the architecture.
That means if you trained a Transformer model with a context length of 2K, you could infer token sequences of any size. The only problem is that the model will not produce meaningful results on 100K tokens during inference if it isn’t trained on 100K context length. In this case, the training data distribution will be far from the one during the inference, so the model will fail as any machine learning model in this setup.
One solution to train a large context length Transformer is to train it in two stages: train the base model on 2K tokens context length and then continue training (fine-tuning) on longer contexts (for example, 65K or 100K). That’s precisely what MosaicML did. But the problem is that it won’t work with the original Transformer architecture, so you need to use some tricks (see Trick #1 later in the post).
Recap on Multi-Head Attention
Challenges of a large context length are related to the computational complexity of the transformer architecture. To discuss the complexity, first, let’s recap how the attention layer works.
Q — queries, K — keys and V — values, notations from the paper relating to the information retrieval, where you insert a “query” to the system and search the closest “key”
n —the input number of tokens
d — text embedding dimension
h — the number of attention heads
k— linear projection size for Q and K
v — linear projection size for V
Multi-Head Attention:
- We have a lookup Embedding layer that, for a given token, returns a vector of size (1, d). Thus, for a sequence of n tokens, we get the text embeddings matrix X of size (n, d). Then we sum it up with the Positional Sinusoidal Embedding.
- The Multi-Head Attention layer aims to calculate the new embedding for this sequence of tokens that can be considered as an original text encoding X but weighted (1) by relative importance between tokens with regards to the context and (2) by relative positions of tokens.
- We process this embedding matrix X (n, d) in parallel with h attention layers (heads). To get Q, K, and V for all attention heads, you linearly project X to k, k, and v dimensions, respectively. You do it by multiplying X by h matrices of shape (d, k), (d, k), and (d, v). You can think about it as multiplying (n, d) by (h, d, k), (h, d, k), and (h, d, v).
- Attention Heads return h attention scores matrices of size (n, v). Then we concatenate pieces from all heads (n, h*v) and linearly project it for the next steps.
Scaled Dot-Product Attention:
Now, let’s zoom in on one attention head.
- Q, K, V are 3 linear projections of X of size (n, k), (n, k), and (n, v) obtained by multiplying to learnable weights separate for each head.
- We get attention scores by calculating the distance (dot product) between the Q and the K (transposed). You multiply matrix (n, k) by (k, n) and get the matrix (n, n). Then we multiply it by the mask matrix to zero down some of the tokens (required in the decoder). Then we scale it and apply softmax to be from 0 to 1. This way, we get the matrix of shape (n, n) with n_ij – a relative attention score from 0 to 1 between the i-th and j-th token that shows how “close” these tokens are in this particular context of length n.
- Then we multiply this attention score matrix (n, n) by “values” V of size (n, d) to get the text embedding weighted by these relative attention scores.