Feature Extraction: Freezing the Backbone
Learners will freeze all but the final classification head, train only new layers on a small custom dataset, and confirm dramatically reduced training time.
Feature Extraction vs Fine-Tuning
Transfer learning has two main strategies. In feature extraction, the pre-trained backbone is completely frozen — its weights do not change during training. Only the new classification head, which you add on top, learns from your data. In fine-tuning, the entire network or at least some backbone layers are also updated.
Feature extraction is the right choice when your dataset is small (less than a few thousand images) or when your images are similar to ImageNet (natural photographs of everyday objects). It is much faster since gradients do not flow through the backbone, and it avoids destroying carefully learned features with noisy updates from too little data.
Freezing Parameters in PyTorch
In PyTorch, each parameter tensor has a requires_grad attribute. Setting it to False prevents gradient computation for that tensor, effectively freezing it. The simplest way to freeze all backbone parameters is to iterate over model.parameters() and set requires_grad = False, then replace the classification head (which starts with new random weights, so requires_grad=True by default).
This is efficient: PyTorch's autograd skips frozen parameters during the backward pass, reducing memory usage and speeding up training significantly compared to fine-tuning the whole network.
import torchvision.models as models
import torch.nn as nn
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
# Freeze ALL backbone parameters
for param in model.parameters():
param.requires_grad = False
# Replace the classification head (creates new trainable parameters)
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)
# model.fc.parameters() have requires_grad=True by default
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable: {trainable:,} / Total: {total:,} ({trainable/total:.1%})')All lessons in this course
- Pre-trained Models in torchvision: ResNet, EfficientNet, and ViT
- Feature Extraction: Freezing the Backbone
- Fine-Tuning: Unfreezing and Low Learning Rates
- Domain Adaptation: Medical Imaging with Scarce Labels