0Pricing
Machine Learning Academy · Lesson

Building and Training a CNN on CIFAR-10

Learners will stack Conv2d-ReLU-MaxPool blocks, flatten the feature map, attach a linear classifier, and train on CIFAR-10 with data augmentation.

CIFAR-10: The Benchmark Dataset

CIFAR-10 is a classic image classification benchmark containing 60,000 colour images (32x32 pixels, RGB) in 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. There are 50,000 training images and 10,000 test images. It is small enough to train on a laptop in a few hours but complex enough that simple models fail — making it ideal for learning CNN design. PyTorch makes it trivially available via torchvision.datasets.CIFAR10.

import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465),
        std=(0.2023, 0.1994, 0.2010)
    )
])

train_set = torchvision.datasets.CIFAR10(
    root='./data', train=True,
    download=True, transform=transform
)
print('Training images:', len(train_set))   # 50000
print('Image shape:', train_set[0][0].shape) # (3, 32, 32)

Loading Data with DataLoader

After defining the dataset, wrap it in a DataLoader that handles batching, shuffling, and parallel loading. For CIFAR-10, a batch size of 64 or 128 is typical. Set shuffle=True for training data to randomise the order each epoch, and shuffle=False for the test set (order doesn't matter for evaluation). num_workers=2 loads data in parallel with training to reduce the GPU idle time during data fetching.

from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_set,
    batch_size=128,
    shuffle=True,
    num_workers=2,
    pin_memory=True     # faster GPU transfer
)

# Peek at one batch
X_batch, y_batch = next(iter(train_loader))
print('Batch images:', X_batch.shape)   # (128, 3, 32, 32)
print('Batch labels:', y_batch.shape)   # (128,)

All lessons in this course

  1. Convolution and Filters: Detecting Edges and Patterns
  2. Pooling Layers: Spatial Downsampling and Invariance
  3. Building and Training a CNN on CIFAR-10
  4. Data Augmentation: Transforms for Robustness
← Back to Machine Learning Academy