0Pricing
Machine Learning Academy · Lesson

Data Augmentation: Transforms for Robustness

Learners will apply random horizontal flip, crop, and colour jitter via torchvision.transforms, and measure the accuracy improvement from augmentation.

What Is Data Augmentation?

Data augmentation artificially increases the size and diversity of the training dataset by applying random transformations to existing images. Instead of collecting new data (expensive), augmentation creates new training examples on the fly. The key insight is that for many tasks, the label should remain the same under the transformation — a horizontally flipped image of a cat is still a cat. Augmentation reduces overfitting and teaches the model to be invariant to irrelevant variations.

import torchvision.transforms as transforms
from PIL import Image
import torch

# A simple augmentation pipeline
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.3),
    transforms.ToTensor()
])

# Apply to same image -> different result each time
# img = Image.open('cat.jpg')
# aug1 = transform(img)  # one augmented version
# aug2 = transform(img)  # different augmented version
print('Augmentation pipeline defined')

torchvision.transforms: The Augmentation Toolkit

torchvision.transforms provides a rich library of image transforms. Key categories: geometric (flip, rotate, crop, resize, perspective); color (jitter, grayscale, solarize, equalize); pixel-level (Gaussian blur, random erasing); and tensor operations (ToTensor, Normalize). They chain together with transforms.Compose. From PyTorch 2.0, transforms.v2 provides improved speed and additional augmentation primitives.

import torchvision.transforms as transforms

# Catalog of key transforms
train_tf = transforms.Compose([
    transforms.Resize(36),                    # scale
    transforms.RandomCrop(32, padding=4),     # spatial
    transforms.RandomHorizontalFlip(),        # mirror
    transforms.RandomRotation(10),            # rotate
    transforms.RandomGrayscale(p=0.1),        # de-color
    transforms.ColorJitter(
        brightness=0.3, contrast=0.3,
        saturation=0.3, hue=0.1),             # color
    transforms.GaussianBlur(kernel_size=3,    # blur
                            sigma=(0.1, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
print('Full augmentation pipeline ready')

All lessons in this course

  1. Convolution and Filters: Detecting Edges and Patterns
  2. Pooling Layers: Spatial Downsampling and Invariance
  3. Building and Training a CNN on CIFAR-10
  4. Data Augmentation: Transforms for Robustness
← Back to Machine Learning Academy