0Pricing
Machine Learning Academy · Lesson

Domain Adaptation: Medical Imaging with Scarce Labels

Learners will apply transfer learning from ImageNet to a chest X-ray dataset, implement class-weighted loss for imbalanced pathologies, and evaluate AUC-ROC.

The Medical Imaging Challenge

Medical imaging presents unique transfer learning challenges. Unlike natural photographs, chest X-rays, MRI scans, and histology slides look nothing like ImageNet images: they are grayscale or have domain-specific colour patterns, the features that matter (lesions, nodules, calcifications) are subtle and domain-specific, and labelled data requires expert radiologists — making large labelled datasets expensive and rare.

Despite these challenges, ImageNet pre-trained models consistently outperform training from scratch on medical imaging tasks, even when the visual appearance differs substantially. Universal low-level features (edge detectors, texture filters) transfer across domains, providing a strong initialisation that speeds convergence and improves generalisation with scarce labels.

The CheXpert Dataset: Multi-Label X-Ray Classification

CheXpert is a benchmark chest X-ray dataset with 224,316 images and 14 labels (Cardiomegaly, Pleural Effusion, Pneumonia, Atelectasis, etc.). In our scarce-label scenario, we simulate using only a small fraction — say 1% (about 2,243 images) — to mimic real-world clinical settings where annotation budget is limited.

This is a multi-label classification problem: each image can have multiple pathologies simultaneously, unlike single-label classification. The target is a vector of 14 binary values, and we use Binary Cross-Entropy with Logits (BCE) applied element-wise. Evaluation uses AUC-ROC per pathology, averaged across all 14 labels.

# Dataset setup (pseudo-code for illustration)
import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd

class CheXpertDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.transform = transform
        self.labels = ['Atelectasis', 'Cardiomegaly',
                       'Consolidation', 'Edema', 'Pleural Effusion']

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['Path']
        img = Image.open(img_path).convert('RGB')  # Convert grayscale to 3-ch
        label = torch.tensor(self.df.iloc[idx][self.labels].values.astype(float))
        if self.transform:
            img = self.transform(img)
        return img, label

All lessons in this course

  1. Pre-trained Models in torchvision: ResNet, EfficientNet, and ViT
  2. Feature Extraction: Freezing the Backbone
  3. Fine-Tuning: Unfreezing and Low Learning Rates
  4. Domain Adaptation: Medical Imaging with Scarce Labels
← Back to Machine Learning Academy