Why Batch Normalization works

Shivam Chhetry
10 min readMar 16, 2022

--

Batch normalization was introduced in Sergey Ioffe’s and Christian Szegedy’s 2015 paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. The idea is that, instead of just normalizing the inputs to the network, we normalize the inputs to layers within the network.

It’s called batch normalization because during training, we normalize each layer’s inputs by using the mean and variance of the values in the current batch.

The change in the distributions of layers inputs presents as problem because the layers need to continuously adapt to the new distribution. When the input distribution to a learning system changes, it is said to experience covariate shift.

Batch Normalization in PyTorch

This section of the code shows you one way to add batch normalization to a neural network built in PyTorch.

The following code import the packages we need in the notebook and load the MNIST dataset to use in our experiments.

Visualising the data

Neural network classes for testing

The following class, NeuralNet, allows us to create identical neural networks with and without batch normalization to compare. The code is heavily documented, but there is also some additional discussion later. You do not need to read through it all before going through the rest of the notebook, but the comments within the code blocks may answer some of your questions.

About the code:

We are defining a simple MLP for classification; this design choice was made to support the discussion related to batch normalization and not to get the best classification accuracy.

(Important) Model Details

There are quite a few comments in the code, so those should answer most of your questions. However, let’s take a look at the most important lines.

We add batch normalization to layers inside the __init__ function. Here are some important points about that code:

  1. Layers with batch normalization do not include a bias term.
  2. We use PyTorch’s BatchNorm1d function to handle the math. This is the function you use to operate on linear layer outputs; you’ll use BatchNorm2d for 2D outputs like filtered images from convolutional layers.
  3. We add the batch normalization layer before calling the activation function.

Create two different models for testing

  • net_batchnorm is a linear classification model with batch normalization applied to the output of its hidden layers
  • net_no_norm is a plain MLP, without batch normalization

Besides the normalization layers, everthing about these models is the same.

net_batchnorm = NeuralNet(use_batch_norm=True)
net_no_norm = NeuralNet(use_batch_norm=False)
print(net_batchnorm)
print()
print(net_no_norm)

Training

The below train function will take in a model and some number of epochs. We'll use cross entropy loss and stochastic gradient descent for optimization. This function returns the losses, recorded after each epoch, so that we can display and compare the behavior of different models.

.train() mode

Note that we tell our model whether or not it should be in training mode, model.train(). This is an important step because batch normalization has different behavior during training on a batch or testing/evaluating on a larger dataset.

Comparing Models

In the below cells, we train our two different models and compare their trainining loss over time.

# batchnorm model losses
# this may take some time to train
losses_batchnorm = train(net_batchnorm)
# *no* norm model losses
# you should already start to see a difference in training losses
losses_no_norm = train(net_no_norm)

Let’s see a difference in with batchnorm or without batchnorm models in graph

# compare
fig, ax = plt.subplots(figsize=(12,8))
#losses_batchnorm = np.array(losses_batchnorm)
#losses_no_norm = np.array(losses_no_norm)
plt.plot(losses_batchnorm, label='Using batchnorm', alpha=0.5)
plt.plot(losses_no_norm, label='No norm', alpha=0.5)
plt.title("Training Losses")
plt.legend()
Batchnorm v/s No Norm

Testing

You should see that the model with batch normalization, starts off with a lower training loss and, over ten epochs of training, gets to a training loss that is noticeably lower than our model without normalization.

Next, let’s see how both these models perform on our test data! Below, we have a function test that takes in a model and a parameter train (True or False) which indicates whether the model should be in training or evaulation mode. This is for comparison purposes, later. This function will calculate some test statistics including the overall test accuracy of a passed in model.

Training and Evaluation Mode

Setting a model to evaluation mode is important for models with batch normalization layers!

  • Training mode means that the batch normalization layers will use batch statistics to calculate the batch norm.
  • Evaluation mode, on the other hand, uses the estimated population mean and variance from the entire training set, which should give us increased performance on this test data!
# test batchnorm case, in *train* mode
test(net_batchnorm, train=True)
Test Loss: 0.086881Test Accuracy of 0: 98% (967/980)
Test Accuracy of 1: 99% (1126/1135)
Test Accuracy of 2: 96% (999/1032)
Test Accuracy of 3: 97% (989/1010)
Test Accuracy of 4: 96% (952/982)
Test Accuracy of 5: 96% (864/892)
Test Accuracy of 6: 97% (933/958)
Test Accuracy of 7: 96% (990/1028)
Test Accuracy of 8: 96% (939/974)
Test Accuracy of 9: 95% (966/1009)
Test Accuracy (Overall): 97% (9725/10000)
# test batchnorm case, in *evaluation* mode
test(net_batchnorm, train=False)
Test Loss: 0.073484Test Accuracy of 0: 98% (968/980)
Test Accuracy of 1: 99% (1127/1135)
Test Accuracy of 2: 97% (1005/1032)
Test Accuracy of 3: 98% (991/1010)
Test Accuracy of 4: 97% (955/982)
Test Accuracy of 5: 97% (874/892)
Test Accuracy of 6: 97% (932/958)
Test Accuracy of 7: 96% (995/1028)
Test Accuracy of 8: 96% (940/974)
Test Accuracy of 9: 97% (983/1009)
Test Accuracy (Overall): 97% (9770/10000)
# for posterity, test no norm case in eval mode
test(net_no_norm, train=False)
Test Loss: 0.207286Test Accuracy of 0: 98% (963/980)
Test Accuracy of 1: 98% (1113/1135)
Test Accuracy of 2: 91% (943/1032)
Test Accuracy of 3: 93% (943/1010)
Test Accuracy of 4: 93% (918/982)
Test Accuracy of 5: 92% (824/892)
Test Accuracy of 6: 95% (912/958)
Test Accuracy of 7: 92% (954/1028)
Test Accuracy of 8: 91% (891/974)
Test Accuracy of 9: 93% (940/1009)
Test Accuracy (Overall): 94% (9401/10000)

Which model has the highest accuracy?

You should see a small improvement whe comparing the batch norm model’s accuracy in training and evaluation mode; evaluation mode should give a small improvement!

You should also see that the model that uses batch norm layers shows a marked improvement in overall accuracy when compared with the no-normalization model.

Considerations for other network types

This notebook demonstrates batch normalization in a standard neural network with fully connected layers. You can also use batch normalization in other types of networks, but there are some special considerations.

ConvNets

Convolution layers consist of multiple feature maps. (Remember, the depth of a convolutional layer refers to its number of feature maps.) And the weights for each feature map are shared across all the inputs that feed into the layer. Because of these differences, batch normalizing convolutional layers requires batch/population mean and variance per feature map rather than per node in the layer.

To apply batch normalization on the outputs of convolutional layers, we use BatchNorm2d

RNNs

Batch normalization can work with recurrent neural networks, too, as shown in the 2016 paper Recurrent Batch Normalization. It’s a bit more work to implement, but basically involves calculating the means and variances per time step instead of per layer. You can find an example where someone implemented recurrent batch normalization in PyTorch, in this GitHub repo.Why Batch Normalization works

Batch normalization was introduced in Sergey Ioffe’s and Christian Szegedy’s 2015 paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. The idea is that, instead of just normalizing the inputs to the network, we normalize the inputs to layers within the network.

It’s called batch normalization because during training, we normalize each layer’s inputs by using the mean and variance of the values in the current batch.

The change in the distributions of layers inputs presents as problem because the layers need to continuously adapt to the new distribution. When the input distribution to a learning system changes, it is said to experience covariate shift.

Batch Normalization in PyTorch

This section of the code shows you one way to add batch normalization to a neural network built in PyTorch.

The following code import the packages we need in the notebook and load the MNIST dataset to use in our experiments.

Visualising the data

Neural network classes for testing

The following class, NeuralNet, allows us to create identical neural networks with and without batch normalization to compare. The code is heavily documented, but there is also some additional discussion later. You do not need to read through it all before going through the rest of the notebook, but the comments within the code blocks may answer some of your questions.

About the code:

We are defining a simple MLP for classification; this design choice was made to support the discussion related to batch normalization and not to get the best classification accuracy.

(Important) Model Details

There are quite a few comments in the code, so those should answer most of your questions. However, let’s take a look at the most important lines.

We add batch normalization to layers inside the __init__ function. Here are some important points about that code:

  1. Layers with batch normalization do not include a bias term.
  2. We use PyTorch’s BatchNorm1d function to handle the math. This is the function you use to operate on linear layer outputs; you’ll use BatchNorm2d for 2D outputs like filtered images from convolutional layers.
  3. We add the batch normalization layer before calling the activation function.

Create two different models for testing

  • net_batchnorm is a linear classification model with batch normalization applied to the output of its hidden layers
  • net_no_norm is a plain MLP, without batch normalization

Besides the normalization layers, everthing about these models is the same.

net_batchnorm = NeuralNet(use_batch_norm=True)
net_no_norm = NeuralNet(use_batch_norm=False)
print(net_batchnorm)
print()
print(net_no_norm)

Training

The below train function will take in a model and some number of epochs. We'll use cross entropy loss and stochastic gradient descent for optimization. This function returns the losses, recorded after each epoch, so that we can display and compare the behavior of different models.

.train() mode

Note that we tell our model whether or not it should be in training mode, model.train(). This is an important step because batch normalization has different behavior during training on a batch or testing/evaluating on a larger dataset.

Comparing Models

In the below cells, we train our two different models and compare their trainining loss over time.

# batchnorm model losses
# this may take some time to train
losses_batchnorm = train(net_batchnorm)
# *no* norm model losses
# you should already start to see a difference in training losses
losses_no_norm = train(net_no_norm)

Let’s see a difference in with batchnorm or without batchnorm models in graph

# compare
fig, ax = plt.subplots(figsize=(12,8))
#losses_batchnorm = np.array(losses_batchnorm)
#losses_no_norm = np.array(losses_no_norm)
plt.plot(losses_batchnorm, label='Using batchnorm', alpha=0.5)
plt.plot(losses_no_norm, label='No norm', alpha=0.5)
plt.title("Training Losses")
plt.legend()
Batchnorm v/s No Norm

Testing

You should see that the model with batch normalization, starts off with a lower training loss and, over ten epochs of training, gets to a training loss that is noticeably lower than our model without normalization.

Next, let’s see how both these models perform on our test data! Below, we have a function test that takes in a model and a parameter train (True or False) which indicates whether the model should be in training or evaulation mode. This is for comparison purposes, later. This function will calculate some test statistics including the overall test accuracy of a passed in model.

Training and Evaluation Mode

Setting a model to evaluation mode is important for models with batch normalization layers!

  • Training mode means that the batch normalization layers will use batch statistics to calculate the batch norm.
  • Evaluation mode, on the other hand, uses the estimated population mean and variance from the entire training set, which should give us increased performance on this test data!
# test batchnorm case, in *train* mode
test(net_batchnorm, train=True)
Test Loss: 0.086881Test Accuracy of 0: 98% (967/980)
Test Accuracy of 1: 99% (1126/1135)
Test Accuracy of 2: 96% (999/1032)
Test Accuracy of 3: 97% (989/1010)
Test Accuracy of 4: 96% (952/982)
Test Accuracy of 5: 96% (864/892)
Test Accuracy of 6: 97% (933/958)
Test Accuracy of 7: 96% (990/1028)
Test Accuracy of 8: 96% (939/974)
Test Accuracy of 9: 95% (966/1009)
Test Accuracy (Overall): 97% (9725/10000)
# test batchnorm case, in *evaluation* mode
test(net_batchnorm, train=False)
Test Loss: 0.073484Test Accuracy of 0: 98% (968/980)
Test Accuracy of 1: 99% (1127/1135)
Test Accuracy of 2: 97% (1005/1032)
Test Accuracy of 3: 98% (991/1010)
Test Accuracy of 4: 97% (955/982)
Test Accuracy of 5: 97% (874/892)
Test Accuracy of 6: 97% (932/958)
Test Accuracy of 7: 96% (995/1028)
Test Accuracy of 8: 96% (940/974)
Test Accuracy of 9: 97% (983/1009)
Test Accuracy (Overall): 97% (9770/10000)
# for posterity, test no norm case in eval mode
test(net_no_norm, train=False)
Test Loss: 0.207286Test Accuracy of 0: 98% (963/980)
Test Accuracy of 1: 98% (1113/1135)
Test Accuracy of 2: 91% (943/1032)
Test Accuracy of 3: 93% (943/1010)
Test Accuracy of 4: 93% (918/982)
Test Accuracy of 5: 92% (824/892)
Test Accuracy of 6: 95% (912/958)
Test Accuracy of 7: 92% (954/1028)
Test Accuracy of 8: 91% (891/974)
Test Accuracy of 9: 93% (940/1009)
Test Accuracy (Overall): 94% (9401/10000)

Which model has the highest accuracy?

You should see a small improvement whe comparing the batch norm model’s accuracy in training and evaluation mode; evaluation mode should give a small improvement!

You should also see that the model that uses batch norm layers shows a marked improvement in overall accuracy when compared with the no-normalization model.

Considerations for other network types

This notebook demonstrates batch normalization in a standard neural network with fully connected layers. You can also use batch normalization in other types of networks, but there are some special considerations.

ConvNets

Convolution layers consist of multiple feature maps. (Remember, the depth of a convolutional layer refers to its number of feature maps.) And the weights for each feature map are shared across all the inputs that feed into the layer. Because of these differences, batch normalizing convolutional layers requires batch/population mean and variance per feature map rather than per node in the layer.

To apply batch normalization on the outputs of convolutional layers, we use BatchNorm2d

RNNs

Batch normalization can work with recurrent neural networks, too, as shown in the 2016 paper Recurrent Batch Normalization. It’s a bit more work to implement, but basically involves calculating the means and variances per time step instead of per layer. You can find an example where someone implemented recurrent batch normalization in PyTorch, in this GitHub repo.

--

--