The Unofficial PyTorch Optimization Loop Song

Sing it with me! Do the forward pass, calculate the loss, optimizer zero grad, lossssss backward, optimizer step, step, step.

man standing in computer studio with microphone in hand and screens behind singing a song

I've been learning PyTorch.

And if you're learning PyTorch too but have come from frameworks such as Scikit-Learn or TensorFlow, writing your own training and testing loops may be a foreign concept (both of these frameworks have a beautiful model.fit() system).

To help myself learn I wrote a jingle.

So I'd never forget.

You can see the video version of me singing this song on YouTube.

Torchy, hit the music!

What is an optimization loop?

Before we get to the lyrics, let's define what I mean by optimization loop (or training loop).

In PyTorch, or machine learning in general an optimization loop steps through a dataset with the goal of training a model (often a neural network) to learn patterns in that dataset via a combination of forward propagation (forward pass), backpropagation (backward pass) and gradient descent.

The testing loop has the goal of evaluating the patterns the model has learned in a training loop.

Lyrics

Let's train!
For an epoch in a range
Call model dot train
Do the forward pass
Calculate the loss
Optimizer zero grad
Lossssss backward
Optimizer step step step

Test time!
Call model dot eval
With torch inference mode
Do the forward pass
Calculate the loss

Print out what's happenin'

Let's do it again 'gain 'gain

For another epoch
Call model dot train
Do the forward pass
Calculate the loss
Optimizer zero grad
Lossssss backward
Optimizer step step step

Test time!
Call model dot eval
With torch inference mode
Do the forward pass
Calculate the loss

Print out what's happenin'

Let's do it again 'gain 'gain

For another epoch
Call model dot train
Do the forward pass
Calculate the loss
Optimizer zero grad
Lossssss backward
Optimizer step step step

Test time!
Call model dot eval
With torch inference mode
Do the forward pass
Calculate the loss

Print out what's happenin'

Keep going if you want
But don't forget to save save save

Explaining the lyrics

The lyrics describe what happens in a training loop (forward pass on the training data, loss calculation, zeroing the optimizer gradients, performing backpropagation and gradient descent by stepping the optimizer).

steps in a PyTorch training loop
Steps in a PyTorch training loop. Source: Learn PyTorch for Deep Learning Book Chapter 01.

As well as what happens in the testing loop (forward pass, loss calculation on the test data – data the model has never seen before).

steps in a PyTorch testing loop
Steps in a PyTorch testing loop (notice the lack of backpropagation via loss.backward() and no gradient descent via optimizer.step(), this is because these two steps aren't needed for evaluation/testing/making inference). Source: Learn PyTorch for Deep Learning Book Chapter 01.

They're called loops because they step through the various samples in the data (training and test sets).

Because that's the whole idea of machine learning models, look at samples of data and perform numerical calculations to find patterns in that data.

Order wise, you'll want to perform the forward pass before calculating the loss.

And you'll want to perform backpropagation (loss.backward()) before stepping the optimizer (optimizer.step()).

Making training and testing functions

Of course, you can functionize the code above.

And you should.

It'll prevent you going mad and learning jingles like this.

But you only have to learn it once.

Then you functionize the code above.

And use it again 'gain 'gain.

functionizing PyTorch training and testing loop
Rather than remembering all of the steps, you can functionize the training and testing loop steps (after coding them once or twice) so you can reuse them for various projects and datasets.

Materials and resources

This song came out of finding a way to remember and teach the steps in a PyTorch optimization loop.

Many of the resources for learning PyTorch never really discussed each step or the order of the steps.

So I made my own.

The two main algorithms that enable a neural network to learn are backpropagation and gradient descent.

PyTorch implements the mechanics of those behind the scenes.

But if you'd like to learn more about each, I'd recommend the following:


PS as you might have guessed, this song is not endorsed by PyTorch whatsoever, I just made it for fun, hence unofficial.