Source code for standard_e2e.data_structures.frame_data
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import pandas as pd
import torch
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
field_validator,
model_validator,
)
from torch.utils.data._utils.collate import collate as _torch_collate
from torch.utils.data._utils.collate import default_collate_fn_map
from standard_e2e.constants import INDEX_FILE_NAME
from standard_e2e.data_structures.containers import (
BatchedFrameDetections3D,
BatchedLidarPointCloud,
CameraData,
FrameDetections3D,
HDMap,
LidarData,
LidarPointCloud,
)
from standard_e2e.data_structures.trajectory_data import BatchedTrajectory, Trajectory
from standard_e2e.dataset_utils.modality_defaults import ModalityDefaults
from standard_e2e.enums import (
CameraDirection,
Intent,
Modality,
)
from standard_e2e.enums import TrajectoryComponent as TC
[docs]
class StandardFrameData(BaseModel):
"""
Represents a single frame data in intermediate standardized format:
Raw frame data -> StandardFrameData -> TransformedFrameData.
Attributes:
timestamp (float): Timestamp of the frame in seconds.
frame_id (int): Unique identifier of the frame within a sequence.
segment_id (str): Unique identifier of the segment this frame belongs to.
dataset_name (str): Name of the dataset this frame belongs to.
split (str): Dataset split (e.g., "train", "val", "test").
global_position (Optional[Trajectory]):
Pose for the ego entity at this frame in global coordinates.
intent (Optional[Intent]):
Predicted or annotated intent associated with the frame.
cameras (dict[CameraDirection, CameraData]):
Camera data keyed by camera direction.
lidar (Optional[LidarData]): LiDAR data for the frame, if available.
future_states (Optional[Trajectory]):
Future trajectory states relative to this frame.
past_states (Optional[Trajectory]):
Past trajectory states leading up to this frame.
hd_map (Optional[HDMap]): HD map snapshot in vehicle frame at this
frame's timestamp; consumed by ``HDMapBEVAdapter`` (and future
vector adapters). In-memory only — not persisted to ``.npz``.
frame_detections_3d (Optional[FrameDetections3D]):
3D detections present in the frame.
aux_data (Optional[Dict[str, Any]]): Additional auxiliary data.
extra_index_data (Optional[Dict[str, Any]]): Extra indexing or lookup data.
"""
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)
timestamp: float
frame_id: int
segment_id: str
dataset_name: str
split: str
global_position: Optional[Trajectory] = None
intent: Optional[Intent] = None
cameras: dict[CameraDirection, CameraData] = Field(default_factory=dict)
lidar: Optional[LidarData] = None
future_states: Optional[Trajectory] = None
past_states: Optional[Trajectory] = None
hd_map: Optional[HDMap] = None
frame_detections_3d: Optional[FrameDetections3D] = None
aux_data: Optional[Dict[str, Any]] = None
extra_index_data: Optional[Dict[str, Any]] = None
@field_validator("cameras")
@classmethod
def _validate_cameras(cls, v):
# Ensure keys are CameraDirection
if not isinstance(v, dict):
raise TypeError("cameras must be a dict")
for k, val in v.items():
if not isinstance(k, CameraDirection):
raise TypeError("cameras keys must be CameraDirection")
if not isinstance(val, CameraData):
raise TypeError("cameras values must be CameraData")
return v
[docs]
class FrameIndexData(BaseModel):
"""Frame index metadata (for parquet serialization)."""
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)
dataset_name: str
segment_id: str
frame_id: int | str
timestamp: float
split: str
filename: str
extra_index_data: dict | None = None
[docs]
def to_index_dict(self) -> dict:
"""Convert index metadata to a flat dictionary used for Parquet writes.
Extra fields stored in ``extra_index_data`` are prefixed with ``extra_`` to
keep backward compatibility with legacy index files.
"""
d = self.model_dump()
# Flatten extra_index_data to prefixed keys as before
extra = d.pop("extra_index_data", None)
if extra:
for k, v in extra.items():
d[f"extra_{k}"] = v
return d
[docs]
@classmethod
def save_index_data(
cls, index_data: list["FrameIndexData"], output_path: str
) -> pd.DataFrame:
"""Persist a list of index entries to a sorted Parquet file.
Args:
index_data: Sequence of frame index records to serialize.
output_path: Directory where the Parquet file will be written.
Returns:
pd.DataFrame: The sorted index dataframe that was written to disk.
"""
df = pd.DataFrame([d.to_index_dict() for d in index_data])
df.sort_values(by=["segment_id", "frame_id"], inplace=True)
df.to_parquet(os.path.join(output_path, INDEX_FILE_NAME), index=False)
logging.info(
"Index data saved to %s", os.path.join(output_path, INDEX_FILE_NAME)
)
return df
def _to_device_recursive(x: Any, device: torch.device) -> Any:
"""Recursively move tensors/BatchedTrajectory (and nested containers) to device."""
if isinstance(x, torch.Tensor):
return x.to(device=device, non_blocking=True)
if isinstance(x, BatchedTrajectory):
x.to(device) # in-place; returns self
return x
if isinstance(x, BatchedLidarPointCloud):
x.to(device) # in-place; returns self
return x
if isinstance(x, dict):
return {k: _to_device_recursive(v, device) for k, v in x.items()}
if isinstance(x, list):
return [_to_device_recursive(v, device) for v in x]
if isinstance(x, tuple):
return tuple(_to_device_recursive(v, device) for v in x)
return x
[docs]
class TransformedFrameData(BaseModel):
"""Represents a single frame data with associated metadata, transformed by Adapters.
A finalized, training-ready structure, loaded by Dataset:
Raw frame data -> StandardFrameData -> TransformedFrameData.
Attributes:
dataset_name (str): Name of the dataset containing this frame.
segment_id (str): Identifier of the sequence/segment this frame belongs to.
frame_id (int): Unique identifier of the frame within the segment.
timestamp (float): Timestamp of the frame in seconds.
split (str): Dataset split (e.g., train/val/test) for this frame.
global_position (Trajectory | None): World-frame position data; defaults
to a zeroed position if not provided.
filename (str | None): Auto-generated file name of the frame npz; computed
as ``{dataset_name}/{split}/{segment_id}_{frame_id}.npz`` when missing.
aux_data (dict[str, Any] | None): Optional auxiliary metadata.
extra_index_data (dict[str, Any] | None): Optional extra indexing metadata.
timestamp_diff (float | None): Optional time delta to adjacent frames.
modality_defaults (dict[Modality, ModalityDefaults] | None): Optional
default handlers used to normalize modality data.
Private Attributes:
_modality_data (dict[Modality, Any]): Raw modality-specific payloads,
stored privately for compatibility with legacy code.
Use get_modality_data() to access.
Methods:
set_modality_data(modality, data): Store payload for a given modality.
get_modality_data(modality, set_default=True): Retrieve payload for a
modality, optionally normalizing via its default handler.
get_present_modality_keys(): List the modalities currently stored.
to_npz(path): Serialize the frame (including modality data) to a compressed
``.npz`` file.
from_npz(path, required_modalities=None): Loads a frame from ``.npz``,
optionally ensuring required modalities exist (inserting None) and
applying defaults.
"""
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)
dataset_name: str
segment_id: str
frame_id: int
timestamp: float
split: str
global_position: Optional[Trajectory] = None
filename: str | None = None # auto-filled
# Private so Pydantic doesn't complain about leading underscore but we can
# still expose same attribute name for existing code. We'll handle manual
# (de)serialization.
_modality_data: Dict[Modality, Any] = PrivateAttr(default_factory=dict)
aux_data: Optional[Dict[str, Any]] = None
extra_index_data: Optional[Dict[str, Any]] = None
timestamp_diff: Optional[float] = None
modality_defaults: dict[Modality, ModalityDefaults] | None = None
def __init__(self, **data: Any):
# Allow construction with _modality_data while keeping it private.
modality_data = data.pop("_modality_data", None)
super().__init__(**data)
if modality_data is not None:
# Assign after validation (no validation enforced on dict contents here)
self._modality_data.update(modality_data)
@model_validator(mode="after")
def _set_filename(self):
if not self.filename:
self.filename = os.path.join(
self.dataset_name, self.split, f"{self.segment_id}_{self.frame_id}.npz"
)
return self
@model_validator(mode="after")
def _set_default_global_position(self):
if not self.global_position:
self.global_position = Trajectory(
{
TC.TIMESTAMP: [self.timestamp],
TC.X: [0],
TC.Y: [0],
TC.Z: [0],
TC.HEADING: [0],
TC.IS_VALID: [0],
}
)
return self
[docs]
def set_modality_data(self, modality: Modality, data: Any):
"""Store data associated with the given modality.
Args:
modality (Modality): The modality identifier used as the key.
data (Any): The data object to associate with the modality.
"""
self._modality_data[modality] = data
[docs]
def get_modality_data(self, modality: Modality, set_default: bool = True) -> Any:
"""
Retrieve modality-specific data from the stored modality map.
Args:
modality (Modality): The modality key used to locate the associated data.
set_default (bool, optional): If True, normalizes the raw modality data
using the configured default handler for the given modality when
available. If False, returns the raw modality value without
normalization. Defaults to True.
Returns:
Any: The raw or normalized modality data, or None if the modality key
is not present.
"""
modality_raw_value = self._modality_data.get(modality)
modality_default_handler = (
self.modality_defaults.get(modality) if self.modality_defaults else None
)
if not set_default:
return modality_raw_value
return (
modality_default_handler.normalize(modality_raw_value, modality)
if modality_default_handler
else modality_raw_value
)
[docs]
def get_present_modality_keys(self) -> List[Modality]:
"""Return the list of modality keys currently present in the frame data.
Returns:
List[Modality]: The modalities for which data is stored.
"""
return list(self._modality_data.keys())
[docs]
def to_npz(self, path: str):
"""Serialize the frame data to a compressed NPZ file.
Args:
path: Destination file path for the NPZ archive.
"""
payload = self.model_dump()
# Manually inject private attr for backward compatibility
payload["_modality_data"] = self._modality_data
# Avoid persisting modality_defaults object reference (non-serializable)
payload.pop("modality_defaults", None)
payload.pop("extra_index_data", None)
np.savez_compressed(path, **payload)
[docs]
@classmethod
def from_npz(
cls, path: str, required_modalities: list[Modality] | None = None
) -> "TransformedFrameData":
"""Load the frame data from a .npz file."""
data = np.load(path, allow_pickle=True)
instance = cls(
dataset_name=data["dataset_name"].item(),
segment_id=data["segment_id"].item(),
frame_id=int(data["frame_id"].item()),
timestamp=float(data["timestamp"].item()),
split=data["split"].item(),
global_position=data["global_position"].item(),
aux_data=data.get("aux_data", np.array({})).item(),
extra_index_data=data.get("extra_index_data", np.array({})).item(),
)
if "_modality_data" in data:
instance._modality_data = data["_modality_data"].item()
if required_modalities is None:
return instance
required_modalities = [Modality(m) for m in required_modalities]
for required_modality in required_modalities:
if required_modality not in instance._modality_data:
instance.set_modality_data(required_modality, None)
# remove unwanted modalities
for modality in list(instance._modality_data.keys()):
if modality not in required_modalities:
del instance._modality_data[modality]
return instance
def collate_trajectory_fn(
batch,
*,
# pylint: disable=unused-argument
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
):
"""Collate a batch of ``Trajectory`` instances into a ``BatchedTrajectory``."""
return BatchedTrajectory(batch)
def collate_frame_detections_fn(
batch,
*,
# pylint: disable=unused-argument
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
):
"""Collate a batch of ``FrameDetections3D`` into ``BatchedFrameDetections3D``."""
return BatchedFrameDetections3D(batch)
def collate_lidar_fn(
batch,
*,
# pylint: disable=unused-argument
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
):
"""Collate a batch of ``LidarPointCloud`` into ``BatchedLidarPointCloud``."""
return BatchedLidarPointCloud(batch)
def collate_modalities(
batch: List[Any], *, device: Optional[torch.device] = None
) -> Any:
"""
Exactly like torch's default_collate, except leaves of type `Trajectory`
are turned into a `BatchedTrajectory`. Everything else is native behavior.
"""
device = device or torch.device("cpu")
# Collate fn map must accept broader key types required by torch's internal typing
extended_map: dict[Union[type, tuple[type, ...]], Callable[..., Any]] = {
Trajectory: collate_trajectory_fn,
FrameDetections3D: collate_frame_detections_fn,
LidarPointCloud: collate_lidar_fn,
}
# default_collate_fn_map has type compatibility, update after copy
extended_map.update(default_collate_fn_map)
return _torch_collate(batch, collate_fn_map=extended_map)
[docs]
class TransformedFrameDataBatch:
"""
Data structure to hold a batch of frame data (PyTorch-friendly).
"""
dataset_name: list[str]
segment_id: list[str]
frame_id: list[int]
timestamp: torch.Tensor
split: list[str]
filename: list[str]
_modality_data: Dict[Modality, Any]
aux_data: Optional[Dict[str, Any]] = None
timestamp_diff: Optional[torch.Tensor] = None
def __init__(
self,
frames: list[TransformedFrameData],
*,
device: Optional[torch.device] = None,
):
"""Create a batched view of multiple ``TransformedFrameData`` instances.
Args:
frames: Non-empty list of frames to batch.
device: Optional device to place tensor-like fields on during
initialization. Defaults to CPU when not provided.
"""
if not frames:
raise ValueError("TransformedFrameDataBatch requires a non-empty \
list of TransformedFrameData.")
device = device or torch.device("cpu")
self.dataset_name = [frame.dataset_name for frame in frames]
self.segment_id = [frame.segment_id for frame in frames]
self.frame_id = [frame.frame_id for frame in frames]
self.timestamp = torch.tensor(
[frame.timestamp for frame in frames], dtype=torch.float32, device=device
)
self.split = [frame.split for frame in frames]
self.filename = []
for frame in frames:
if frame.filename is None:
raise ValueError("All frames must have a filename before batching.")
self.filename.append(frame.filename)
# Union of modality keys across the batch (keeps missing as None)
all_modalities = sorted(
set().union(*[set(f._modality_data.keys()) for f in frames]),
key=lambda m: m.name,
)
# Collate each modality across frames using PyTorch's
# collate + Trajectory override
self._modality_data = {
modality: collate_modalities(
[f.get_modality_data(modality) for f in frames],
device=device,
)
for modality in all_modalities
}
# Aux data: if dict-like across frames,
# you may also want to collate it with the same helper.
# For now: copy the first if present.
self.aux_data = (
frames[0].aux_data if frames and frames[0].aux_data is not None else None
)
# timestamp_diff: keep None if any missing; otherwise stack to tensor
if any(f.timestamp_diff is None for f in frames):
self.timestamp_diff = None
else:
self.timestamp_diff = torch.tensor(
[f.timestamp_diff for f in frames], dtype=torch.float32, device=device
)
[docs]
def get_modality_data(self, modality: Modality) -> Any:
"""Get data for a specific modality."""
return self._modality_data.get(modality)
[docs]
def cuda(self, device: Optional[int] = None):
"""Move batched tensors and nested modality payloads to a CUDA device."""
dev = torch.device(f"cuda:{device}" if device is not None else "cuda")
return self.to(dev)
[docs]
def to(self, device: torch.device):
"""Move all tensor-like fields to ``device`` (non-blocking where possible)."""
self.timestamp = self.timestamp.to(device=device, non_blocking=True)
if self.timestamp_diff is not None:
self.timestamp_diff = self.timestamp_diff.to(
device=device, non_blocking=True
)
self._modality_data = {
k: _to_device_recursive(v, device) for k, v in self._modality_data.items()
}
return self