Source code for standard_e2e.caching.source_dataset_converter

"""
Module for converting source datasets to target formats with optional parallel
processing and TFRecord support.
"""

import argparse
import logging
import multiprocessing
import os
from abc import ABC, abstractmethod
from typing import Any, Optional, Union, cast, final

import pandas as pd
import tensorflow as tf
from tqdm import tqdm

from standard_e2e.caching.segment_context import SegmentContextAggregator
from standard_e2e.caching.source_dataset_processor import SourceDatasetProcessor
from standard_e2e.data_structures import FrameIndexData

# Worker-local state for the parallel pool. Each worker process gets its own
# copy via ``_init_worker`` (see ``_convert_frames``); this module-level slot
# is what the per-task ``_process_frame_in_worker`` function reads. Sending
# the processor through Pool.initializer pickles it once per worker rather
# than once per task — the latter ships hundreds of MB on every dispatch
# when the processor carries any non-trivial state (e.g. Waymo Perception's
# prescanned HD-map cache).
_WORKER_PROCESSOR: Optional[SourceDatasetProcessor] = None


def _init_worker(processor: SourceDatasetProcessor) -> None:
    global _WORKER_PROCESSOR
    _WORKER_PROCESSOR = processor


def _process_frame_in_worker(raw_frame_data: Any) -> FrameIndexData:
    assert _WORKER_PROCESSOR is not None, "worker processor not initialized"
    return _WORKER_PROCESSOR.process_frame_and_save_data(raw_frame_data)


[docs] class SourceDatasetConverter(ABC): """Base class orchestrating conversion from raw datasets to processed frames.""" def __init__( self, source_processor: SourceDatasetProcessor, input_path: str, split: str, num_workers: int = 0, do_parallel_processing: bool = True, arguments: Optional[Union[argparse.Namespace, dict]] = None, ): self._source_processor = source_processor self._input_path = input_path self._split = split self._num_workers = num_workers if self._num_workers == 0: self._num_workers = multiprocessing.cpu_count() self._do_parallel_processing = do_parallel_processing # Normalize arguments into argparse.Namespace if arguments is None: self._args = argparse.Namespace() elif isinstance(arguments, dict): self._args = argparse.Namespace(**arguments) else: self._args = arguments self._source_dataset_iterator = self._get_source_dataset_iterator() @property def dataset_name(self) -> str: """Return the name of the dataset.""" return self._source_processor.dataset_name
[docs] @classmethod def get_arg_parser(cls): """Return an argument parser for the converter.""" parser = argparse.ArgumentParser() parser.add_argument( "--input_path", type=str, required=True, help="Path to the input directory containing the source dataset.", ) parser.add_argument( "--output_path", type=str, required=True, help="Path to the output directory where the converted data will be saved.", ) parser.add_argument( "--split", type=str, required=True, help="Split of the dataset to process." ) parser.add_argument( "--num_workers", type=int, default=multiprocessing.cpu_count(), help="Number of worker processes to use for data processing.", ) parser.add_argument( "--do_parallel_processing", action="store_true", help="Whether to use parallel processing for data conversion.", ) parser.add_argument( "--config_file", type=str, required=True, help="Path to the configuration file.", ) return parser
@abstractmethod def _get_source_dataset_iterator(self): """Return an iterator over the source dataset.""" raise NotImplementedError("Subclasses must implement this method.") @property def max_workers(self) -> Optional[int]: """Optional cap on parallel-pool size; ``None`` means no cap. Used by datasets where pool throughput plateaus or regresses past a certain worker count -- typically because the processor carries large state (e.g. a prescanned HD-map cache) and ``Pool``'s per-task dispatch overhead grows with worker count. Subclasses whose processors are small can leave this at ``None``. """ return None @property def multiprocessing_start_method(self) -> str: """Start method for the worker pool. Default ``"spawn"`` is the conservative choice: TensorFlow and OpenCV both keep global thread / mutex state that ``fork()`` inherits in a deadlock-prone way (typically before the first frame completes). Spawn pays a per-worker import cost (~5 s per worker, dominated by TensorFlow) but is the safe pattern for any worker that may run TF or cv2 work post-fork. Subclasses whose worker hot path is fully TF-free (no ``tf.io.decode_image``, no ``frame_utils.*`` calls, etc.) may override to ``"fork"`` to avoid the spawn import tax. This is a very large speedup on small / DEBUG runs and a meaningful one on full splits. """ return "spawn" @final def _run_context_aggregators(self, index_df: pd.DataFrame) -> None: logging.info("Running context aggregators...") for context_aggregator in self._source_processor.context_aggregators: logging.info( "Processing with context aggregator: %s", context_aggregator.__class__.__name__, ) if not isinstance(context_aggregator, SegmentContextAggregator): raise TypeError( "context_aggregator must be an instance of SegmentContextAggregator" f", got {type(context_aggregator)}" ) context_aggregator.process( index_df, num_workers=self._num_workers, do_parallel=self._do_parallel_processing, ) @final def _convert_frames(self) -> pd.DataFrame: """Convert the source dataset to the target format.""" logging.info("Processing input path: %s", self._input_path) logging.info("Output path: %s", self._source_processor.output_path) logging.info("Processing split: %s", self._split) if self._do_parallel_processing: effective_workers = self._num_workers if self.max_workers is not None and effective_workers > self.max_workers: logging.info( "Capping pool size from %d to %d for %s (see %s.max_workers)", effective_workers, self.max_workers, self._source_processor.dataset_name, type(self).__name__, ) effective_workers = self.max_workers logging.info( "Using parallel processing with %d workers for dataset conversion.", effective_workers, ) # Start method comes from the converter subclass — see # ``SourceDatasetConverter.multiprocessing_start_method`` for the # default ``"spawn"`` rationale and the ``"fork"`` opt-out. ctx = multiprocessing.get_context(self.multiprocessing_start_method) logging.info("Pool start method: %s", self.multiprocessing_start_method) # Ship the processor to each worker exactly once via # ``initializer`` instead of letting ``pool.imap`` pickle a bound # method on every dispatch. The bound-method form was a # hundreds-of-MB-per-task tax for processors carrying prescanned # state and capped throughput at ~1 fr/s on Waymo Perception. with ctx.Pool( effective_workers, initializer=_init_worker, initargs=(self._source_processor,), ) as pool: results = list( tqdm( pool.imap( _process_frame_in_worker, self._source_dataset_iterator, ), desc=f"Processing \ {self._source_processor.dataset_name} dataset", ) ) else: logging.info( "Processing dataset without parallelization for %s.", self._source_processor.dataset_name, ) results = [] for raw_frame_data in tqdm( self._source_dataset_iterator, desc=f"Processing {self._source_processor.dataset_name} dataset", ): results.append( self._source_processor.process_frame_and_save_data(raw_frame_data) ) logging.info( "Conversion completed for %s dataset.", self._source_processor.dataset_name ) logging.info("Converted %d frames from the source dataset.", len(results)) logging.info("Data saved to %s", self._source_processor.specific_output_path) index_df = FrameIndexData.save_index_data( results, self._source_processor.specific_output_path ) return index_df
[docs] @final def convert(self) -> None: """Convert all frames then run any configured context aggregators.""" try: index_df = self._convert_frames() self._run_context_aggregators(index_df) finally: self._cleanup_after_convert()
def _cleanup_after_convert(self) -> None: """Hook for subclasses to remove transient artifacts created at init time (e.g. an HD-map prescan scratch dir). Called from ``convert()``'s ``finally`` block so it runs whether conversion succeeded or raised. """
# Default no-op; subclasses override when they have something to # clean up. class TFRecSourceDatasetConverter(SourceDatasetConverter, ABC): """ TFRecSourceDatasetConverter is an abstract base class for converting source datasets into TensorFlow Record (TFRecord) format. This class extends SourceDatasetConverter and provides additional functionality for handling sharded TFRecord datasets. It defines command-line arguments for sharding, enforces implementation of file retrieval logic, and provides an iterator over the source dataset with optional sharding support. Methods ------- get_arg_parser(cls): Returns an argument parser with additional arguments for sharding (n_shards, shard_id). _get_processing_files(self): Abstract method that must be implemented by subclasses to return a list of files to process. _source_dataset_iterator(self): Returns a tf.data.TFRecordDataset iterator over the files to process, applying sharding if specified. Attributes ---------- _args : argparse.Namespace Parsed command-line arguments, expected to include n_shards and shard_id. _input_path : str Path to the input data directory. _split : str Name of the dataset split (e.g., 'train', 'val', 'test'). """ def __init__(self, *args, **kwargs): # Provide default sharding args if not supplied if kwargs.get("arguments") is None: kwargs["arguments"] = { "n_shards": 1, "shard_id": 0, } super().__init__(*args, **kwargs) @classmethod def get_arg_parser(cls): """Return an argument parser for the converter.""" parser = super().get_arg_parser() parser.add_argument( "--n_shards", type=int, default=1, help="Number of shards to split the output data into.", ) parser.add_argument( "--shard_id", type=int, default=0, help="ID of the shard to process." ) return parser @abstractmethod def _get_processing_files(self): """Return a list of files to process.""" raise NotImplementedError("Subclasses must implement this method.") def _get_source_dataset_iterator(self): processing_files = self._get_processing_files() processing_files = tf.io.matching_files(processing_files) if os.environ.get("STANDARD_E2E_DEBUG", "false").lower() == "true": logging.info( "STANDARD_E2E_DEBUG is set to true, processing only the first file." ) processing_files = processing_files[:1] if len(processing_files) == 0: raise FileNotFoundError( f"No files found in {self._input_path} with pattern \ {self._split}*.tfrecord*" ) logging.info( "Found %d files to process for split '%s' of dataset '%s' before sharding.", len(processing_files), self._split, self._source_processor.dataset_name, ) # Use generic Dataset variable for typing compatibility # (mypy mismatch with subclass) dataset = cast( tf.data.Dataset, tf.data.TFRecordDataset( processing_files, compression_type="", ), ) if getattr(self._args, "n_shards", 1) > 1: logging.info( "Sharding dataset into %d shards, processing shard %d.", self._args.n_shards, self._args.shard_id, ) dataset = dataset.shard(self._args.n_shards, self._args.shard_id) else: logging.info("Processing without sharding.") return dataset