In this post, we will dive into the internals of Large Language Models (LLMs) to gain a practical understanding of how they work. To aid us in this exploration, we will be using the source code of llama.cpp, a pure c++ implementation of Meta’s LLaMA model.
Personally, I have found llama.cpp to be an excellent learning aid for understanding LLMs on a deeper level. Its code is clean, concise and straightforward, without involving excessive abstractions.
We will use this commit version.
We will focus on the inference aspect of LLMs, meaning: how the already-trained model generates responses based on user prompts.
This post is written for engineers in fields other than ML and AI who are interested in better understanding LLMs.
It focuses on the internals of an LLM from an engineering perspective, rather than an AI perspective.
Therefore, it does not assume extensive knowledge in math or deep learning.
Throughout this post, we will go over the inference process from beginning to end, covering the following subjects (click to jump to the relevant section):
- Tensors: A basic overview of how the mathematical operations are carried out using tensors, potentially offloaded to a GPU.
- Tokenization: The process of splitting the user’s prompt into a list of tokens, which the LLM uses as its input.
- Embedding: The process of converting the tokens into a vector representation.
- The Transformer: The central part of the LLM architecture, responsible for the actual inference process. We will focus on the self-attention mechanism.
- Sampling: The process of choosing the next predicted token. We will explore two sampling techniques.
- The KV cache: A common optimization technique used to speed up inference in large prompts. We will explore a basic kv cache implementation.
By the end of this post you will hopefully gain an end-to-end understanding of how LLMs work. This will enable you to explore more advanced topics, some of which are detailed in the last section.
Table of contents
Open Table of contents
High-level flow from prompt to output
As a large language model, LLaMA works by taking an input text, the “prompt”, and predicting what the next tokens, or words, should be.
To illustrate this, we will use the first sentence from the Wikipedia article about Quantum Mechanics as an example.
Our prompt is:
Quantum mechanics is a fundamental theory in physics that
The LLM attempts to continue the sentence according to what it was trained to believe is the most likely continuation.
Using llama.cpp, we get the following continuation:
provides insights into how matter and energy behave at the atomic scale.
Let’s begin by examining the high-level flow of how this process works.
At its core, an LLM only predicts a single token each time.
The generation of a complete sentence (or more) is achieved by repeatedly applying the LLM model to the same prompt, with the previous output tokens appended to the prompt.
This type of model is referred to as an autoregressive model.
Thus, our focus will primarily be on the generation of a single token, as depicted in the high-level diagram below:

Following the diagram, the flow is as follows:
- The tokenizer splits the prompt into a list of tokens. Some words may be split into multiple tokens, based on the model’s vocabulary. Each token is represented by a unique number.
- Each numerical token is converted into an embedding. An embedding is a vector of fixed size that represents the token in a way that is more efficient for the LLM to process. All the embeddings together form an embedding matrix.
- The embedding matrix serves as the input to the Transformer. The Transformer is a neural network that acts as the core of the LLM. The Transformer consists of a chain of multiple layers. Each layer takes an input matrix and performs various mathematical operations on it using the model parameters, the most notable being the self-attention mechanism. The layer’s output is used as the next layer’s input.
- A final neural network converts the output of the Transformer into logits. Each possible next token has a corresponding logit, which represents the probability that the token is the “correct” continuation of the sentence.
- One of several sampling techniques is used to choose the next token from the list of logits.
- The chosen token is returned as the output. To continue generating tokens, the chosen token is appended to the list of tokens from step (1), and the process is repeated. This can be continued until the desired number of tokens is generated, or the LLM emits a special end-of-stream (EOS) token.
In the following sections, we will delve into each of these steps in detail.
But before doing that, we need to familiarize ourselves with tensors.
Understanding tensors with ggml
Tensors are the main data structure used for performing mathemetical operations in neural networks.
llama.cpp uses ggml, a pure C++ implementation of tensors, equivalent to PyTorch or Tensorflow in the Python ecosystem.
We will use ggml to get an understanding of how tensors operate.
A tensor represents a multi-dimensional array of numbers. A tensor may hold a single number, a vector (one-dimensional array), a matrix (two-dimensional array) or even three or four dimensional arrays. More than is not needed in practice.
It is important to distinguish between two types of tensors.
There are tensors that hold actual data, containing a multi-dimensional array of numbers.
On the other hand, there are tensors that only represent the result of a computation between one or more other tensors, and do not hold data until actually computed.
We will explore this distinction soon.
Basic structure of a tensor
In ggml tensors are represented by the ggml_tensor
struct. Simplified slightly for our purposes, it looks like the following:
// ggml.h
struct ggml_tensor {
enum ggml_type type;
enum ggml_backend backend;
int n_dims;
// number of elements
int64_t ne[GGML_MAX_DIMS];
// stride in bytes
size_t nb[GGML_MAX_DIMS];
enum ggml_op op;
struct ggml_tensor * src[GGML_MAX_SRC];
void * data;
char name[GGML_MAX_NAME];
};
The first few fields are straightforward:
type
contains the primitive type of the tensor’s elements. For example,GGML_TYPE_F32
means that each element is a 32-bit floating point number.enum
contains whether the tensor is CPU-backed or GPU-backed. We’ll come back to this bit later.n_dims
is the number of dimensions, which may range from 1 to 4.ne
contains the number of elements in each dimension. ggml is row-major order, meaning thatne[0]
marks the size of each row,ne[1]
of each column and so on.
nb
is a bit more sophisticated. It contains the stride: the number of bytes between consequetive elements in each dimension. In the first dimension this will be the size of the primitive element. In the second dimension it will be the row size times the size of an element, and so on. For example, for a 4x3x2 tensor:

The purpose of using a stride is to allow certain tensor operations to be performed without copying any data.
For example, the transpose operation on a two-dimensional that turns rows into columns can be carried out by just flipping ne
and nb
and pointing to the same underlying data:
// ggml.c (the function was slightly simplified).
struct ggml_tensor * ggml_transpose(
struct ggml_context * ctx,
struct ggml_tensor * a) {
// Initialize `result` to point to the same data as `a`
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
result->ne[0] = a->ne[1];
result->ne[1] = a->ne[0];
result->nb[0] = a->nb[1];
result->nb[1] = a->nb[0];
result->op = GGML_OP_TRANSPOSE;
result->src[0] = a;
return result;
}
In the above function, result
is a new tensor initialized to point to the same multi-dimensional array of numbers as the source tensor a
.
By exchanging the dimensions in ne
and the strides in nb
, it performs the transpose operation without copying any data.
Tensor operations and views
As mentioned before, some tensors hold data, while others represent the theoretical result of an operation between other tensors.
Going back to struct ggml_tensor
:
op
may be any supported operation between tensors. Setting it toGGML_OP_NONE
marks that the tensor holds data. Other values can mark an operation. For example,GGML_OP_MUL_MAT
means that this tensor does not hold data, but only represents the result of matrix multiplication between two other tensors.src
is an array of pointers to the tensors between which the operation is to be taken. For example, ifop == GGML_OP_MUL_MAT
, thensrc
will contain pointers to the two tensors to be multiplied. Ifop == GGML_OP_NONE
, thensrc
will be empty.data
points to the actual tensor’s data, orNULL
if this tensor is an operation. It may also point to another tensor’s data, and then it’s known as a view. For example, in theggml_transpose()
function above, the resulting tensor is a view of the original, just with flipped dimensions and strides.data
points to the same location in memory.
The matrix multiplication function illustrates these concepts well:
// ggml.c (simplified and commented)
struct ggml_tensor * ggml_mul_mat(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
// Check that the tensors' dimensions permit matrix multiplication.
GGML_ASSERT(ggml_can_mul_mat(a, b));
// Set the new tensor's dimensions
// according to matrix multiplication rules.
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
// Allocate a new ggml_tensor.
// No data is actually allocated except the wrapper struct.
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
// Set the operation and sources.
result->op = GGML_OP_MUL_MAT;
result->src[0] = a;
result->src[1] = b;
return result;
}
In the above function, result
does not contain any data. It is merely a representation of the theoretical result of multiplying a
and b
.
Computing tensors
The ggml_mul_mat()
function above, or any other tensor operation, does not calculate anything but just prepares the tensors for the operation.
A different way to look at it is that it builds up a computation graph where each tensor operation is a node, and the operation’s sources are the node’s children.
In the matrix multiplication scenario, the graph has a parent node with operation GGML_OP_MUL_MAT
, along with two children.
As a real example from llama.cpp, the following code implements the self-attention mechanism which is part of each Transformer layer and will be explored more in-depth later:
// llama.cpp
static struct ggml_cgraph * llm_build_llama(/* ... */) {
// ...
// K,Q,V are tensors initialized earlier
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scale is a single-number tensor initialized earlier.
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
// ...
}
The code is a series of tensor operations and builds a computation graph that is identical to the one described in the original Transformer paper:
In order to actually compute the result tensor (here it’s KQV
) the following steps are taken:
- Data is loaded into each leaf tensor’s
data
pointer. In the example the leaf tensors areK
,Q
andV
. - The output tensor (
KQV
) is converted to a computation graph usingggml_build_forward()
. This function is relatively straightforward and orders the nodes in a depth-first order.1 - The computation graph is run using
ggml_graph_compute()
, which runsggml_compute_forward()
on each node in a depth-first order.ggml_compute_forward()
does the heavy lifting of calculations. It performs the mathetmatical operation and fills the tensor’sdata
pointer with the result. - At the end of this process, the output tensor’s
data
pointer points to the final result.
Offloading calculations to the GPU
Many tensor operations like matrix addition and multiplication can be calculated on a GPU much more efficiently due to its high parallelism.
When a GPU is available, tensors can be marked with tensor->backend = GGML_BACKEND_GPU
.
In this case, ggml_compute_forward()
will attempt to offload the calculation to the GPU.
The GPU will perform the tensor operation, and the result will be stored on the GPU’s memory (and not in the data
pointer).
Consider the self-attention omputation graph shown before. Assuming that K
,Q
,V
are fixed tensors, the computation can be offloaded to the GPU:
The process begins by copying K
,Q
,V
to the GPU memory.
The CPU then drives the computation forward tensor-by-tensor, but the actual mathematical operation is offloaded to the GPU.
When the last operation in the graph ends, the result tensor’s data is copied back from the GPU memory to the CPU memory.
Note: In a real transformer K,Q,V
are not fixed and KQV
is not the final output. More on that later.
With this understanding of tensors, we can go back to the flow of LLaMA.
Tokenization
The first step in inference is tokenization.
Tokenization is the process of splitting the prompt into a list of shorter strings known as tokens.
The tokens must be part of the model’s vocabulary, which is the list of tokens the LLM was trained on.
LLaMA’s vocabulary, for example, consists of 32k tokens and is distributed as part of the model.
For our example prompt, the tokenization splits the prompt into eleven tokens (spaces are replaced with the special meta symbol ’▁’ (U+2581)):
|Quant|um|▁mechan|ics|▁is|▁a|▁fundamental|▁theory|▁in|▁physics|▁that|
For tokenization, LLaMA uses the SentencePiece tokenizer with the byte-pair-encoding (BPE) algorithm.
This tokenizer is interesting because it is subword-based, meaning that words may be represented by multiple tokens. In our prompt, for example, ‘Quantum’ is split into ‘Quant’ and ‘um’. During training, when the vocabulary is derived, the BPE algorithm ensures that common words are included in the vocabulary as a single token, while rare words are broken down into subwords. In the example above, the word ‘Quantum’ is not part of the vocabulary, but ‘Quant’ and ‘um’ are as two separate tokens. White spaces are not treated specially, and are included in the tokens themselves as the meta character if they are common enough.
Subword-based tokenization is powerful due to multiple reasons:
- It allows the LLM to learn the meaning of rare words like ‘Quantum’ while keeping the vocabulary size relatively small by representing common suffixes and prefixes as separate tokens.
- It learns language-specific features without employing language-specific tokenization schemes. Quoting from the BPE-encoding paper:
consider compounds such as the German Abwasser|behandlungs|anlange ‘sewage water treatment plant’, for which a segmented, variable-length representation is intuitively more appealing than encoding the word as a fixed-length vector.
- Similarly, it is also useful in parsing code. For example, a variable named
model_size
will be tokenized intomodel|_|size
, allowing the LLM to “understand” the purpose of the variable (yet another reason to give your variables indicative names!).
In llama.cpp, tokenization is performed using the llama_tokenize()
function. This function takes the prompt string as input and returns a list of tokens, where each token is represented by an integer:
// llama.h
typedef int llama_token;
// common.h
std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
// the prompt
const std::string & text,
bool add_bos);
The tokenization process starts by breaking down the prompt into single-character tokens. Then, it iteratively tries to merge each two consequetive tokens into a larger one, as long as the merged token is part of the vocabulary. This ensures that the resulting tokens are as large as possible. For our example prompt, the tokenization steps are as follows:
Q|u|a|n|t|u|m|▁|m|e|c|h|a|n|i|c|s|▁|i|s|▁a|▁|f|u|n|d|a|m|e|n|t|a|l|
Qu|an|t|um|▁m|e|ch|an|ic|s|▁|is|▁a|▁f|u|nd|am|en|t|al|
Qu|ant|um|▁me|chan|ics|▁is|▁a|▁f|und|am|ent|al|
Quant|um|▁mechan|ics|▁is|▁a|▁fund|ament|al|
Quant|um|▁mechan|ics|▁is|▁a|▁fund|amental|
Quant|um|▁mechan|ics|▁is|▁a|▁fundamental|
Note that each intermediate step consists of valid tokenization according to the model’s vocabulary. However, only the last one is used as the input to the LLM.
Embeddings
The tokens are used as input to LLaMA to predict the next token.
The key function here is the llm_build_llama()
function:
// llama.cpp (simplified)
static struct ggml_cgraph * llm_build_llama(
llama_context & lctx,
const llama_token * tokens,
int n_tokens,
int n_past);
This function takes a list of tokens represented by the tokens
and n_tokens
parameters as input.
It then builds the full tensor computation graph of LLaMA, and returns it as a struct ggml_cgraph
.
No computation actually takes place at this stage.
The n_past
parameter, which is currently set to zero, can be ignored for now.
We will revisit it later when discussing the kv cache.
Beside the tokens, the function makes use of the model weights, or model parameters.
These are fixed tensors learned during the LLM training process and included as part of the model.
These model parameters are pre-loaded into lctx
before the inference begins.
We will now begin exploring the computation graph structure.
The first part of this computation graph involv