0Pricing
Machine Learning Academy · Lesson

The Vanishing Gradient Problem in Deep Time Steps

Learners will observe exploding and vanishing gradients in a deep RNN through gradient norm logging and understand why long sequences make training unstable.

Gradients Must Travel Through Time

To learn from long-range dependencies in a sequence, gradients from the loss at the final timestep must travel backward through every timestep to update the parameters that processed the early inputs. For a sequence of length T, this means multiplying the same weight matrix W_hh by itself T times during BPTT. This repeated multiplication is the root cause of both vanishing gradients (exponential decay) and exploding gradients (exponential growth).

import torch

# Conceptual illustration of gradient travel through T steps
# Gradient = dL/dh_T * (W_hh)^T * ...

# If W_hh has spectral radius < 1:
W_small = torch.eye(4) * 0.9
print('W^10 max value:', (W_small @ W_small @ W_small @
      W_small @ W_small @ W_small @
      W_small @ W_small @ W_small @ W_small).abs().max().item())
# -> very small: gradient vanishes

# If W_hh has spectral radius > 1:
W_big = torch.eye(4) * 1.1
print('W^10 max value:', (W_big ** 10).abs().max().item())
# -> very large: gradient explodes

Vanishing Gradient: The Mathematical Root Cause

During BPTT, the gradient of the loss with respect to the hidden state at timestep t involves the product of the Jacobian matrices of h with respect to h at each step from t to T. The Jacobian at each step involves diag(f'(h_t)) * W_hh where f' is the derivative of the activation. For tanh, f' is bounded by 1, and typical random weights have spectral radius less than 1 — so this product of T matrices drives gradients to zero exponentially fast with T.

import torch

# Track gradient norm through BPTT
def simulate_bptt_gradient(T, weight_scale=0.9):
    W = torch.eye(8) * weight_scale
    grad = torch.ones(8)   # gradient at final timestep
    norms = [grad.norm().item()]
    for t in range(T):
        grad = W.T @ grad  # one BPTT step
        norms.append(grad.norm().item())
    return norms

norms = simulate_bptt_gradient(T=20)
print('Gradient norms over 20 steps:')
print([f'{n:.4f}' for n in norms[::5]])
# Decreases from 2.83 -> nearly 0 after 20 steps

All lessons in this course

  1. Vanilla RNNs: Hidden State and Sequence Unrolling
  2. The Vanishing Gradient Problem in Deep Time Steps
  3. LSTM Cell: Input, Forget, and Output Gates
  4. Sequence-to-One: Sentiment Analysis with an LSTM
← Back to Machine Learning Academy