Source code for standard_e2e.dataset_utils.augmentation.augmentation
from abc import ABC, abstractmethod
from standard_e2e.data_structures import TransformedFrameData
[docs]
class FrameAugmentation(ABC):
"""Abstract base class for per-frame augmentations.
Subclasses implement ``_augment`` to transform a mapping of
``TransformedFrameData`` keyed by frame name. ``augment`` enforces regime
validation and delegates to the subclass.
"""
ALLOWED_REGIMES = ["train", "val", "test"]
def __init__(self, *args, **kwargs):
"""Construct an augmentation; accepts arbitrary kwargs for subclasses."""
[docs]
def augment(
self, frames: dict[str, TransformedFrameData], regime: str
) -> dict[str, TransformedFrameData]:
"""Apply augmentation to the given frame data.
Args:
frames (dict[str, TransformedFrameData]): The frame data to augment.
regime (str): The regime for which the augmentation is applied.
Returns:
dict[str, TransformedFrameData]: The augmented frame data.
"""
if regime not in self.ALLOWED_REGIMES:
raise ValueError(
f"Invalid regime: {regime}. Must be one of {self.ALLOWED_REGIMES}."
)
return self._augment(frames, regime)
@abstractmethod
def _augment(
self, frames: dict[str, TransformedFrameData], regime: str
) -> dict[str, TransformedFrameData]:
"""Apply augmentation to the given frame data.
Args:
frames (dict[str, TransformedFrameData]): The frame data to augment.
Returns:
dict[str, TransformedFrameData]: The augmented frame data.
"""
raise NotImplementedError("Subclasses must implement the augment method.")
[docs]
class IdentityFrameAugmentation(FrameAugmentation):
"""Identity augmentation that returns the frame data unchanged."""
def _augment(
self, frames: dict[str, TransformedFrameData], regime: str
) -> dict[str, TransformedFrameData]:
"""Return the frame data unchanged."""
return frames