Batch normalization, or batchnorm, is a popular technique used to speed up the training of neural networks by addressing a problem known as internal covariate shift. It also reduces sensitivity to weight and bias initialization, which makes the training process much more stable and less dependent on careful tuning of the initial values. Before batchnorm, the weights and biases had to be carefully tuned to ensure stable and efficient training. Poor initialization could lead to either exploding or vanishing gradients, both of which are problematic when trying to train a neural network.

In this blog post, I will explain what batchnorm tries to solve, how it works, and code out an implementation of it using PyTorch. Before jumping into the details of batchnorm, let’s take a look at why weights and biases must be carefully initialized.

Need for careful initialization

import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from abc import ABC, abstractmethod
from typing import List
 
BATCH_SIZE = 32
MAX_STEPS = 1_500
 
transform = transforms.Compose([
    transforms.ToTensor()
])
 
train_dataset = datasets.MNIST(
    root="data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(
    root="data", train=False, download=True, transform=transform)
 
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=len(test_dataset), shuffle=True)
 
 
class Layer(ABC):
    @abstractmethod
    def __call__(self, x):
        pass
 
    @abstractmethod
    def parameters(self):
        pass
 
 
class Linear(Layer):
    def __init__(self, fan_in, fan_out):
        super().__init__()
 
        self.weight = torch.rand(
            (fan_in, fan_out)) / fan_in ** 0.5
        self.bias = torch.zeros(fan_out)
 
    def __call__(self, x):
        return x @ self.weight + self.bias
 
    def parameters(self):
        return [self.weight, self.bias]
 
 
class Tanh(Layer):
    def __init__(self):
        super().__init__()
 
    def __call__(self, x):
        return torch.tanh(x)
 
    def parameters(self):
        return []
 
 
class Model(Layer):
    def __init__(self, layers: Layer):
        super().__init__()
 
        self.layers: List[Layer] = layers
 
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
 
        return x
 
    def parameters(self):
        parameters = []
 
        for layer in self.layers:
            parameters += layer.parameters()
 
        return parameters
 
    def zero_grad(self):
        for p in self.parameters():
            p.grad = None
 
 
model = Model([
    Linear(28 * 28, 128),
    Tanh(),
    Linear(128, 10)
])
 
parameters: List[torch.Tensor] = model.parameters()
 
for p in parameters:
    p.requires_grad = True
 
step = 0
 
for x, y in train_loader:
    if step >= MAX_STEPS:
        break
    x = x.view(-1, 28 * 28)
    out = model(x)
    loss = F.cross_entropy(out, y)
 
    print(f"step - {step + 1}; loss - {loss:.3f}")
 
    model.zero_grad()
    loss.backward()
 
    lr = 0.1 * (1.0 - step / MAX_STEPS)
    lr = max(lr, 0.01)
 
    for p in parameters:
        if p.grad is not None:
            p.data -= lr * p.grad
 
    step += 1
 
with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test = x_test.view(-1, 28 * 28)
        x_test = model(x_test)
        loss = F.cross_entropy(x_test, y_test)
        print(f"loss: {loss.item():.3f}")

The code above defines a very basic neural network trained on the MNIST dataset. It consists of a single hidden layer and uses tanh as the activation function, mini-batch SGD as the optimizer, and cross-entropy as the loss function. The weights are initialized randomly using torch.randn, and the biases are initially set to zero.

After training the model for a maximum of 1,500 steps, it achieves a loss of around 2.3.

The loss v/s training steps plot above reveals some interesting details about the training process:

  • The initial loss starts at a higher value (around 3.6 - 3.8).
  • The curve shows a slow descent and contains several “plateaus” along the way.
  • The loss gets stuck at around 2.3, indicating that the optimizer has reached a poor local minimum.

The reason behind the high initial loss is that, due to the randomly initialized weights, the network makes highly biased initial predictions. When this happens, the first few training steps are spent squashing down the weights before the network begins actually learning i.e. optimizing the parameters.

In a classification problem with N classes, if the network has no prior knowledge, it should ideally predict equal probabilities for all classes. If cross-entropy is used as the loss function, then the ideal initial loss would be equal to −log(1/N). So, if the initial loss during training is close to −log(1/N), then the network will work on optimizing the parameters rather than first squashing down the weights.

One way to fix this issue is to scale down the randomly initialized weights by multiplying them with some factor, such that the initial loss is close to −log⁡(1/N).

class Linear(Layer):
    def __init__(self, fan_in, fan_out):
        super().__init__()
 
        self.weight = torch.rand((fan_in, fan_out)) * 0.01
        self.bias = torch.zeros(fan_out)

By simply scaling down the randomly initialized weights by a factor of 100, the optimizer reaches a far better local minimum by the end of training. The model achieves a loss of around 0.3, nearly 7.5 times better than before.

Along with bringing the initial loss closer to the ideal starting value, it also fixed the issue of saturated tanh. To understand the problem of saturated tanh, we first need to understand the nature of the tanh function.

tanh is a squashing function - it takes in any real number as the input and compresses it into the range of [−1, 1]. The tanh function has “plateau” regions at both ends of its graph, which means that when it’s applied to a very large positive number, the output is almost equal to 1, and similarly, for a very large negative number, the output is almost equal to -1.

The derivate of is equal to

Consider a layer in a neural network that undergoes a linear transformation, followed by a tanh activation function applied to the outputs. If the outputs of the linear transformation are very large, then tanh will return values close to either 1 or -1, depending on the sign. When the network backpropagates through tanh, the derivatives will be nearly zero, since the post-activation values for most of the neurons in that layer are saturated at either 1 or -1. This causes those neurons to become dead - meaning they won’t learn anything during training. This issue is known as saturated tanh.

The bar chart above shows the mean post-activation values of the first hidden layer for all 128 neurons when the weights are randomly initialized without any scaling factor. Notice how the post-activation values for most of the neurons are equal to 1. This indicates that those neurons were barely learning anything during training, which also explains the high loss when the weights were not properly initialized.

This bar chart shows the case when the weights are properly initialized i.e. they are scaled down so that the outputs of the linear transformation are not too large, preventing them from being saturated after tanh is applied.

The issue with this approach is that it involves a lot of trial and error to determine which scaling factor yields the most optimal results and randomly using magic numbers in the codebase isn’t really the best practice. To overcome this issue, we can use a standard weight initialization technique such as Xavier initialization (if tanh activation function is used) or He initialization (if ReLU-like activation functions are used).

Xavier Initialization

Xavier initialization is a weight initialization technique introduced in the paper Understanding the Difficulty of Training Deep Feedforward Neural Networks. The main idea of Xavier initialization is to keep the variance of activations and gradients roughly the same across all layers.

If the variance of activations grows too much then for activation functions like tanh and sigmoid, large values get squashed to 1 or -1. This makes the derivatives almost equal to zero, so the learning slows down.

If the variance of gradients grows too much then it causes the exploding gradients problem.

If the variance of activations shrink too much then the signal going forward gets smaller and smaller, which causing underfitting in the network.

If the variance of the gradients shrinks too much then it causes the vanishing gradients problem.

In Xavier initialization, the weights are sampled from:

PyTorch has a utility function that fills an input tensor using the Xavier initialization technique - torch.nn.init.xavier_normal_

class Linear(Layer):
    def __init__(self, fan_in, fan_out):
        super().__init__()
 
        self.weight = torch.rand(
            (fan_in, fan_out))
        self.bias = torch.zeros(fan_out)
 
        torch.nn.init.xavier_normal_(self.weight)

Using Xavier initialization, the network reaches the optimal local minimum much more quickly.

Introduction to batchnorm

Training Deep Neural Networks is complicated by the fact that the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change. This slows down the training by requiring lower learning rates and careful parameter initialization and makes it notoriously hard to train models with saturating nonlinearities. We refer to this phenomenon as internal covariate shift and address the problem by normalizing layer inputs. Our method draws its strength from making normalization a part of the model architecture and performing the normalization for each training mini-batch. Batch Normalization allows us to use much higher learning rates and be less careful about initialization.

Batch normalization, or batchnorm, was introduced in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. It helps to speed up the training process of neural networks by addressing the issue of internal covariate shift. But what is internal covariate shift?

During training, as the parameters of the earlier layers update, the distribution of their outputs (activations) changes. This means that the distribution of inputs to the next layer are constantly shifting. So when a later layer tries to learn something, it has constantly re-learn the things due to ever-changing distribution shift. This phenomenon where the distribution of activations shifts internally during, is known as internal covariate shift.

I’ve tweaked the training loop to capture the snapshots of the pre and post activation values for every 200th step and plot them at the end using matplotlib and seaborn.

snapshots = {
    "step": [],
    "pre_activations": [],
    "post_activations": []
}
 
step = 0
 
for x, y in train_loader:
    if step >= MAX_STEPS:
        break
 
    x = x.view(-1, 28 * 28)
 
    pre_act = model.layers[0](x)
    post_act = model.layers[1](pre_act)
 
    if step % 200 == 0:
        snapshots["step"].append(step)
        snapshots["pre_activations"].append(pre_act.detach().view(-1).cpu())
        snapshots["post_activations"].append(post_act.detach().view(-1).cpu())
 
    out = model.layers[2](post_act)
    loss = F.cross_entropy(out, y)
 
    print(f"step - {step + 1}; loss - {loss:.3f}")
 
    model.zero_grad()
    loss.backward()
 
    lr = 0.1 * (1.0 - step / MAX_STEPS)
    lr = max(lr, 0.01)
 
    for p in parameters:
        if p.grad is not None:
            p.data -= lr * p.grad
 
    step += 1
 
with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test = x_test.view(-1, 28 * 28)
        x_test = model(x_test)
        loss = F.cross_entropy(x_test, y_test)
        print(f"loss: {loss.item():.3f}")
 
plt.figure(figsize=(14, 6))
 
for i, step in enumerate(snapshots["step"]):
    plt.subplot(1, 2, 1)
    sns.kdeplot(snapshots["pre_activations"][i],
                label=f"step {step}", linewidth=2)
    plt.title("Distribution of first linear layer outputs")
 
    plt.subplot(1, 2, 2)
    sns.kdeplot(snapshots["post_activations"][i],
                label=f"step {step}", linewidth=2)
    plt.title("Distribution after tanh activation")
 
plt.subplot(1, 2, 1)
plt.legend()
plt.subplot(1, 2, 2)
plt.legend()
plt.tight_layout()
plt.show()

The distribution shift is clearly visible from the above plot, notice how dissimilar the distribution of pre and post-activation values are, after the initial step. This indicates that our current network is facing internal covariate shift and it is spending a lot of time re-learning stuff instead of learning something new and optimizing the parameters.

Batch normalization reduces internal covariate shift by adding a normalization step that adjusts the mean and variance of the activations.

where is the mean of activations in that batch and is the std of activations in that batch. With this formula, which is also known as naive batchnorm, the mean and std of is always set to be equal to 0 and 1 respectively.

But there is an issue with this, which is the representation of the layer’s output is constrained. When batchnorm is applied to the post-tanh activations then the values which were present in the “plateau” regions, come down to the “sweet” middle region. This does solve the issue of saturated tanh but it also constrains the output of the tanh layer.

What if the case when the values are present in the plateau regions actually produces the most optimal result? To fix this issue, the batchnorm paper suggested two new trainable parameters - and

With these parameters, batchnom reduces internal covariate shift within the network without constraining any layer’s outputs.

Implementation of Batchnorm

class BatchNorm(Layer):
    def __init__(self, dim):
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
 
    def __call__(self, x: torch.Tensor):
        mean = x.mean(0, keepdim=True)
        std = x.std(0, keepdim=True, unbiased=True)
 
        xhat = (x - mean)/std
        return self.gamma * xhat + self.beta
 
    def parameters(self):
        return [self.gamma, self.beta]

This is a pretty straightforward implementation of batchnorm using PyTorch, but there are a few caveats - Bessel’s correction and how batchnorm acts differently during training and inference.

Bessel’s correction

Bessel’s correction is a statistical fix applied when population variance or population standard deviation is estimated from a sample. When the population variance is estimated from the sample using the regular formula of variance, then it always underestimates the true variance of the population.

The underestimation can be removed by dividing by instead of , and this is known as the Bessel’s correction.

std = x.std(0, keepdim=True, unbiased=True)

unbiased=True tells PyTorch to take Bessel’s correction in consideration, though it is set to be True by default but I’ve done it explicitly so that I could explain about it in the later sections.

Batchnorm during inference

During inference, generally, a single input is fed to the model at a time and the mean and variance which are calculated from that single input will give very different results because batchnorm is designed to operate on a batch of inputs. To fix this issue, the model can use running estimates during inference and explicitly calculate mean and variance during training. This is the “difference”, that was mentioned above, in how batchnorm acts during training and inference.

class Layer(ABC):
    def __init__(self):
        self.training = True
 
    @abstractmethod
    def __call__(self, x):
        pass
 
    @abstractmethod
    def parameters(self):
        pass
class Model(Layer):
   # ...
 
    def eval(self):
        for layer in self.layers:
            layer.training = False
 
    def train(self):
        for layer in self.layers:
            layer.training = True

I’ve added a new field under Layer class - training, which could be used to execute different pieces of code based on whether it is inference or training. I’ve also added two additional helper methods under Model class - eval, which sets the entire model into inference mode and train, which sets the entire model into training mode.

class BatchNorm(Layer):
    def __init__(self, dim, momentum=0.01):
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        self.momentum = momentum
        self.bn_mean = torch.zeroes(dim)
        self.bn_std = torch.zeros(dim)
 
    def __call__(self, x: torch.Tensor):
        if self.training:
            mean = x.mean(0, keepdim=True)
            std = x.std(0, keepdim=True, unbiased=True)
        else:
            mean = self.bn_mean
            std = self.bn_std
 
        xhat = (x - mean)/std
        out = self.gamma * xhat + self.beta
 
        if self.training:
            with torch.no_grad():
                self.bn_mean = (1 - self.momentum) * \
                    self.bn_mean + self.momentum * mean
                self.bn_std = (1 - self.momentum) * \
                    self.bn_std + self.momentum * std
 
        return out
 
    def parameters(self):
        return [self.gamma, self.beta]

self.bn_mean and self.bn_std are the running estimates which are estimated using exponential moving average (EMA).

where is the momentum of the exponential moving average.

self.bn_mean and self.bn_std are not trainable parameters i.e. they are not updated via the training loop but they are updated through the exponential moving average under the batchnorm layer.

model.eval()
with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test = x_test.view(-1, 28 * 28)
        x_test = model(x_test)
        loss = F.cross_entropy(x_test, y_test)
        print(f"loss: {loss.item():.3f}")
model.train()

Notice how the model is set into inference/evaluation mode before testing.

On training the model with both the batchnorm layer and Xavier initialization, the final training loss is around 0.21 and the validation loss is around 0.18.

The reason why there isn’t much difference between the loss plot of batchnorm + Xavier initialization and just Xavier initialization is that batchnorm shines the best in deep neural networks.

References