This post is a long form essay version of a talk about PyTorch internals, that I gave at the PyTorch NYC meetup on May 14, 2019.
Hi everyone! Today I want to talk about the internals of PyTorch.
This talk is for those of you who have used PyTorch, and thought to yourself, “It would be great if I could contribute to PyTorch,” but were scared by PyTorch’s behemoth of a C++ codebase. I’m not going to lie: the PyTorch codebase can be a bit overwhelming at times. The purpose of this talk is to put a map in your hands: to tell you about the basic conceptual structure of a “tensor library that supports automatic differentiation”, and give you some tools and tricks for finding your way around the codebase. I’m going to assume that you’ve written some PyTorch before, but haven’t necessarily delved deeper into how a machine learning library is written.
The talk is in two parts: in the first part, I’m going to first introduce you to the conceptual universe of a tensor library. I’ll start by talking about the tensor data type you know and love, and give a more detailed discussion about what exactly this data type provides, which will lead us to a better understanding of how it is actually implemented under the hood. If you’re an advanced user of PyTorch, you’ll be familiar with most of this material. We’ll also talk about the trinity of “extension points”, layout, device and dtype, which guide how we think about extensions to the tensor class. In the live talk at PyTorch NYC, I skipped the slides about autograd, but I’ll talk a little bit about them in these notes as well.
The second part grapples with the actual nitty gritty details involved with actually coding in PyTorch. I’ll tell you how to cut your way through swaths of autograd code, what code actually matters and what is legacy, and also all of the cool tools that PyTorch gives you for writing kernels.
The tensor is the central data structure in PyTorch. You probably have a pretty good idea about what a tensor intuitively represents: its an n-dimensional data structure containing some sort of scalar type, e.g., floats, ints, et cetera. We can think of a tensor as consisting of some data, and then some metadata describing the size of the tensor, the type of the elements in contains (dtype), what device the tensor lives on (CPU memory? CUDA memory?)
There’s also a little piece of metadata you might be less familiar with: the stride. Strides are actually one of the distinctive features of PyTorch, so it’s worth discussing them a little more.
A tensor is a mathematical concept. But to represent it on our computers, we have to define some sort of physical representation for them. The most common representation is to lay out each element of the tensor contiguously in memory (that’s where the term contiguous comes from), writing out each row to memory, as you see above. In the example above, I’ve specified that the tensor contains 32-bit integers, so you can see that each integer lies in a physical address, each offset four bytes from each other. To remember what the actual dimensions of the tensor are, we have to also record what the sizes are as extra metadata.
So, what do strides have to do with this picture?
Suppose that I want to access the element at position tensor[1, 0] in my logical representation. How do I translate this logical position into a location in physical memory? Strides tell me how to do this: to find out where any element for a tensor lives, I multiply each index with the respective stride for that dimension, and sum them all together. In the picture above, I’ve color coded the first dimension blue and the second dimension red, so you can follow the index and stride in the stride calculation. Doing this sum, I get two (zero-indexed), and indeed, the number three lives two below the beginning of the contiguous array.
(Later in the talk, I’ll talk about TensorAccessor, a convenience class that handles the indexing calculation. When you use TensorAccessor, rather than raw pointers, this calculation is handled under the covers for you.)
Strides are the fundamental basis of how we provide views to PyTorch users. For example, suppose that I want to extract out a tensor that represents the second row of the tensor above:
Using advanced indexing support, I can just write tensor[1, :] to get this row. Here’s the important thing: when I do this, I don’t create a new tensor; instead, I just return a tensor which is a different view on the underlying data. This means that if I, for example, edit the data in that view, it will be reflected in the original tensor. In this case, it’s not too hard to see how to do this: three and four live in contiguous memory, and all we need to do is record an offset saying that the data of this (logical) tensor lives two down from the top. (Every tensor records an offset, but most of the time it’s zero, and I’ll omit it from my diagrams when that’s the case.)
Question from the talk: If I take a view on a tensor, how do I free the memory of the underlying tensor?
Answer: You have to make a copy of the view, thus disconnecting it from the original physical memory. There’s really not much else you can do. By the way, if you have written Java in the old days, taking substrings of strings has a similar problem, because by default no copy is made, so the substring retains the (possibly very large string). Apparently, they fixed this in Java 7u6.
A more interesting case is if I want to take the first column:
When we look at the physical memory, we see that the elements of the column are not contiguous: there’s a gap of one element between each one. Here, strides come to the rescue: instead of specifying a stride of one, we specify a stride of two, saying that between one element and the next, you need to jump two slots. (By the way, this is why it’s called a “stride”: if we think of an index as walking across the layout, the stride says how many locations we stride forward every time we take a step.)
The stride representation can actually let you represent all sorts of interesting views on tensors; if you want to play around with the possibilities, check out the Stride Visualizer.
Let’s step back for a moment, and think about how we would actually implement this functionality (after all, this is an internals talk.) If we can have views on tensor, this means we have to decouple the notion of the tensor (the user-visible concept that you know and love), and the actual physical data that stores the data of the tensor (called storage):
There may be multiple tensors which share the same storage. Storage defines the dtype and physical size of the tensor, while each tensor records the sizes, strides and offset, defining the logical interpretation of the physical memory.
One thing to realize is that there is always a pair of Tensor-Storage, even for “simple” cases where you don’t really need a storage (e.g., you just allocated a contiguous tensor with torch.zeros(2, 2)).
By the way, we’re interested in making this picture not true; instead of having a separate concept of storage, just define a view to be a tensor that is backed by a base tensor. This is a little more complicated, but it has the benefit that contiguous tensors get a much more direct representation without the Storage indirection. A change like this would make PyTorch’s internal representation a bit more like Numpy’s.
We’ve talked quite a bit about the data layout of tensor (some might say, if you get the data representation right, everything else falls in place). But it’s also worth briefly talking about how operations on the tensor are implemented. At the very most abstract level, when you call torch.mm, two dispatches happen:
The first dispatch is based on the device type and layout of a tensor: e.g., whether or not it is a CPU tensor or a CUDA tensor (and also, e.g., whether or not it is a strided tensor or a sparse one). This is a dynamic dispatch: it’s a virtual function call (exactly where that virtual function call occurs will be the subject of the second half of this talk). It should make sense that you need to do a dispatch here: the implementation of CPU matrix multiply is quite different from a CUDA implementation. It is a dynamic dispatch because these kernels may live in separate libraries (e.g., libcaffe2.so versus libcaffe2_gpu.so), and so you have no choice: if you want to get into a library that you don’t have a direct dependency on, you have to dynamic dispatch your way there.
The second dispatch is a dispatch on the dtype in question. This dispatch is just a simple switch-statement for whatever dtypes a kernel chooses to support. Upon reflection, it should also make sense that we need to a dispatch here: the CPU code (or CUDA code, as it may) that implements multiplication on float is different from the code for int. It stands to reason you need separate kernels for each dtype.
This is probably the most important mental picture to have in your head, if you’re trying to understand the way operators in PyTorch are invoked. We’ll return to this picture when it’s time to look more at code.
Since we have been talking about Tensor, I also want to take a little time to the world of tensor extensions. After all, there’s more to life than dense, CPU float tensors. There’s all sorts of interesting extensions going on, like XLA tensors, or quantized tensors, or MKL-DNN tensors, and one of the things we have to think about, as a tensor library, is how to accommodate these extensions.
Our current model for extensions offers four extension points on tensors. First, there is the trinity three parameters which uniquely determine what a tensor is:
- The device, the description of where the tensor’s physical memory is actually stored, e.g., on a CPU, on an NVIDIA GPU (cuda), or perhaps on an AMD GPU (hip) or a TPU (xla). The distinguishing characteristic of a device is that it has its own allocator, that doesn’t work with any other device.
- The layout, which describes how we logically interpret this physical memory. The most common layout is a strided tensor, but sparse tensors have a different layout involving a pair of tensors, one for indices, and one for data; MKL-DNN tensors may have even more exotic layout, like blocked layout, which can’t be represented using merely strides.
- The dtype, which describes what it is that is actually stored in each element of the tensor. This could be floats or integers, or it could be, for examp
17 Comments
nitrogen99
2019. How much of this is still relevant?
chuckledog
Great article, thanks for posting. Here’s a nice summary of automatic differentiation, mentioned in the article and core to how NN’s are implemented: https://medium.com/@rhome/automatic-differentiation-26d5a993…
vimgrinder
For someone it might help: If you are having trouble reading long articles, try text-to-audio with line highlight. It helps a lot. It has cured my lack of attention.
blahhh456
[dead]
blahhh2525
[flagged]
brutus1979
Is there a video version of this? It seems it is from a talk?
hargun2010
I guess its longer version of slides but not new I saw comment from as far back as 2023, nonetheless good content (resharable).
https://web.mit.edu/~ezyang/Public/pytorch-internals.pdf
smokel
Also interesting in this context is the PyTorch Developer Podcast [1] by the same author. Very comforting to learn about PyTorch internals while doing the dishes.
[1] https://pytorch-dev-podcast.simplecast.com/
bilal2vec
See also dev forum roadmaps [1] and design docs (e.g. [2], [3],[4])
[1]: https://dev-discuss.pytorch.org/t/meta-pytorch-team-2025-h1-…
[2]: https://dev-discuss.pytorch.org/t/pytorch-symmetricmemory-ha…
[3]: https://dev-discuss.pytorch.org/t/where-do-the-2000-pytorch-…
[4]: https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-s…
alexrigler
This is a fun blast from the near past. I helped organize the PyTorch NYC meetup where Ed presented this and still think it's one of the best technical presentations I've seen. Hand drawn slides for the W. Wish I recorded :
aduffy
Edward taught a Programming Languages class I took nearly a decade ago, and clicking through here I immediately recognized the illustrated slides, brought a smile to my face
zcbenz
For learning internals of ML frameworks I recommend reading the source code of MLX: https://github.com/ml-explore/mlx .
It is a modern and clean codebase without legacies, and I could understand most things without seeking external articles.
quotemstr
Huh. I'd have written TORCH_CHECK like this:
Turns out it's possible to write TORCH_CHECK() so that it evaluates the streaming operators only if the check fails. (Check out how glog works.)
pizza
Btw, would anyone have any good resources on using pytorch as a general-purpose graph library? Like stuff beyond the assumption of nets = forward-only (acyclic) digraph
banana_dick_7
[dead]
banana_dick_8
[flagged]
curtisszmania
[dead]