Source code for standard_e2e.unified_dataset

from typing import Dict, Optional

import pandas as pd
import torch
from torch.utils.data import Dataset

from standard_e2e.data_structures import (
    TransformedFrameData,
    TransformedFrameDataBatch,
)
from standard_e2e.dataset_utils.augmentation import (
    FrameAugmentation,
    IdentityFrameAugmentation,
)
from standard_e2e.dataset_utils.frame_loader import (
    FrameLoader,
    create_frame_loaders_from_config,
)
from standard_e2e.dataset_utils.modality_defaults import (
    ModalityDefaults,
    _check_modality_defaults_dict,
)
from standard_e2e.enums import Modality
from standard_e2e.indexing.filters import IndexFilter


[docs] class UnifiedE2EDataset(Dataset): """ UnifiedE2EDataset wraps feature and label frame loaders, applies optional index filters and augmentations, and provides PyTorch-compatible access to transformed frames for end-to-end training or evaluation. Parameters ---------- index_data : pd.DataFrame Index dataframe describing available frames and metadata used by frame loaders. processed_data_path : str Root path where processed feature/label frames are stored. regime : str Augmentation regime; must be one of FrameAugmentation.ALLOWED_REGIMES. feature_loaders : list[FrameLoader] | None, optional Pre-instantiated feature frame loaders; mutually exclusive with feature_loaders_config. label_loaders : list[FrameLoader] | None, optional Pre-instantiated label frame loaders; mutually exclusive with label_loaders_config. feature_loaders_config : dict | list[dict] | None, optional Configuration(s) to build feature frame loaders when feature_loaders is None. label_loaders_config : dict | list[dict] | None, optional Configuration(s) to build label frame loaders when label_loaders is None. augmentations : list[FrameAugmentation] | None, optional Ordered list of frame augmentations to apply; defaults to IdentityFrameAugmentation. index_filters : list[IndexFilter] | None, optional Filters applied to index_data to select usable rows. modality_defaults : dict[Modality, ModalityDefaults] | None, optional Default modality-specific settings passed to frame loaders. """ def __init__( self, index_data: pd.DataFrame, processed_data_path: str, regime: str, feature_loaders: list[FrameLoader] | None = None, label_loaders: list[FrameLoader] | None = None, feature_loaders_config: dict | list[dict] | None = None, label_loaders_config: dict | list[dict] | None = None, augmentations: list[FrameAugmentation] | None = None, index_filters: list[IndexFilter] | None = None, modality_defaults: dict[Modality, ModalityDefaults] | None = None, ): if regime not in FrameAugmentation.ALLOWED_REGIMES: raise ValueError(f"Invalid regime: {regime}. Must be one of \ {FrameAugmentation.ALLOWED_REGIMES}.") _check_modality_defaults_dict(modality_defaults) self._regime = regime self._index_data = index_data.copy() self._index_filters = index_filters self._apply_index_filters() self._iloc_to_idx = self._index_data[ self._index_data[IndexFilter.FILTERING_COL_NAME] ].index self._processed_data_path = processed_data_path self._frame_loaders = self._create_feature_and_label_loaders( feature_loaders, label_loaders, feature_loaders_config, label_loaders_config, ) self._augmentations = augmentations or [IdentityFrameAugmentation()] self._modality_defaults = modality_defaults def _apply_index_filters(self): if self._index_filters: for index_filter in self._index_filters: self._index_data = index_filter.filter(self._index_data) if not self._index_filters: self._index_data[IndexFilter.FILTERING_COL_NAME] = True def _create_feature_and_label_loaders( self, feature_loaders: list[FrameLoader] | None, label_loaders: list[FrameLoader] | None, feature_loaders_config: dict | list[dict] | None, label_loaders_config: dict | list[dict] | None, ) -> list[FrameLoader]: if feature_loaders is None and feature_loaders_config is None: raise ValueError( "Both feature_loaders and feature_loaders_config are not defined." ) if label_loaders is None and label_loaders_config is None: raise ValueError( "Both label_loaders and label_loaders_config are not defined." ) if feature_loaders is not None and feature_loaders_config is not None: raise ValueError( "Both feature_loaders and feature_loaders_config are defined." ) if label_loaders is not None and label_loaders_config is not None: raise ValueError("Both label_loaders and label_loaders_config are defined.") if feature_loaders: for loader in feature_loaders: loader.set_index_data(self._index_data) else: if feature_loaders_config is None: raise ValueError("feature_loaders_config must not be None here") fl_config: list[dict] if isinstance(feature_loaders_config, dict): fl_config = [feature_loaders_config] elif isinstance(feature_loaders_config, list): fl_config = feature_loaders_config else: raise TypeError("feature_loaders_config must be dict or list of dict") feature_loaders = create_frame_loaders_from_config( processed_data_path=self._processed_data_path, frames_description=fl_config, location="features", index_data=self._index_data, ) if label_loaders: for loader in label_loaders: loader.set_index_data(self._index_data) else: if label_loaders_config is None: raise ValueError("label_loaders_config must not be None here") ll_config: list[dict] if isinstance(label_loaders_config, dict): ll_config = [label_loaders_config] elif isinstance(label_loaders_config, list): ll_config = label_loaders_config else: raise TypeError("label_loaders_config must be dict or list of dict") label_loaders = create_frame_loaders_from_config( processed_data_path=self._processed_data_path, frames_description=ll_config, location="labels", index_data=self._index_data, ) result = feature_loaders + label_loaders for loader in result: if loader.processed_data_path is None: loader.set_processed_data_path(self._processed_data_path) return result def __len__(self) -> int: return int(self._index_data[IndexFilter.FILTERING_COL_NAME].sum()) def _fetch_frames(self, current_frame_idx: int) -> dict[str, TransformedFrameData]: frames: dict[str, TransformedFrameData] = {} for frame_loader in self._frame_loaders: frame = frame_loader.load_frame( current_frame_idx=current_frame_idx, index_data=self._index_data, modality_defaults=self._modality_defaults, ) frames.update(frame) return frames def _augment_frames( self, frames: dict[str, TransformedFrameData] ) -> dict[str, TransformedFrameData]: for augmentation in self._augmentations: frames = augmentation.augment(frames, regime=self._regime) return frames def __getitem__(self, requested_ord_idx: int) -> dict[str, TransformedFrameData]: current_frame_idx = self._iloc_to_idx[requested_ord_idx] frames: dict[str, TransformedFrameData] = self._fetch_frames(current_frame_idx) frames = self._augment_frames(frames) return frames
[docs] @classmethod def collate_fn( cls, batch: list[dict[str, TransformedFrameData]], device: Optional[torch.device] = None, ) -> dict[str, TransformedFrameDataBatch]: """ Collate a list of {key: TransformedFrameData} into \ {key: TransformedFrameDataBatch}. TransformedFrameDataBatch internally uses PyTorch's \ collate with a Trajectory override. """ device = device or torch.device("cpu") collated: Dict[str, TransformedFrameDataBatch] = {} keys = batch[0].keys() for key in keys: collated[key] = TransformedFrameDataBatch( [sample[key] for sample in batch], device=device ) return collated