Building and Training a CNN on CIFAR-10
Learners will stack Conv2d-ReLU-MaxPool blocks, flatten the feature map, attach a linear classifier, and train on CIFAR-10 with data augmentation.
CIFAR-10: The Benchmark Dataset
CIFAR-10 is a classic image classification benchmark containing 60,000 colour images (32x32 pixels, RGB) in 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. There are 50,000 training images and 10,000 test images. It is small enough to train on a laptop in a few hours but complex enough that simple models fail — making it ideal for learning CNN design. PyTorch makes it trivially available via torchvision.datasets.CIFAR10.
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010)
)
])
train_set = torchvision.datasets.CIFAR10(
root='./data', train=True,
download=True, transform=transform
)
print('Training images:', len(train_set)) # 50000
print('Image shape:', train_set[0][0].shape) # (3, 32, 32)Loading Data with DataLoader
After defining the dataset, wrap it in a DataLoader that handles batching, shuffling, and parallel loading. For CIFAR-10, a batch size of 64 or 128 is typical. Set shuffle=True for training data to randomise the order each epoch, and shuffle=False for the test set (order doesn't matter for evaluation). num_workers=2 loads data in parallel with training to reduce the GPU idle time during data fetching.
from torch.utils.data import DataLoader
train_loader = DataLoader(
train_set,
batch_size=128,
shuffle=True,
num_workers=2,
pin_memory=True # faster GPU transfer
)
# Peek at one batch
X_batch, y_batch = next(iter(train_loader))
print('Batch images:', X_batch.shape) # (128, 3, 32, 32)
print('Batch labels:', y_batch.shape) # (128,)All lessons in this course
- Convolution and Filters: Detecting Edges and Patterns
- Pooling Layers: Spatial Downsampling and Invariance
- Building and Training a CNN on CIFAR-10
- Data Augmentation: Transforms for Robustness