0Pricing
Machine Learning Academy · Lesson

Batch Normalisation: Stable and Faster Training

Learners will insert nn.BatchNorm1d between layers, observe faster convergence on a deep network, and understand how BatchNorm normalises activations within each mini-batch.

The Problem Batch Norm Solves

Deep neural networks suffer from internal covariate shift — the distribution of each layer's inputs changes during training as the previous layer's weights update. This forces each layer to continuously adapt to a shifting input distribution, slowing training. Batch Normalisation (Batch Norm), introduced by Ioffe and Szegedy in 2015, addresses this by normalising layer inputs within each mini-batch, dramatically accelerating training and reducing sensitivity to weight initialisation.

# Without batch norm: deep networks train slowly and
# require very careful weight init and LR tuning.

# With batch norm: can use higher learning rates,
# less sensitive to initialisation, acts as regulariser.

# Batch norm normalises each feature to:
# mean=0, std=1 within the batch, then
# applies learnable scale (gamma) and shift (beta).
print('Batch Norm: normalize -> scale -> shift')

How Batch Norm Works Mathematically

For each feature dimension, Batch Norm computes the mean and variance across the current mini-batch, then normalises each value. After normalisation it applies two learnable parameters: gamma (scale) and beta (shift). This allows the network to undo the normalisation if needed — the identity transform is recoverable. A small constant epsilon is added to the variance to prevent division by zero.

import torch

def batch_norm_manual(x, gamma, beta, eps=1e-5):
    # x shape: (batch_size, features)
    mu = x.mean(dim=0)           # mean per feature
    var = x.var(dim=0, unbiased=False)  # var per feature
    x_norm = (x - mu) / (var + eps).sqrt()
    return gamma * x_norm + beta  # scale and shift

x = torch.randn(32, 8)  # batch=32, 8 features
gamma = torch.ones(8)
beta = torch.zeros(8)

out = batch_norm_manual(x, gamma, beta)
print('Mean near 0:', out.mean(dim=0).abs().max().item() < 0.01)
print('Std near 1:', (out.std(dim=0) - 1).abs().max().item() < 0.01)

All lessons in this course

  1. Learning Rate: The Most Important Hyperparameter
  2. Batch Normalisation: Stable and Faster Training
  3. Dropout Regularisation to Prevent Overfitting
  4. Weight Initialisation: Xavier and He Initialisation
← Back to Machine Learning Academy