Introduction
So I recently completed my PhD in Mathematics from the University of Oxford. (Hurrah! It was so much fun.)
In 2-and-a-bit years I wrote 12 papers, received 4139 GitHub stars, got 3271 Twitter followers, authored 1 textbook – doing double-duty as my thesis – and got the coveted big-tech job-offer.
On Neural Differential Equations
If you’re interested in a textbook on Neural Differential Equations with a smattering of scientific computing, then my thesis is available online.
Quite a few folks seem to have looked at this, and messaged me – mostly on Twitter or Mastodon – asking for advice on how to achieve success in a machine learning PhD?
Each time my answer is: Just Know Stuff.
Now, I don’t think “Just Know Stuff” is a terribly controversial opinion – undergraduate classes are largely based around imparting knowledge; the first year of a new PhD’s life is usually spent reading up on the literature – but from the number of questions I get it would seem that this is something worth restating.
Know your field inside-out. Know as much about adjacent fields (in math, statistics, …) as you can. Don’t just know how to program; know how to do software development. Know the mathematical underpinnings your work is built upon. And so on and so on. Indeed: possessing a technical depth of knowledge is how you come up with new ideas and learn to recognise bad ones.
This does beg the follow-up question: what is worth knowing? What is worth learning?
And the answer to that is what I started repeating to all of you folks messaging me. But then that started taking up way too much time, mostly because I write way too much. So now I’m writing this post instead – this way I’ll only have to write way too much only once!
The following is a highly personal list of the things I found to be useful during my PhD, and which I think are of a broad enough appeal that they probably represent a reasonable core of knowledge for those just starting an ML PhD. The following is by no mean exhaustive, and you should certainly expect to add a lot of domain-specific stuff on top of this. But perhaps the following is a useful starting point.
This list is targeted towards early-stage PhD or pre-PhD students. If you’re late-stage and reading through this thinking “yeah, of course I know this stuff”, then well… that’s the point!
Machine learning
- Know both forward- and reverse-mode autodifferentiation. (Nice reference: Appendix A of my thesis. ;) )
- Write some custom gradient operations in both PyTorch and JAX.
- Look up “optimal Jacobian accumulation” on the autodifferentiation page on Wikipedia.
- Optional: learn how JAX derives reverse-mode autoderivatives by combining partial evaluation, forward-mode-autodifferentiation, and transposition.
- Optional: why is the computation of a divergence computationally expensive using autodifferentiation? Learn Hutchinson’s trace estimator. (Why is that efficient?) Learn the Hutch++ trace estimator. (Which is surprisingly poorly known.)
- What is meant by Strassen’s algorithm? Learn how matrix multiplies are actually done in practice. Learn Winograd convolutions.
- Write your own implementation of a convolutional layer. Write your own implementation of multihead attention.
- Know the universal approximation theorem. (I recommend Leshno et al. 1993 or Pinkus 1999 as references. Not the much-more-frequently cited references to older results by Cybenko or Hornik, who give much weaker results.)
- Optional: if you’re really keen then look up the modern line of work on alternate universal approximation theorems.
- Learn the basics of graph neural networks. (E.g. what is oversmoothing?) How do these generalise CNNs?
- Learn modern Transformer architectures. Look up recent papers (or implementations) to see some of the more common architectural tricks. Build a toy implementation.
- Learn U-Nets. Build a toy implementation.
- Know how residual networks are discretised ordinary differential equations.
- know how Gated Recurrent Units (GRUs) are also discretised differential equations.
- Know how stochastic gradient descent is also a discretised differential equation too! (Yes, including the “stochastic”: that’s a Monte-Carlo discretisation of an expectation.) These are gradient flows.
- Know what is meant by the manifold hypothesis.
- Learn the basics of policy gradients. Implement PPO to solve cart-pole. (Spinning up is a great resource.)
- Learn KL divergence, Wasserstein distance, MMD distance.
- Learn normalising flows, VAEs, WGANs, score-based diffusion models. Implement a basic score-based diffusion from scratch.
- Try the basics of distributed training of a model. (Over multiple GPUs; multiple computers.) Start with
jax.pmap
. - Know how to do hyperparameter optimisation via Bayesian optimisation. My favourite library for this is Ax.
- Optional: try doing this in a distributed fashion, with a main thread sending hyperparameter jobs to different machines, and receiving results back. (The “Service API” for Ax is the appropriate tool here.)
- Learn the formulae for Adadelta, Adam, etc. What were the innovations for each optimiser? (Momentum, second moments, …) What are some of the newer ones that are now being used (Adabelief, RAdam, NAdamW, … etc. etc. – this is a flavour-of-the-month kind of field.)
- Learn why we use first-order optimisation techniques (SGD and friends), rather t