Data Augmentation
Comprehensive guide to data augmentation techniques for training robust deep learning models.
Why Data Augmentation?
Benefits:
- Increases training data diversity
- Reduces overfitting
- Improves model generalization
- Makes models robust to variations
- Cost-effective alternative to collecting more data
Common Use Cases:
- Computer vision (images)
- Natural language processing (text)
- Audio processing
- Time series data
Keras Data Augmentation
Basic Image Augmentation (Legacy)
1from tensorflow.keras.preprocessing.image import ImageDataGenerator
2
3# Create augmentation generator
4datagen = ImageDataGenerator(
5 rotation_range=20, # Random rotation ±20 degrees
6 width_shift_range=0.2, # Horizontal shift 20%
7 height_shift_range=0.2, # Vertical shift 20%
8 shear_range=0.2, # Shear transformation
9 zoom_range=0.2, # Random zoom
10 horizontal_flip=True, # Random horizontal flip
11 vertical_flip=False, # No vertical flip
12 fill_mode='nearest', # Fill strategy for new pixels
13 brightness_range=[0.8, 1.2], # Brightness adjustment
14 rescale=1./255 # Normalize to [0,1]
15)
16
17# Fit on training data
18datagen.fit(X_train)
19
20# Generate augmented batches
21for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=32):
22 # Train model on augmented batch
23 model.fit(X_batch, y_batch)
Modern Keras Layers (Recommended)
1from tensorflow import keras
2from tensorflow.keras import layers
3
4# Build augmentation pipeline as model layers
5data_augmentation = keras.Sequential([
6 layers.RandomFlip("horizontal"),
7 layers.RandomRotation(0.1),
8 layers.RandomZoom(0.1),
9 layers.RandomContrast(0.1),
10 layers.RandomBrightness(0.1),
11 layers.RandomTranslation(0.1, 0.1),
12])
13
14# Integrate into model
15model = keras.Sequential([
16 # Augmentation layers (only active during training)
17 data_augmentation,
18
19 # Model architecture
20 layers.Conv2D(32, 3, activation='relu'),
21 layers.MaxPooling2D(),
22 layers.Conv2D(64, 3, activation='relu'),
23 layers.MaxPooling2D(),
24 layers.Flatten(),
25 layers.Dense(128, activation='relu'),
26 layers.Dense(10, activation='softmax')
27])
28
29# Augmentation automatically applied during training
30model.fit(X_train, y_train, epochs=10)
Advanced Augmentation Techniques
1import tensorflow as tf
2from tensorflow.keras import layers
3
4class MixupLayer(layers.Layer):
5 """Mixup augmentation: blend two images and their labels"""
6 def __init__(self, alpha=0.2, **kwargs):
7 super().__init__(**kwargs)
8 self.alpha = alpha
9
10 def call(self, inputs, training=None):
11 if not training:
12 return inputs
13
14 images, labels = inputs
15 batch_size = tf.shape(images)[0]
16
17 # Sample lambda from Beta distribution
18 lam = tf.random.uniform([batch_size, 1, 1, 1], 0, 1)
19 lam = tf.maximum(lam, 1 - lam)
20
21 # Shuffle indices
22 indices = tf.random.shuffle(tf.range(batch_size))
23
24 # Mix images
25 mixed_images = lam * images + (1 - lam) * tf.gather(images, indices)
26
27 # Mix labels
28 lam_labels = tf.reshape(lam, [batch_size, 1])
29 mixed_labels = lam_labels * labels + (1 - lam_labels) * tf.gather(labels, indices)
30
31 return mixed_images, mixed_labels
32
33class CutMixLayer(layers.Layer):
34 """CutMix augmentation: cut and paste patches between images"""
35 def __init__(self, alpha=1.0, **kwargs):
36 super().__init__(**kwargs)
37 self.alpha = alpha
38
39 def call(self, inputs, training=None):
40 if not training:
41 return inputs
42
43 images, labels = inputs
44 batch_size = tf.shape(images)[0]
45 height = tf.shape(images)[1]
46 width = tf.shape(images)[2]
47
48 # Sample lambda
49 lam = tf.random.uniform([], 0, 1)
50
51 # Random box
52 cut_ratio = tf.sqrt(1 - lam)
53 cut_h = tf.cast(cut_ratio * tf.cast(height, tf.float32), tf.int32)
54 cut_w = tf.cast(cut_ratio * tf.cast(width, tf.float32), tf.int32)
55
56 cx = tf.random.uniform([], 0, width, dtype=tf.int32)
57 cy = tf.random.uniform([], 0, height, dtype=tf.int32)
58
59 x1 = tf.clip_by_value(cx - cut_w // 2, 0, width)
60 y1 = tf.clip_by_value(cy - cut_h // 2, 0, height)
61 x2 = tf.clip_by_value(cx + cut_w // 2, 0, width)
62 y2 = tf.clip_by_value(cy + cut_h // 2, 0, height)
63
64 # Shuffle and mix
65 indices = tf.random.shuffle(tf.range(batch_size))
66 shuffled_images = tf.gather(images, indices)
67 shuffled_labels = tf.gather(labels, indices)
68
69 # Create mask
70 mask = tf.ones([batch_size, height, width, 1])
71 mask = tf.tensor_scatter_nd_update(
72 mask,
73 [[i, j, k, 0] for i in range(batch_size)
74 for j in range(y1, y2) for k in range(x1, x2)],
75 tf.zeros([batch_size * (y2-y1) * (x2-x1)])
76 )
77
78 mixed_images = mask * images + (1 - mask) * shuffled_images
79
80 # Adjust lambda based on actual box size
81 lam = 1 - (tf.cast((x2-x1)*(y2-y1), tf.float32) /
82 tf.cast(height*width, tf.float32))
83 mixed_labels = lam * labels + (1 - lam) * shuffled_labels
84
85 return mixed_images, mixed_labels
86
87# Use in model
88inputs = keras.Input(shape=(224, 224, 3))
89labels = keras.Input(shape=(10,))
90
91# Apply augmentation
92x, y = MixupLayer()([inputs, labels])
93
94# Continue with model architecture
95x = layers.Conv2D(64, 3, activation='relu')(x)
96# ... rest of model
Custom Augmentation Pipeline
1import tensorflow as tf
2
3def augment_image(image, label):
4 """Custom augmentation function"""
5 # Random crop and resize
6 image = tf.image.random_crop(image, size=[200, 200, 3])
7 image = tf.image.resize(image, [224, 224])
8
9 # Color augmentation
10 image = tf.image.random_brightness(image, 0.2)
11 image = tf.image.random_contrast(image, 0.8, 1.2)
12 image = tf.image.random_saturation(image, 0.8, 1.2)
13 image = tf.image.random_hue(image, 0.1)
14
15 # Geometric augmentation
16 image = tf.image.random_flip_left_right(image)
17 image = tf.image.rot90(image, k=tf.random.uniform([], 0, 4, dtype=tf.int32))
18
19 # Normalize
20 image = tf.clip_by_value(image, 0, 1)
21
22 return image, label
23
24# Apply to dataset
25train_dataset = train_dataset.map(
26 augment_image,
27 num_parallel_calls=tf.data.AUTOTUNE
28)
PyTorch Data Augmentation
Basic Transforms
1import torch
2from torchvision import transforms
3from torch.utils.data import DataLoader
4
5# Define augmentation pipeline
6train_transform = transforms.Compose([
7 # Resize
8 transforms.Resize(256),
9 transforms.RandomCrop(224),
10
11 # Geometric transformations
12 transforms.RandomHorizontalFlip(p=0.5),
13 transforms.RandomVerticalFlip(p=0.1),
14 transforms.RandomRotation(degrees=15),
15 transforms.RandomAffine(
16 degrees=0,
17 translate=(0.1, 0.1),
18 scale=(0.9, 1.1),
19 shear=10
20 ),
21
22 # Color augmentation
23 transforms.ColorJitter(
24 brightness=0.2,
25 contrast=0.2,
26 saturation=0.2,
27 hue=0.1
28 ),
29 transforms.RandomGrayscale(p=0.1),
30
31 # Advanced
32 transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
33 transforms.RandomErasing(p=0.5, scale=(0.02, 0.2)),
34
35 # Normalize
36 transforms.ToTensor(),
37 transforms.Normalize(mean=[0.485, 0.456, 0.406],
38 std=[0.229, 0.224, 0.225])
39])
40
41# Validation transform (no augmentation)
42val_transform = transforms.Compose([
43 transforms.Resize(256),
44 transforms.CenterCrop(224),
45 transforms.ToTensor(),
46 transforms.Normalize(mean=[0.485, 0.456, 0.406],
47 std=[0.229, 0.224, 0.225])
48])
49
50# Apply to dataset
51from torchvision.datasets import ImageFolder
52
53train_dataset = ImageFolder('data/train', transform=train_transform)
54val_dataset = ImageFolder('data/val', transform=val_transform)
55
56train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
57val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
Advanced Augmentation with Albumentations
1import albumentations as A
2from albumentations.pytorch import ToTensorV2
3import cv2
4
5# Powerful augmentation library
6train_transform = A.Compose([
7 # Resize and crop
8 A.RandomResizedCrop(224, 224, scale=(0.8, 1.0)),
9
10 # Geometric
11 A.HorizontalFlip(p=0.5),
12 A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1,
13 rotate_limit=15, p=0.5),
14 A.Perspective(scale=(0.05, 0.1), p=0.5),
15 A.ElasticTransform(alpha=1, sigma=50, p=0.3),
16
17 # Color and lighting
18 A.RandomBrightnessContrast(brightness_limit=0.2,
19 contrast_limit=0.2, p=0.5),
20 A.HueSaturationValue(hue_shift_limit=20,
21 sat_shift_limit=30,
22 val_shift_limit=20, p=0.5),
23 A.RGBShift(r_shift_limit=15, g_shift_limit=15,
24 b_shift_limit=15, p=0.5),
25 A.RandomGamma(gamma_limit=(80, 120), p=0.5),
26
27 # Blur and noise
28 A.OneOf([
29 A.MotionBlur(blur_limit=5),
30 A.MedianBlur(blur_limit=5),
31 A.GaussianBlur(blur_limit=5),
32 ], p=0.3),
33
34 A.OneOf([
35 A.GaussNoise(var_limit=(10.0, 50.0)),
36 A.ISONoise(),
37 ], p=0.3),
38
39 # Weather effects
40 A.OneOf([
41 A.RandomRain(p=1.0),
42 A.RandomFog(p=1.0),
43 A.RandomSunFlare(p=1.0),
44 ], p=0.1),
45
46 # Cutout/Erasing
47 A.CoarseDropout(max_holes=8, max_height=32, max_width=32,
48 fill_value=0, p=0.5),
49
50 # Normalize and convert
51 A.Normalize(mean=[0.485, 0.456, 0.406],
52 std=[0.229, 0.224, 0.225]),
53 ToTensorV2()
54])
55
56# Custom dataset with albumentations
57class AlbumentationsDataset(torch.utils.data.Dataset):
58 def __init__(self, image_paths, labels, transform=None):
59 self.image_paths = image_paths
60 self.labels = labels
61 self.transform = transform
62
63 def __len__(self):
64 return len(self.image_paths)
65
66 def __getitem__(self, idx):
67 image = cv2.imread(self.image_paths[idx])
68 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69 label = self.labels[idx]
70
71 if self.transform:
72 augmented = self.transform(image=image)
73 image = augmented['image']
74
75 return image, label
Mixup and CutMix in PyTorch
1import numpy as np
2
3def mixup_data(x, y, alpha=1.0):
4 """Mixup augmentation"""
5 if alpha > 0:
6 lam = np.random.beta(alpha, alpha)
7 else:
8 lam = 1
9
10 batch_size = x.size(0)
11 index = torch.randperm(batch_size).to(x.device)
12
13 mixed_x = lam * x + (1 - lam) * x[index]
14 y_a, y_b = y, y[index]
15
16 return mixed_x, y_a, y_b, lam
17
18def mixup_criterion(criterion, pred, y_a, y_b, lam):
19 """Loss function for mixup"""
20 return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
21
22def cutmix_data(x, y, alpha=1.0):
23 """CutMix augmentation"""
24 lam = np.random.beta(alpha, alpha)
25 batch_size = x.size(0)
26 index = torch.randperm(batch_size).to(x.device)
27
28 # Random box
29 _, _, H, W = x.size()
30 cut_ratio = np.sqrt(1 - lam)
31 cut_h = int(H * cut_ratio)
32 cut_w = int(W * cut_ratio)
33
34 cx = np.random.randint(W)
35 cy = np.random.randint(H)
36
37 x1 = np.clip(cx - cut_w // 2, 0, W)
38 y1 = np.clip(cy - cut_h // 2, 0, H)
39 x2 = np.clip(cx + cut_w // 2, 0, W)
40 y2 = np.clip(cy + cut_h // 2, 0, H)
41
42 # Apply cutmix
43 x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
44
45 # Adjust lambda
46 lam = 1 - ((x2 - x1) * (y2 - y1) / (H * W))
47
48 y_a, y_b = y, y[index]
49 return x, y_a, y_b, lam
50
51# Training loop with mixup/cutmix
52for epoch in range(num_epochs):
53 for images, labels in train_loader:
54 images, labels = images.to(device), labels.to(device)
55
56 # Apply augmentation
57 if np.random.rand() < 0.5:
58 images, labels_a, labels_b, lam = mixup_data(images, labels)
59 else:
60 images, labels_a, labels_b, lam = cutmix_data(images, labels)
61
62 # Forward pass
63 outputs = model(images)
64 loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
65
66 # Backward pass
67 optimizer.zero_grad()
68 loss.backward()
69 optimizer.step()
AutoAugment and RandAugment
1from torchvision.transforms import autoaugment, RandAugment
2
3# AutoAugment (learned policies)
4auto_transform = transforms.Compose([
5 transforms.Resize(256),
6 transforms.RandomCrop(224),
7 autoaugment.AutoAugment(
8 policy=autoaugment.AutoAugmentPolicy.IMAGENET
9 ),
10 transforms.ToTensor(),
11 transforms.Normalize(mean=[0.485, 0.456, 0.406],
12 std=[0.229, 0.224, 0.225])
13])
14
15# RandAugment (simpler, often better)
16rand_transform = transforms.Compose([
17 transforms.Resize(256),
18 transforms.RandomCrop(224),
19 RandAugment(num_ops=2, magnitude=9),
20 transforms.ToTensor(),
21 transforms.Normalize(mean=[0.485, 0.456, 0.406],
22 std=[0.229, 0.224, 0.225])
23])
Text Data Augmentation
Keras Text Augmentation
1import tensorflow as tf
2import random
3
4def augment_text(text, label):
5 """Text augmentation techniques"""
6 # Synonym replacement (requires nlpaug or similar)
7 # Back translation (translate to another language and back)
8 # Random insertion, deletion, swap
9
10 # Simple example: random word dropout
11 words = tf.strings.split(text).numpy()
12 if len(words) > 3:
13 # Drop 10% of words randomly
14 keep_prob = 0.9
15 words = [w for w in words if random.random() < keep_prob]
16
17 augmented_text = ' '.join(words)
18 return augmented_text, label
PyTorch Text Augmentation
1import nlpaug.augmenter.word as naw
2import nlpaug.augmenter.sentence as nas
3
4# Synonym replacement
5syn_aug = naw.SynonymAug(aug_src='wordnet')
6
7# Contextual word embeddings
8bert_aug = naw.ContextualWordEmbsAug(
9 model_path='bert-base-uncased',
10 action="substitute"
11)
12
13# Back translation
14back_trans_aug = naw.BackTranslationAug(
15 from_model_name='facebook/wmt19-en-de',
16 to_model_name='facebook/wmt19-de-en'
17)
18
19# Apply augmentation
20text = "The movie was absolutely fantastic"
21augmented_texts = [
22 syn_aug.augment(text),
23 bert_aug.augment(text),
24 back_trans_aug.augment(text)
25]
Best Practices
When to Use Augmentation
1# ✅ Good practices
2# 1. Use augmentation only on training data
3train_dataset = train_dataset.map(augment)
4val_dataset = val_dataset # No augmentation
5
6# 2. Start with simple augmentations
7simple_aug = transforms.Compose([
8 transforms.RandomHorizontalFlip(),
9 transforms.RandomCrop(224),
10 transforms.ColorJitter(0.1, 0.1, 0.1)
11])
12
13# 3. Gradually add complexity
14# 4. Monitor validation performance
15# 5. Use domain-appropriate augmentations
Augmentation Strength
1# Light augmentation (high-quality data)
2light_transform = transforms.Compose([
3 transforms.RandomHorizontalFlip(p=0.5),
4 transforms.ColorJitter(brightness=0.1, contrast=0.1),
5 transforms.ToTensor()
6])
7
8# Medium augmentation (standard)
9medium_transform = transforms.Compose([
10 transforms.RandomHorizontalFlip(p=0.5),
11 transforms.RandomRotation(15),
12 transforms.ColorJitter(brightness=0.2, contrast=0.2,
13 saturation=0.2, hue=0.1),
14 transforms.ToTensor()
15])
16
17# Heavy augmentation (small datasets)
18heavy_transform = transforms.Compose([
19 transforms.RandomHorizontalFlip(p=0.5),
20 transforms.RandomRotation(30),
21 transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),
22 transforms.ColorJitter(brightness=0.3, contrast=0.3,
23 saturation=0.3, hue=0.2),
24 transforms.RandomErasing(p=0.5),
25 transforms.ToTensor()
26])
Performance Optimization
1# Keras: Use prefetching
2train_dataset = train_dataset.map(
3 augment,
4 num_parallel_calls=tf.data.AUTOTUNE
5).prefetch(tf.data.AUTOTUNE)
6
7# PyTorch: Use multiple workers
8train_loader = DataLoader(
9 train_dataset,
10 batch_size=32,
11 shuffle=True,
12 num_workers=4, # Parallel data loading
13 pin_memory=True, # Faster GPU transfer
14 persistent_workers=True # Keep workers alive
15)
Common Pitfalls
1# ❌ Don't augment validation/test data
2# ❌ Don't use unrealistic augmentations
3# (e.g., vertical flip for text/faces)
4# ❌ Don't over-augment (can hurt performance)
5# ❌ Don't forget to normalize after augmentation
6# ❌ Don't apply augmentation twice accidentally
7
8# ✅ Do validate augmentation visually
9import matplotlib.pyplot as plt
10
11def visualize_augmentations(dataset, num_samples=5):
12 fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))
13
14 for i, (image, label) in enumerate(dataset.take(num_samples)):
15 # Original
16 axes[0, i].imshow(image)
17 axes[0, i].set_title(f'Original: {label}')
18 axes[0, i].axis('off')
19
20 # Augmented
21 aug_image, _ = augment(image, label)
22 axes[1, i].imshow(aug_image)
23 axes[1, i].set_title('Augmented')
24 axes[1, i].axis('off')
25
26 plt.tight_layout()
27 plt.show()
Related Topics
- Transfer Learning: Pre-trained models with augmentation
- Semi-Supervised Learning: Augmentation for unlabeled data
- Test-Time Augmentation: Average predictions over augmented test samples
- Adversarial Training: Augmentation with adversarial examples
Related Snippets
- DNN Policy Learning Theory
Deep Neural Network policy learning with mathematical foundations. Policy … - Graph RAG Techniques
Graph-based Retrieval-Augmented Generation for enhanced context and relationship … - Image to Vector Embeddings
Image embeddings convert visual content into dense vector representations that … - Keras Essentials
High-level Keras API for building neural networks quickly. Installation 1# Keras … - LangChain Recipes
Practical recipes for building LLM applications with LangChain: prompts, chains, … - ONNX Model Conversion
ONNX (Open Neural Network Exchange) for converting models between frameworks. … - PyTorch Essentials
Essential PyTorch operations and patterns for deep learning. Installation 1# CPU … - Q-Learning Theory
Q-Learning algorithm theory with mathematical foundations. Markov Decision … - RAG (Retrieval-Augmented Generation)
Retrieval-Augmented Generation techniques for enhancing LLM responses with … - Sound to Vector Embeddings
Audio embeddings convert sound signals (speech, music, environmental sounds) … - Tensor Mathematics & Backpropagation
Tensor mathematics fundamentals and backpropagation theory with detailed … - TensorFlow Essentials
Essential TensorFlow operations and patterns for deep learning. Installation 1# … - TensorFlow Lite
TensorFlow Lite for deploying models on mobile and embedded devices. Convert … - Text to Vector Embeddings
Text embeddings convert textual content into dense vector representations that …