Batch Normalisation: Stable and Faster Training
Learners will insert nn.BatchNorm1d between layers, observe faster convergence on a deep network, and understand how BatchNorm normalises activations within each mini-batch.
The Problem Batch Norm Solves
Deep neural networks suffer from internal covariate shift — the distribution of each layer's inputs changes during training as the previous layer's weights update. This forces each layer to continuously adapt to a shifting input distribution, slowing training. Batch Normalisation (Batch Norm), introduced by Ioffe and Szegedy in 2015, addresses this by normalising layer inputs within each mini-batch, dramatically accelerating training and reducing sensitivity to weight initialisation.
# Without batch norm: deep networks train slowly and
# require very careful weight init and LR tuning.
# With batch norm: can use higher learning rates,
# less sensitive to initialisation, acts as regulariser.
# Batch norm normalises each feature to:
# mean=0, std=1 within the batch, then
# applies learnable scale (gamma) and shift (beta).
print('Batch Norm: normalize -> scale -> shift')How Batch Norm Works Mathematically
For each feature dimension, Batch Norm computes the mean and variance across the current mini-batch, then normalises each value. After normalisation it applies two learnable parameters: gamma (scale) and beta (shift). This allows the network to undo the normalisation if needed — the identity transform is recoverable. A small constant epsilon is added to the variance to prevent division by zero.
import torch
def batch_norm_manual(x, gamma, beta, eps=1e-5):
# x shape: (batch_size, features)
mu = x.mean(dim=0) # mean per feature
var = x.var(dim=0, unbiased=False) # var per feature
x_norm = (x - mu) / (var + eps).sqrt()
return gamma * x_norm + beta # scale and shift
x = torch.randn(32, 8) # batch=32, 8 features
gamma = torch.ones(8)
beta = torch.zeros(8)
out = batch_norm_manual(x, gamma, beta)
print('Mean near 0:', out.mean(dim=0).abs().max().item() < 0.01)
print('Std near 1:', (out.std(dim=0) - 1).abs().max().item() < 0.01)All lessons in this course
- Learning Rate: The Most Important Hyperparameter
- Batch Normalisation: Stable and Faster Training
- Dropout Regularisation to Prevent Overfitting
- Weight Initialisation: Xavier and He Initialisation