Source code for standard_e2e.caching.source_dataset_processor

import logging
import os
from abc import ABC, abstractmethod
from typing import Any, final

from standard_e2e.caching.adapters import AbstractAdapter
from standard_e2e.caching.segment_context import SegmentContextAggregator
from standard_e2e.data_structures import (
    FrameIndexData,
    StandardFrameData,
    TransformedFrameData,
)
from standard_e2e.enums import StandardFrameDataField
from standard_e2e.indexing import IndexDataGenerator
from standard_e2e.utils import _check_list_of_objects_or_none


[docs] class SourceDatasetProcessor(ABC): """Abstract base class for processing source datasets.""" def __init__( self, common_output_path: str, split: str, index_data_generator: IndexDataGenerator | None = None, adapters: list[AbstractAdapter] | None = None, context_aggregators: list[SegmentContextAggregator] | None = None, ): _check_list_of_objects_or_none(adapters, AbstractAdapter) _check_list_of_objects_or_none(context_aggregators, SegmentContextAggregator) if not isinstance(index_data_generator, (IndexDataGenerator, type(None))): raise TypeError( "index_data_generator must be an instance of IndexDataGenerator" f"or None, got {type(index_data_generator)}" ) self._split = split self._common_output_path = common_output_path self._specific_output_path = self._prepare_output_directory() self._inner_path = os.path.relpath( self._specific_output_path, common_output_path ) self._adapters = self._get_default_adapters() if adapters is None else adapters self._context_aggregators = ( self._get_default_context_aggregators() if context_aggregators is None else context_aggregators ) # Union of ``StandardFrameData`` attributes the registered adapter # chain reads. Per-dataset ``_prepare_standardized_frame_data`` # implementations consult ``self.needs_attr(...)`` to skip # building modalities no adapter consumes (lazy load). self._consumed_attrs: set[StandardFrameDataField] = set() for _adapter in self._adapters: self._consumed_attrs |= _adapter.consumes_attrs self._index_data_generator = ( index_data_generator if index_data_generator else IndexDataGenerator() ) if self._split not in self.allowed_splits: raise ValueError( f"Invalid split: {self._split}. Must be one of {self.allowed_splits}." ) logging.info("Initialized %s processor", self.dataset_name) logging.info("Using adapters: %s", [a.name for a in self._adapters]) logging.info("Consumed SFD attrs: %s", sorted(self._consumed_attrs)) logging.info("Specific output path: %s", self._specific_output_path)
[docs] def needs_attr(self, attr: StandardFrameDataField) -> bool: """Whether at least one registered adapter reads this ``StandardFrameData`` field. Used by per-dataset processors to skip expensive modality builds (cameras, lidar, hd_map, detections, …) when no adapter would consume them. ``True`` when ``attr`` is in the consumed-attrs union, plus a hard-coded special case: the identifier / index fields are always treated as needed since they are required for the cache + index regardless of adapter chain. """ always = { StandardFrameDataField.DATASET_NAME, StandardFrameDataField.SPLIT, StandardFrameDataField.SEGMENT_ID, StandardFrameDataField.FRAME_ID, StandardFrameDataField.TIMESTAMP, StandardFrameDataField.GLOBAL_POSITION, } if attr in always: return True return attr in self._consumed_attrs
def _get_default_adapters(self) -> list[AbstractAdapter]: raise NotImplementedError("Subclasses must implement this method.") def _get_default_context_aggregators(self) -> list[SegmentContextAggregator]: return [] @final def _prepare_output_directory(self) -> str: """Prepare the output directory for specific processed data.""" specific_output_path = os.path.join( self._common_output_path, self.dataset_name, self.split ) if not os.path.exists(specific_output_path): os.makedirs(specific_output_path) logging.info("Created output directory: %s", specific_output_path) else: logging.warning("Output directory already exists: %s", specific_output_path) return specific_output_path @property @abstractmethod def dataset_name(self) -> str: """Return the name of the dataset.""" raise NotImplementedError("Subclasses must implement this method.") @property def allowed_splits(self) -> list[str]: """Return the list of allowed splits for the dataset.""" raise NotImplementedError("Subclasses must implement this method.") @property def context_aggregators(self): return self._context_aggregators
[docs] @final def process_frame( self, raw_frame_data: Any ) -> tuple[TransformedFrameData, FrameIndexData]: standard_frame_data = self._prepare_standardized_frame_data(raw_frame_data) if not isinstance(standard_frame_data, StandardFrameData): raise TypeError( "_prepare_standardized_frame_data must return StandardFrameData, " f"got {type(standard_frame_data)}" ) transformed_modalities = {} for adapter in self._adapters: transformed_modalities.update(adapter.transform(standard_frame_data)) # Merge each adapter's per-frame metadata into aux_data so the .npz # carries adapter-side configuration (e.g. the HD-map BEV channel list) # that downstream consumers need to interpret modality outputs. merged_aux_data: dict | None if standard_frame_data.aux_data is None: merged_aux_data = None else: merged_aux_data = dict(standard_frame_data.aux_data) for adapter in self._adapters: adapter_meta = adapter.metadata if not adapter_meta: continue if merged_aux_data is None: merged_aux_data = {} merged_aux_data.update(adapter_meta) transformed_frame_data = TransformedFrameData( dataset_name=standard_frame_data.dataset_name, segment_id=standard_frame_data.segment_id, frame_id=standard_frame_data.frame_id, timestamp=standard_frame_data.timestamp, split=standard_frame_data.split, global_position=standard_frame_data.global_position, aux_data=merged_aux_data, extra_index_data=standard_frame_data.extra_index_data, _modality_data=transformed_modalities, ) frame_index_data = self._index_data_generator.generate_index_data( transformed_frame_data ) return transformed_frame_data, frame_index_data
@abstractmethod def _prepare_standardized_frame_data( self, raw_frame_data: Any ) -> StandardFrameData: """Process a single frame of data.""" # Implement the logic to process a single frame raise NotImplementedError("Subclasses must implement this method.")
[docs] @final def process_frame_and_save_data(self, raw_frame_data: Any) -> FrameIndexData: """ Process a single frame of raw data, save the processed frame data to disk, and return the corresponding FrameIndexData. """ frame_data: TransformedFrameData frame_index_data: FrameIndexData frame_data, frame_index_data = self.process_frame(raw_frame_data) filename = frame_data.filename if filename is None: raise ValueError("Frame data must have a filename before saving.") frame_data.to_npz(os.path.join(self._common_output_path, filename)) return frame_index_data
@property def split(self) -> str: """Return the dataset split.""" return self._split @property def output_path(self) -> str: """Return the output path for the processed dataset.""" return self._common_output_path @property def inner_path(self) -> str: """Return the inner path relative to the common output path.""" return self._inner_path @property def specific_output_path(self) -> str: """Return the specific output path for the dataset.""" return self._specific_output_path