Source code for standard_e2e.dataset_utils.modality_defaults.modality_defaults

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, final

import numpy as np

from standard_e2e.data_structures.containers import LidarPointCloud
from standard_e2e.data_structures.trajectory_data import Trajectory
from standard_e2e.enums import Intent, LidarComponent, Modality

if TYPE_CHECKING:
    from standard_e2e.caching.adapters.detections_bev_adapter import (
        Detections3DBEVAdapter,
    )
    from standard_e2e.caching.adapters.hdmap_adapter import HDMapBEVAdapter
    from standard_e2e.caching.adapters.lidar_adapter import LidarBEVAdapter


[docs] class ModalityDefaults(ABC): """ Base class for modality-specific default value handling on the fly. """
[docs] @final def normalize(self, raw_value: Any, modality: Modality) -> Any: """Normalize a raw modality payload using the subclass implementation. Args: raw_value: Raw modality payload (may be ``None``). modality: Modality to normalize; must be present in ``allowed_modalities``. Returns: Any: The normalized payload produced by ``_normalize``. Raises: ValueError: If ``modality`` is not a ``Modality`` enum member or is not allowed for this defaults handler. """ if not isinstance(modality, Modality): raise ValueError("modality must be an instance of Modality enum") if modality not in self.allowed_modalities: raise ValueError( f"modality {modality} is not allowed for this defaults handler" ) return self._normalize(raw_value, modality)
@abstractmethod def _normalize(self, raw_value: Any, modality: Modality) -> Any: ... @property @abstractmethod def allowed_modalities(self) -> list[Modality]: """Return a list of allowed modalities for this defaults handler."""
[docs] class PreferredTrajectoryDefaults(ModalityDefaults): """Substitute missing preference trajectories with ``n`` empty trajectories.""" def __init__(self, n: int = 3) -> None: if n <= 0: raise ValueError("n must be positive") self._n = int(n) def _normalize(self, raw_value: Any, modality: Modality) -> Any: if modality is Modality.PREFERENCE_TRAJECTORY: if raw_value is None or ( isinstance(raw_value, (list, tuple)) and len(raw_value) == 0 ): return [Trajectory() for _ in range(self._n)] return raw_value @property def allowed_modalities(self) -> list[Modality]: return [Modality.PREFERENCE_TRAJECTORY]
[docs] class IntentDefaults(ModalityDefaults): """Provide default handling for the ``INTENT`` modality (fallback UNKNOWN).""" def _normalize(self, raw_value: Any, modality: Modality) -> Any: if modality is Modality.INTENT: if raw_value is None: return Intent.UNKNOWN return raw_value @property def allowed_modalities(self) -> list[Modality]: return [Modality.INTENT]
[docs] class LidarPointCloudDefaults(ModalityDefaults): """Substitute a missing ``LIDAR_PC`` modality with an empty ``LidarPointCloud`` (xyz components, zero points). Lets datasets without lidar coexist with lidar-bearing ones in a single ``UnifiedE2EDataset``. """ _COMPONENTS = [LidarComponent.X, LidarComponent.Y, LidarComponent.Z] def _normalize(self, raw_value: Any, modality: Modality) -> Any: if modality is Modality.LIDAR_PC and raw_value is None: return LidarPointCloud( points=np.zeros((0, 3), dtype=np.float32), components=self._COMPONENTS, ) return raw_value @property def allowed_modalities(self) -> list[Modality]: return [Modality.LIDAR_PC]
[docs] class LidarBEVDefaults(ModalityDefaults): """Substitute a missing ``LIDAR_BEV`` modality with a zero tensor of the given ``(C, H, W)`` shape. Use :meth:`from_adapter` so the empty shape stays in sync with the BEV adapter's output. """ def __init__(self, shape: tuple[int, int, int]) -> None: if len(shape) != 3 or any(d <= 0 for d in shape): raise ValueError(f"shape must be (C, H, W) of positive ints; got {shape}") self._shape = shape def _normalize(self, raw_value: Any, modality: Modality) -> Any: if modality is Modality.LIDAR_BEV and raw_value is None: return np.zeros(self._shape, dtype=np.float32) return raw_value @property def allowed_modalities(self) -> list[Modality]: return [Modality.LIDAR_BEV]
[docs] @classmethod def from_adapter(cls, adapter: "LidarBEVAdapter") -> "LidarBEVDefaults": """Build defaults whose empty shape matches ``adapter.output_shape``.""" return cls(shape=adapter.output_shape)
[docs] class HDMapBEVDefaults(ModalityDefaults): """Substitute a missing ``HD_MAP_BEV`` modality with a zero tensor of the given ``(C, H, W)`` shape. Use :meth:`from_adapter` so the empty shape stays in sync with the BEV adapter's output. """ def __init__(self, shape: tuple[int, int, int]) -> None: if len(shape) != 3 or any(d <= 0 for d in shape): raise ValueError(f"shape must be (C, H, W) of positive ints; got {shape}") self._shape = shape def _normalize(self, raw_value: Any, modality: Modality) -> Any: if modality is Modality.HD_MAP_BEV and raw_value is None: return np.zeros(self._shape, dtype=np.float32) return raw_value @property def allowed_modalities(self) -> list[Modality]: return [Modality.HD_MAP_BEV]
[docs] @classmethod def from_adapter(cls, adapter: "HDMapBEVAdapter") -> "HDMapBEVDefaults": """Build defaults whose empty shape matches ``adapter.output_shape``.""" return cls(shape=adapter.output_shape)
class Detections3DBEVDefaults(ModalityDefaults): """Substitute a missing ``DETECTIONS_3D_BEV`` modality with a zero tensor of the given ``(C, H, W)`` shape. Use :meth:`from_adapter` so the empty shape stays in sync with the BEV adapter's output. """ def __init__(self, shape: tuple[int, int, int]) -> None: if len(shape) != 3 or any(d <= 0 for d in shape): raise ValueError(f"shape must be (C, H, W) of positive ints; got {shape}") self._shape = shape def _normalize(self, raw_value: Any, modality: Modality) -> Any: if modality is Modality.DETECTIONS_3D_BEV and raw_value is None: return np.zeros(self._shape, dtype=np.float32) return raw_value @property def allowed_modalities(self) -> list[Modality]: return [Modality.DETECTIONS_3D_BEV] @classmethod def from_adapter( cls, adapter: "Detections3DBEVAdapter" ) -> "Detections3DBEVDefaults": """Build defaults whose empty shape matches ``adapter.output_shape``.""" return cls(shape=adapter.output_shape)