0PricingLogin
Machine Learning Academy · Lesson

Vanilla RNNs: Hidden State and Sequence Unrolling

Learners will implement a one-step RNN cell manually, unroll it across a short sequence, and visualise how the hidden state accumulates context.

Why We Need Recurrent Networks

Standard feedforward networks treat each input independently — they have no memory of previous inputs. But many real-world problems involve sequential data where context from the past matters: predicting the next word in a sentence, forecasting tomorrow's stock price from historical data, or classifying a gesture from a video frame sequence. Recurrent Neural Networks (RNNs) solve this by maintaining a hidden state that carries information across timesteps.

# Examples of sequential data:
sequences = {
    'NLP': 'The cat sat on the ___  (predict next word)',
    'Time Series': '[1.2, 1.5, 1.3, 1.8, ???]',
    'Speech': '[audio_t0, audio_t1, ..., audio_tN]',
    'Video': '[frame_1, frame_2, ..., frame_T]',
    'DNA': 'ATCGATCG... (biological sequence)',
}
for name, example in sequences.items():
    print(f'{name}: {example}')

The Vanilla RNN Cell

A vanilla RNN cell takes two inputs: the current input x_t and the previous hidden state h_{t-1}. It produces the next hidden state h_t using the formula: h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b). The same weight matrices W_hh and W_xh are used at every timestep — this is weight sharing across time, analogous to how CNNs share weights across space. The hidden state carries the network's memory of all past inputs.

import torch

def rnn_cell(x_t, h_prev, W_xh, W_hh, b):
    '''One step of a vanilla RNN cell'''
    # x_t: (batch, input_size)
    # h_prev: (batch, hidden_size)
    h_t = torch.tanh(
        x_t @ W_xh.T +    # input contribution
        h_prev @ W_hh.T +  # hidden-to-hidden contribution
        b                  # bias
    )
    return h_t

# Example: input_size=4, hidden_size=8
batch = 3
x_t   = torch.randn(batch, 4)
h_prev = torch.zeros(batch, 8)
W_xh  = torch.randn(8, 4) * 0.01
W_hh  = torch.randn(8, 8) * 0.01
b     = torch.zeros(8)
h_t = rnn_cell(x_t, h_prev, W_xh, W_hh, b)
print(h_t.shape)   # (3, 8)

All lessons in this course

  1. Vanilla RNNs: Hidden State and Sequence Unrolling
  2. The Vanishing Gradient Problem in Deep Time Steps
  3. LSTM Cell: Input, Forget, and Output Gates
  4. Sequence-to-One: Sentiment Analysis with an LSTM
← Back to Machine Learning Academy