The Unofficial PyTorch Optimization Loop Song
Sing it with me! Do the forward pass, calculate the loss, optimizer zero grad, lossssss backward, optimizer step, step, step.
I've been learning PyTorch.
And if you're learning PyTorch too but have come from frameworks such as Scikit-Learn or TensorFlow, writing your own training and testing loops may be a foreign concept (both of these frameworks have a beautiful model.fit()
system).
To help myself learn I wrote a jingle.
So I'd never forget.
You can see the video version of me singing this song on YouTube.
Torchy, hit the music!
What is an optimization loop?
Before we get to the lyrics, let's define what I mean by optimization loop (or training loop).
In PyTorch, or machine learning in general an optimization loop steps through a dataset with the goal of training a model (often a neural network) to learn patterns in that dataset via a combination of forward propagation (forward pass), backpropagation (backward pass) and gradient descent.
The testing loop has the goal of evaluating the patterns the model has learned in a training loop.
Lyrics
Let's train!
For an epoch in a range
Call model dot train
Do the forward pass
Calculate the loss
Optimizer zero grad
Lossssss backward
Optimizer step step step
Test time!
Call model dot eval
With torch inference mode
Do the forward pass
Calculate the loss
Print out what's happenin'
Let's do it again 'gain 'gain
For another epoch
Call model dot train
Do the forward pass
Calculate the loss
Optimizer zero grad
Lossssss backward
Optimizer step step step
Test time!
Call model dot eval
With torch inference mode
Do the forward pass
Calculate the loss
Print out what's happenin'
Let's do it again 'gain 'gain
For another epoch
Call model dot train
Do the forward pass
Calculate the loss
Optimizer zero grad
Lossssss backward
Optimizer step step step
Test time!
Call model dot eval
With torch inference mode
Do the forward pass
Calculate the loss
Print out what's happenin'
Keep going if you want
But don't forget to save save save
Explaining the lyrics
The lyrics describe what happens in a training loop (forward pass on the training data, loss calculation, zeroing the optimizer gradients, performing backpropagation and gradient descent by stepping the optimizer).
As well as what happens in the testing loop (forward pass, loss calculation on the test data – data the model has never seen before).
They're called loops because they step through the various samples in the data (training and test sets).
Because that's the whole idea of machine learning models, look at samples of data and perform numerical calculations to find patterns in that data.
Order wise, you'll want to perform the forward pass before calculating the loss.
And you'll want to perform backpropagation (loss.backward()
) before stepping the optimizer (optimizer.step()
).
Making training and testing functions
Of course, you can functionize the code above.
And you should.
It'll prevent you going mad and learning jingles like this.
But you only have to learn it once.
Then you functionize the code above.
And use it again 'gain 'gain.
Materials and resources
This song came out of finding a way to remember and teach the steps in a PyTorch optimization loop.
Many of the resources for learning PyTorch never really discussed each step or the order of the steps.
So I made my own.
- Learn PyTorch for Deep Learning GitHub – a resource for learning PyTorch code-first from the fundamentals.
- Learn PyTorch for Deep Learning Book – an online book version of the code materials above.
- Code for writing the jingle above – a small functioning version of the PyTorch code used to write this jingle.
The two main algorithms that enable a neural network to learn are backpropagation and gradient descent.
PyTorch implements the mechanics of those behind the scenes.
But if you'd like to learn more about each, I'd recommend the following:
- Gradient descent, how neural networks learn video by 3blue1brown
- What is backpropagation really doing? video by 3blue1brown
PS as you might have guessed, this song is not endorsed by PyTorch whatsoever, I just made it for fun, hence unofficial.