Data Augmentation: Transforms for Robustness
Learners will apply random horizontal flip, crop, and colour jitter via torchvision.transforms, and measure the accuracy improvement from augmentation.
What Is Data Augmentation?
Data augmentation artificially increases the size and diversity of the training dataset by applying random transformations to existing images. Instead of collecting new data (expensive), augmentation creates new training examples on the fly. The key insight is that for many tasks, the label should remain the same under the transformation — a horizontally flipped image of a cat is still a cat. Augmentation reduces overfitting and teaches the model to be invariant to irrelevant variations.
import torchvision.transforms as transforms
from PIL import Image
import torch
# A simple augmentation pipeline
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.3),
transforms.ToTensor()
])
# Apply to same image -> different result each time
# img = Image.open('cat.jpg')
# aug1 = transform(img) # one augmented version
# aug2 = transform(img) # different augmented version
print('Augmentation pipeline defined')torchvision.transforms: The Augmentation Toolkit
torchvision.transforms provides a rich library of image transforms. Key categories: geometric (flip, rotate, crop, resize, perspective); color (jitter, grayscale, solarize, equalize); pixel-level (Gaussian blur, random erasing); and tensor operations (ToTensor, Normalize). They chain together with transforms.Compose. From PyTorch 2.0, transforms.v2 provides improved speed and additional augmentation primitives.
import torchvision.transforms as transforms
# Catalog of key transforms
train_tf = transforms.Compose([
transforms.Resize(36), # scale
transforms.RandomCrop(32, padding=4), # spatial
transforms.RandomHorizontalFlip(), # mirror
transforms.RandomRotation(10), # rotate
transforms.RandomGrayscale(p=0.1), # de-color
transforms.ColorJitter(
brightness=0.3, contrast=0.3,
saturation=0.3, hue=0.1), # color
transforms.GaussianBlur(kernel_size=3, # blur
sigma=(0.1, 1.0)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
print('Full augmentation pipeline ready')All lessons in this course
- Convolution and Filters: Detecting Edges and Patterns
- Pooling Layers: Spatial Downsampling and Invariance
- Building and Training a CNN on CIFAR-10
- Data Augmentation: Transforms for Robustness