"""
Image transformations - elegant, minimal, powerful.
Uses torchvision v2 transforms. Convention: magnitude in [0, 1].
"""
from torchvision.transforms import v2 as T
from toolz import pipe
from .config import TRANSFORM_CATEGORIES
from .utils import numpy_to_pil, pil_to_numpy, clamp
# ============================================================
# TRANSFORM REGISTRY - Single Source of Truth
# ============================================================
TRANSFORMS = {
# Geometric
"rotate": lambda m: T.RandomRotation(degrees=int(180 * m)),
"flip_h": lambda m: T.RandomHorizontalFlip(p=m),
"flip_v": lambda m: T.RandomVerticalFlip(p=m),
"affine": lambda m: T.RandomAffine(degrees=0, translate=(m * 0.2, m * 0.2)),
"shear": lambda m: T.RandomAffine(degrees=0, shear=int(45 * m)),
"perspective": lambda m: T.RandomPerspective(distortion_scale=0.5 * m, p=1.0),
"elastic": lambda m: T.ElasticTransform(alpha=m * 50.0),
"random_crop": lambda m: T.RandomResizedCrop(
size=32, scale=(1 - m * 0.3, 1.0), ratio=(0.75, 1.33), antialias=True
),
# Color
"brightness": lambda m: T.ColorJitter(brightness=m * 0.5),
"contrast": lambda m: T.ColorJitter(contrast=m * 0.5),
"saturation": lambda m: T.ColorJitter(saturation=m * 0.5),
"hue": lambda m: T.ColorJitter(hue=0.1 * m),
"color_jitter": lambda m: T.ColorJitter(
brightness=m * 0.3, contrast=m * 0.3, saturation=m * 0.3, hue=0.05 * m
),
# Advanced color
"sharpen": lambda m: T.RandomAdjustSharpness(sharpness_factor=1 + m * 3, p=1.0),
"autocontrast": lambda m: T.RandomAutocontrast(p=1.0),
"equalize": lambda m: T.RandomEqualize(p=1.0),
"invert": lambda m: T.RandomInvert(p=1.0),
"solarize": lambda m: T.RandomSolarize(threshold=int(128 + 127 * m), p=1.0),
"posterize": lambda m: T.RandomPosterize(bits=max(1, int(2 + 6 * m)), p=1.0),
"grayscale": lambda m: T.RandomGrayscale(p=m),
# Blur
"blur": lambda m: T.GaussianBlur(kernel_size=3, sigma=(0.1, 3 + 20 * m)),
# Occlusion
"erasing": lambda m: T.RandomErasing(p=m, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
# Advanced
"channel_permute": lambda m: T.RandomChannelPermutation() if m > 0.5 else T.Identity(),
"photometric_distort": lambda m: T.RandomPhotometricDistort(
brightness=(1 - 0.3 * m, 1 + 0.3 * m),
contrast=(1 - 0.3 * m, 1 + 0.3 * m),
saturation=(1 - 0.3 * m, 1 + 0.3 * m),
hue=(-0.05 * m, 0.05 * m),
),
}
# ============================================================
# TRANSFORM OPERATIONS - Functional, composable
# ============================================================
[docs]
def apply_policy(image, policy):
"""
Apply augmentation policy to image.
Pure function: image → policy → augmented image.
Args:
image: RGB image (H, W, C) numpy array or PIL Image
policy: List of (transform_name, magnitude) tuples
Returns:
Augmented image as numpy array
"""
# Pipeline: numpy → PIL → transforms → numpy
was_numpy = hasattr(image, 'shape') # Check if numpy
pil_img = numpy_to_pil(image)
# Build and apply transform pipeline
transforms = [make_transform(name, mag) for name, mag in policy]
pipeline = T.Compose(transforms)
augmented = pipeline(pil_img)
# Return in original format
return pil_to_numpy(augmented) if was_numpy else augmented
[docs]
def create_augmenter(policy):
"""
Create reusable augmenter from policy.
Returns: Callable that augments images.
"""
transforms = [make_transform(name, mag) for name, mag in policy]
pipeline = T.Compose(transforms)
def augment(image):
pil_img = numpy_to_pil(image)
augmented = pipeline(pil_img)
return pil_to_numpy(augmented)
return augment
# ============================================================
# TRANSFORM CATEGORIES - Derived from config (SSOT)
# ============================================================
# ============================================================
# EXPORTS
# ============================================================
__all__ = [
"TRANSFORMS",
"make_transform",
"apply_policy",
"create_augmenter",
"get_transform_names",
]