Data Augmentation

Comprehensive guide to data augmentation techniques for training robust deep learning models.


Why Data Augmentation?

Benefits:

  • Increases training data diversity
  • Reduces overfitting
  • Improves model generalization
  • Makes models robust to variations
  • Cost-effective alternative to collecting more data

Common Use Cases:

  • Computer vision (images)
  • Natural language processing (text)
  • Audio processing
  • Time series data

Keras Data Augmentation

Basic Image Augmentation (Legacy)

 1from tensorflow.keras.preprocessing.image import ImageDataGenerator
 2
 3# Create augmentation generator
 4datagen = ImageDataGenerator(
 5    rotation_range=20,           # Random rotation ±20 degrees
 6    width_shift_range=0.2,       # Horizontal shift 20%
 7    height_shift_range=0.2,      # Vertical shift 20%
 8    shear_range=0.2,            # Shear transformation
 9    zoom_range=0.2,             # Random zoom
10    horizontal_flip=True,        # Random horizontal flip
11    vertical_flip=False,         # No vertical flip
12    fill_mode='nearest',         # Fill strategy for new pixels
13    brightness_range=[0.8, 1.2], # Brightness adjustment
14    rescale=1./255               # Normalize to [0,1]
15)
16
17# Fit on training data
18datagen.fit(X_train)
19
20# Generate augmented batches
21for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=32):
22    # Train model on augmented batch
23    model.fit(X_batch, y_batch)
 1from tensorflow import keras
 2from tensorflow.keras import layers
 3
 4# Build augmentation pipeline as model layers
 5data_augmentation = keras.Sequential([
 6    layers.RandomFlip("horizontal"),
 7    layers.RandomRotation(0.1),
 8    layers.RandomZoom(0.1),
 9    layers.RandomContrast(0.1),
10    layers.RandomBrightness(0.1),
11    layers.RandomTranslation(0.1, 0.1),
12])
13
14# Integrate into model
15model = keras.Sequential([
16    # Augmentation layers (only active during training)
17    data_augmentation,
18    
19    # Model architecture
20    layers.Conv2D(32, 3, activation='relu'),
21    layers.MaxPooling2D(),
22    layers.Conv2D(64, 3, activation='relu'),
23    layers.MaxPooling2D(),
24    layers.Flatten(),
25    layers.Dense(128, activation='relu'),
26    layers.Dense(10, activation='softmax')
27])
28
29# Augmentation automatically applied during training
30model.fit(X_train, y_train, epochs=10)

Advanced Augmentation Techniques

 1import tensorflow as tf
 2from tensorflow.keras import layers
 3
 4class MixupLayer(layers.Layer):
 5    """Mixup augmentation: blend two images and their labels"""
 6    def __init__(self, alpha=0.2, **kwargs):
 7        super().__init__(**kwargs)
 8        self.alpha = alpha
 9    
10    def call(self, inputs, training=None):
11        if not training:
12            return inputs
13        
14        images, labels = inputs
15        batch_size = tf.shape(images)[0]
16        
17        # Sample lambda from Beta distribution
18        lam = tf.random.uniform([batch_size, 1, 1, 1], 0, 1)
19        lam = tf.maximum(lam, 1 - lam)
20        
21        # Shuffle indices
22        indices = tf.random.shuffle(tf.range(batch_size))
23        
24        # Mix images
25        mixed_images = lam * images + (1 - lam) * tf.gather(images, indices)
26        
27        # Mix labels
28        lam_labels = tf.reshape(lam, [batch_size, 1])
29        mixed_labels = lam_labels * labels + (1 - lam_labels) * tf.gather(labels, indices)
30        
31        return mixed_images, mixed_labels
32
33class CutMixLayer(layers.Layer):
34    """CutMix augmentation: cut and paste patches between images"""
35    def __init__(self, alpha=1.0, **kwargs):
36        super().__init__(**kwargs)
37        self.alpha = alpha
38    
39    def call(self, inputs, training=None):
40        if not training:
41            return inputs
42        
43        images, labels = inputs
44        batch_size = tf.shape(images)[0]
45        height = tf.shape(images)[1]
46        width = tf.shape(images)[2]
47        
48        # Sample lambda
49        lam = tf.random.uniform([], 0, 1)
50        
51        # Random box
52        cut_ratio = tf.sqrt(1 - lam)
53        cut_h = tf.cast(cut_ratio * tf.cast(height, tf.float32), tf.int32)
54        cut_w = tf.cast(cut_ratio * tf.cast(width, tf.float32), tf.int32)
55        
56        cx = tf.random.uniform([], 0, width, dtype=tf.int32)
57        cy = tf.random.uniform([], 0, height, dtype=tf.int32)
58        
59        x1 = tf.clip_by_value(cx - cut_w // 2, 0, width)
60        y1 = tf.clip_by_value(cy - cut_h // 2, 0, height)
61        x2 = tf.clip_by_value(cx + cut_w // 2, 0, width)
62        y2 = tf.clip_by_value(cy + cut_h // 2, 0, height)
63        
64        # Shuffle and mix
65        indices = tf.random.shuffle(tf.range(batch_size))
66        shuffled_images = tf.gather(images, indices)
67        shuffled_labels = tf.gather(labels, indices)
68        
69        # Create mask
70        mask = tf.ones([batch_size, height, width, 1])
71        mask = tf.tensor_scatter_nd_update(
72            mask,
73            [[i, j, k, 0] for i in range(batch_size) 
74             for j in range(y1, y2) for k in range(x1, x2)],
75            tf.zeros([batch_size * (y2-y1) * (x2-x1)])
76        )
77        
78        mixed_images = mask * images + (1 - mask) * shuffled_images
79        
80        # Adjust lambda based on actual box size
81        lam = 1 - (tf.cast((x2-x1)*(y2-y1), tf.float32) / 
82                   tf.cast(height*width, tf.float32))
83        mixed_labels = lam * labels + (1 - lam) * shuffled_labels
84        
85        return mixed_images, mixed_labels
86
87# Use in model
88inputs = keras.Input(shape=(224, 224, 3))
89labels = keras.Input(shape=(10,))
90
91# Apply augmentation
92x, y = MixupLayer()([inputs, labels])
93
94# Continue with model architecture
95x = layers.Conv2D(64, 3, activation='relu')(x)
96# ... rest of model

Custom Augmentation Pipeline

 1import tensorflow as tf
 2
 3def augment_image(image, label):
 4    """Custom augmentation function"""
 5    # Random crop and resize
 6    image = tf.image.random_crop(image, size=[200, 200, 3])
 7    image = tf.image.resize(image, [224, 224])
 8    
 9    # Color augmentation
10    image = tf.image.random_brightness(image, 0.2)
11    image = tf.image.random_contrast(image, 0.8, 1.2)
12    image = tf.image.random_saturation(image, 0.8, 1.2)
13    image = tf.image.random_hue(image, 0.1)
14    
15    # Geometric augmentation
16    image = tf.image.random_flip_left_right(image)
17    image = tf.image.rot90(image, k=tf.random.uniform([], 0, 4, dtype=tf.int32))
18    
19    # Normalize
20    image = tf.clip_by_value(image, 0, 1)
21    
22    return image, label
23
24# Apply to dataset
25train_dataset = train_dataset.map(
26    augment_image,
27    num_parallel_calls=tf.data.AUTOTUNE
28)

PyTorch Data Augmentation

Basic Transforms

 1import torch
 2from torchvision import transforms
 3from torch.utils.data import DataLoader
 4
 5# Define augmentation pipeline
 6train_transform = transforms.Compose([
 7    # Resize
 8    transforms.Resize(256),
 9    transforms.RandomCrop(224),
10    
11    # Geometric transformations
12    transforms.RandomHorizontalFlip(p=0.5),
13    transforms.RandomVerticalFlip(p=0.1),
14    transforms.RandomRotation(degrees=15),
15    transforms.RandomAffine(
16        degrees=0,
17        translate=(0.1, 0.1),
18        scale=(0.9, 1.1),
19        shear=10
20    ),
21    
22    # Color augmentation
23    transforms.ColorJitter(
24        brightness=0.2,
25        contrast=0.2,
26        saturation=0.2,
27        hue=0.1
28    ),
29    transforms.RandomGrayscale(p=0.1),
30    
31    # Advanced
32    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
33    transforms.RandomErasing(p=0.5, scale=(0.02, 0.2)),
34    
35    # Normalize
36    transforms.ToTensor(),
37    transforms.Normalize(mean=[0.485, 0.456, 0.406],
38                        std=[0.229, 0.224, 0.225])
39])
40
41# Validation transform (no augmentation)
42val_transform = transforms.Compose([
43    transforms.Resize(256),
44    transforms.CenterCrop(224),
45    transforms.ToTensor(),
46    transforms.Normalize(mean=[0.485, 0.456, 0.406],
47                        std=[0.229, 0.224, 0.225])
48])
49
50# Apply to dataset
51from torchvision.datasets import ImageFolder
52
53train_dataset = ImageFolder('data/train', transform=train_transform)
54val_dataset = ImageFolder('data/val', transform=val_transform)
55
56train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
57val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Advanced Augmentation with Albumentations

 1import albumentations as A
 2from albumentations.pytorch import ToTensorV2
 3import cv2
 4
 5# Powerful augmentation library
 6train_transform = A.Compose([
 7    # Resize and crop
 8    A.RandomResizedCrop(224, 224, scale=(0.8, 1.0)),
 9    
10    # Geometric
11    A.HorizontalFlip(p=0.5),
12    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, 
13                       rotate_limit=15, p=0.5),
14    A.Perspective(scale=(0.05, 0.1), p=0.5),
15    A.ElasticTransform(alpha=1, sigma=50, p=0.3),
16    
17    # Color and lighting
18    A.RandomBrightnessContrast(brightness_limit=0.2, 
19                               contrast_limit=0.2, p=0.5),
20    A.HueSaturationValue(hue_shift_limit=20, 
21                         sat_shift_limit=30, 
22                         val_shift_limit=20, p=0.5),
23    A.RGBShift(r_shift_limit=15, g_shift_limit=15, 
24               b_shift_limit=15, p=0.5),
25    A.RandomGamma(gamma_limit=(80, 120), p=0.5),
26    
27    # Blur and noise
28    A.OneOf([
29        A.MotionBlur(blur_limit=5),
30        A.MedianBlur(blur_limit=5),
31        A.GaussianBlur(blur_limit=5),
32    ], p=0.3),
33    
34    A.OneOf([
35        A.GaussNoise(var_limit=(10.0, 50.0)),
36        A.ISONoise(),
37    ], p=0.3),
38    
39    # Weather effects
40    A.OneOf([
41        A.RandomRain(p=1.0),
42        A.RandomFog(p=1.0),
43        A.RandomSunFlare(p=1.0),
44    ], p=0.1),
45    
46    # Cutout/Erasing
47    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, 
48                    fill_value=0, p=0.5),
49    
50    # Normalize and convert
51    A.Normalize(mean=[0.485, 0.456, 0.406],
52                std=[0.229, 0.224, 0.225]),
53    ToTensorV2()
54])
55
56# Custom dataset with albumentations
57class AlbumentationsDataset(torch.utils.data.Dataset):
58    def __init__(self, image_paths, labels, transform=None):
59        self.image_paths = image_paths
60        self.labels = labels
61        self.transform = transform
62    
63    def __len__(self):
64        return len(self.image_paths)
65    
66    def __getitem__(self, idx):
67        image = cv2.imread(self.image_paths[idx])
68        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69        label = self.labels[idx]
70        
71        if self.transform:
72            augmented = self.transform(image=image)
73            image = augmented['image']
74        
75        return image, label

Mixup and CutMix in PyTorch

 1import numpy as np
 2
 3def mixup_data(x, y, alpha=1.0):
 4    """Mixup augmentation"""
 5    if alpha > 0:
 6        lam = np.random.beta(alpha, alpha)
 7    else:
 8        lam = 1
 9    
10    batch_size = x.size(0)
11    index = torch.randperm(batch_size).to(x.device)
12    
13    mixed_x = lam * x + (1 - lam) * x[index]
14    y_a, y_b = y, y[index]
15    
16    return mixed_x, y_a, y_b, lam
17
18def mixup_criterion(criterion, pred, y_a, y_b, lam):
19    """Loss function for mixup"""
20    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
21
22def cutmix_data(x, y, alpha=1.0):
23    """CutMix augmentation"""
24    lam = np.random.beta(alpha, alpha)
25    batch_size = x.size(0)
26    index = torch.randperm(batch_size).to(x.device)
27    
28    # Random box
29    _, _, H, W = x.size()
30    cut_ratio = np.sqrt(1 - lam)
31    cut_h = int(H * cut_ratio)
32    cut_w = int(W * cut_ratio)
33    
34    cx = np.random.randint(W)
35    cy = np.random.randint(H)
36    
37    x1 = np.clip(cx - cut_w // 2, 0, W)
38    y1 = np.clip(cy - cut_h // 2, 0, H)
39    x2 = np.clip(cx + cut_w // 2, 0, W)
40    y2 = np.clip(cy + cut_h // 2, 0, H)
41    
42    # Apply cutmix
43    x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
44    
45    # Adjust lambda
46    lam = 1 - ((x2 - x1) * (y2 - y1) / (H * W))
47    
48    y_a, y_b = y, y[index]
49    return x, y_a, y_b, lam
50
51# Training loop with mixup/cutmix
52for epoch in range(num_epochs):
53    for images, labels in train_loader:
54        images, labels = images.to(device), labels.to(device)
55        
56        # Apply augmentation
57        if np.random.rand() < 0.5:
58            images, labels_a, labels_b, lam = mixup_data(images, labels)
59        else:
60            images, labels_a, labels_b, lam = cutmix_data(images, labels)
61        
62        # Forward pass
63        outputs = model(images)
64        loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
65        
66        # Backward pass
67        optimizer.zero_grad()
68        loss.backward()
69        optimizer.step()

AutoAugment and RandAugment

 1from torchvision.transforms import autoaugment, RandAugment
 2
 3# AutoAugment (learned policies)
 4auto_transform = transforms.Compose([
 5    transforms.Resize(256),
 6    transforms.RandomCrop(224),
 7    autoaugment.AutoAugment(
 8        policy=autoaugment.AutoAugmentPolicy.IMAGENET
 9    ),
10    transforms.ToTensor(),
11    transforms.Normalize(mean=[0.485, 0.456, 0.406],
12                        std=[0.229, 0.224, 0.225])
13])
14
15# RandAugment (simpler, often better)
16rand_transform = transforms.Compose([
17    transforms.Resize(256),
18    transforms.RandomCrop(224),
19    RandAugment(num_ops=2, magnitude=9),
20    transforms.ToTensor(),
21    transforms.Normalize(mean=[0.485, 0.456, 0.406],
22                        std=[0.229, 0.224, 0.225])
23])

Text Data Augmentation

Keras Text Augmentation

 1import tensorflow as tf
 2import random
 3
 4def augment_text(text, label):
 5    """Text augmentation techniques"""
 6    # Synonym replacement (requires nlpaug or similar)
 7    # Back translation (translate to another language and back)
 8    # Random insertion, deletion, swap
 9    
10    # Simple example: random word dropout
11    words = tf.strings.split(text).numpy()
12    if len(words) > 3:
13        # Drop 10% of words randomly
14        keep_prob = 0.9
15        words = [w for w in words if random.random() < keep_prob]
16    
17    augmented_text = ' '.join(words)
18    return augmented_text, label

PyTorch Text Augmentation

 1import nlpaug.augmenter.word as naw
 2import nlpaug.augmenter.sentence as nas
 3
 4# Synonym replacement
 5syn_aug = naw.SynonymAug(aug_src='wordnet')
 6
 7# Contextual word embeddings
 8bert_aug = naw.ContextualWordEmbsAug(
 9    model_path='bert-base-uncased',
10    action="substitute"
11)
12
13# Back translation
14back_trans_aug = naw.BackTranslationAug(
15    from_model_name='facebook/wmt19-en-de',
16    to_model_name='facebook/wmt19-de-en'
17)
18
19# Apply augmentation
20text = "The movie was absolutely fantastic"
21augmented_texts = [
22    syn_aug.augment(text),
23    bert_aug.augment(text),
24    back_trans_aug.augment(text)
25]

Best Practices

When to Use Augmentation

 1# ✅ Good practices
 2# 1. Use augmentation only on training data
 3train_dataset = train_dataset.map(augment)
 4val_dataset = val_dataset  # No augmentation
 5
 6# 2. Start with simple augmentations
 7simple_aug = transforms.Compose([
 8    transforms.RandomHorizontalFlip(),
 9    transforms.RandomCrop(224),
10    transforms.ColorJitter(0.1, 0.1, 0.1)
11])
12
13# 3. Gradually add complexity
14# 4. Monitor validation performance
15# 5. Use domain-appropriate augmentations

Augmentation Strength

 1# Light augmentation (high-quality data)
 2light_transform = transforms.Compose([
 3    transforms.RandomHorizontalFlip(p=0.5),
 4    transforms.ColorJitter(brightness=0.1, contrast=0.1),
 5    transforms.ToTensor()
 6])
 7
 8# Medium augmentation (standard)
 9medium_transform = transforms.Compose([
10    transforms.RandomHorizontalFlip(p=0.5),
11    transforms.RandomRotation(15),
12    transforms.ColorJitter(brightness=0.2, contrast=0.2, 
13                          saturation=0.2, hue=0.1),
14    transforms.ToTensor()
15])
16
17# Heavy augmentation (small datasets)
18heavy_transform = transforms.Compose([
19    transforms.RandomHorizontalFlip(p=0.5),
20    transforms.RandomRotation(30),
21    transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),
22    transforms.ColorJitter(brightness=0.3, contrast=0.3, 
23                          saturation=0.3, hue=0.2),
24    transforms.RandomErasing(p=0.5),
25    transforms.ToTensor()
26])

Performance Optimization

 1# Keras: Use prefetching
 2train_dataset = train_dataset.map(
 3    augment,
 4    num_parallel_calls=tf.data.AUTOTUNE
 5).prefetch(tf.data.AUTOTUNE)
 6
 7# PyTorch: Use multiple workers
 8train_loader = DataLoader(
 9    train_dataset,
10    batch_size=32,
11    shuffle=True,
12    num_workers=4,  # Parallel data loading
13    pin_memory=True,  # Faster GPU transfer
14    persistent_workers=True  # Keep workers alive
15)

Common Pitfalls

 1# ❌ Don't augment validation/test data
 2# ❌ Don't use unrealistic augmentations
 3#    (e.g., vertical flip for text/faces)
 4# ❌ Don't over-augment (can hurt performance)
 5# ❌ Don't forget to normalize after augmentation
 6# ❌ Don't apply augmentation twice accidentally
 7
 8# ✅ Do validate augmentation visually
 9import matplotlib.pyplot as plt
10
11def visualize_augmentations(dataset, num_samples=5):
12    fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))
13    
14    for i, (image, label) in enumerate(dataset.take(num_samples)):
15        # Original
16        axes[0, i].imshow(image)
17        axes[0, i].set_title(f'Original: {label}')
18        axes[0, i].axis('off')
19        
20        # Augmented
21        aug_image, _ = augment(image, label)
22        axes[1, i].imshow(aug_image)
23        axes[1, i].set_title('Augmented')
24        axes[1, i].axis('off')
25    
26    plt.tight_layout()
27    plt.show()

  • Transfer Learning: Pre-trained models with augmentation
  • Semi-Supervised Learning: Augmentation for unlabeled data
  • Test-Time Augmentation: Average predictions over augmented test samples
  • Adversarial Training: Augmentation with adversarial examples

Related Snippets