Advanced Usage¶
This guide covers advanced features and customization options.
Custom Models¶
You can use your own PyTorch model instead of the default SimpleCNN:
import torch.nn as nn
from deepaugment import DeepAugment
class MyModel(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.fc = nn.Linear(128 * 8 * 8, num_classes)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
return self.fc(x)
# Use your model
aug = DeepAugment(
X_train, y_train, X_val, y_val,
model=MyModel # Pass the class, not an instance
)
best_policy = aug.optimize(iterations=50)
Filtering Transform Categories¶
Restrict augmentations to specific categories:
from deepaugment import DeepAugment
# Use only geometric transforms
aug = DeepAugment(
X_train, y_train, X_val, y_val,
transform_categories=["geometric"]
)
# Use multiple categories
aug = DeepAugment(
X_train, y_train, X_val, y_val,
transform_categories=["geometric", "color"]
)
Available categories:
"geometric": rotate, flip, affine, shear, perspective, elastic, random_crop"color": brightness, contrast, saturation, hue, color_jitter"advanced_color": sharpen, autocontrast, equalize, invert, solarize, posterize, grayscale"blur_noise": blur, gaussian_noise"occlusion": erasing, cutout"advanced": channel_permute, photometric_distort
Custom Reward Function¶
Define your own reward function for optimization:
from deepaugment import DeepAugment
def my_reward(entry):
"""Custom reward function.
Args:
entry: Dict with keys 'policy', 'score', 'iteration'
Returns:
float: Custom reward value
"""
score = entry['score']
iteration = entry['iteration']
# Penalize complex policies
policy_length = len(entry['policy'])
complexity_penalty = policy_length * 0.01
# Bonus for early convergence
early_bonus = (1.0 - iteration / 100) * 0.05
return score - complexity_penalty + early_bonus
aug = DeepAugment(
X_train, y_train, X_val, y_val,
custom_reward_fn=my_reward
)
best_policy = aug.optimize(iterations=100)
Adjusting Policy Size¶
Change the number of operations per policy:
from deepaugment import DeepAugment
# More operations (more complex policies)
aug = DeepAugment(
X_train, y_train, X_val, y_val,
n_operations=6 # Default is 4
)
# Fewer operations (simpler policies)
aug = DeepAugment(
X_train, y_train, X_val, y_val,
n_operations=2
)
Trade-offs:
More operations: Potentially better policies, but larger search space
Fewer operations: Faster optimization, but may miss complex patterns
Controlling Dataset Size¶
Use subsets of your data for faster optimization:
from deepaugment import DeepAugment
aug = DeepAugment(
X_train, y_train, X_val, y_val,
train_size=1000, # Use 1000 training samples
val_size=200 # Use 200 validation samples
)
This is recommended for:
Quick experiments
Large datasets
Limited computational resources
Note: The subset is randomly sampled based on random_state.
Search Method Selection¶
Choose between Bayesian Optimization and Random Search:
from deepaugment import DeepAugment
# Bayesian Optimization (default, recommended)
aug = DeepAugment(
X_train, y_train, X_val, y_val,
method="bayesian"
)
# Random Search (baseline comparison)
aug = DeepAugment(
X_train, y_train, X_val, y_val,
method="random"
)
Bayesian Optimization is almost always better, but Random Search can be useful as a baseline.
Training Configuration¶
Fine-tune the training process:
from deepaugment import DeepAugment
aug = DeepAugment(X_train, y_train, X_val, y_val)
best_policy = aug.optimize(
iterations=100,
epochs=20, # More epochs for better evaluation
samples=3, # Run 3 times and average (reduces noise)
batch_size=128, # Larger batch size
learning_rate=0.005, # Custom learning rate
)
Samples Parameter:
Setting samples > 1 runs training multiple times and averages the results. This:
Reduces noise in evaluation
Gives more reliable results
Increases computation time proportionally
Early Stopping¶
Stop optimization if no improvement is seen:
from deepaugment import DeepAugment
aug = DeepAugment(X_train, y_train, X_val, y_val)
best_policy = aug.optimize(
iterations=100,
early_stopping=True,
patience=10 # Stop after 10 iterations without improvement
)
This can save time when optimization has converged.
Resuming Optimization¶
Resume from a saved checkpoint:
from deepaugment import DeepAugment
# Initial run
aug = DeepAugment(
X_train, y_train, X_val, y_val,
save_history=True,
experiment_name="my_experiment"
)
aug.optimize(iterations=50)
# Resume later
aug_resumed = DeepAugment(
X_train, y_train, X_val, y_val,
resume_from="experiments/my_experiment_checkpoint.json"
)
aug_resumed.optimize(iterations=50) # Continue for 50 more iterations
Experiment Tracking¶
Track multiple experiments:
from deepaugment import DeepAugment
experiments = {
"baseline": {"method": "random"},
"bayesian_4ops": {"method": "bayesian", "n_operations": 4},
"bayesian_6ops": {"method": "bayesian", "n_operations": 6},
"geometric_only": {"transform_categories": ["geometric"]},
}
results = {}
for name, config in experiments.items():
print(f"\nRunning: {name}")
aug = DeepAugment(
X_train, y_train, X_val, y_val,
experiment_name=name,
save_history=True,
**config
)
best = aug.optimize(iterations=50)
results[name] = {
"best_policy": best,
"best_score": aug.best_score(),
}
# Compare results
for name, result in results.items():
print(f"{name}: {result['best_score']:.3f}")
Applying Policies to New Data¶
Use discovered policies to augment new data:
from deepaugment import apply_policy
import torch
# After optimization
best_policy = aug.best_policy()
# Apply to new images
def augment_dataset(images, policy):
augmented = []
for img in images:
# Convert to PIL/tensor if needed
img_tensor = torch.from_numpy(img).permute(2, 0, 1)
# Apply policy
aug_img = apply_policy(img_tensor, policy)
augmented.append(aug_img)
return augmented
# Use in training loop
for epoch in range(num_epochs):
for batch in dataloader:
images = batch['images']
augmented_images = augment_dataset(images, best_policy)
# Train your model...
Next Steps¶
Configuration - Complete configuration reference
CIFAR-10 Example - Complete CIFAR-10 example
Custom Models - More custom model examples
API Reference - Full API documentation