Source code for standard_e2e.dataset_utils.augmentation.trajectory_resampling
from typing import List
from standard_e2e.data_structures import TransformedFrameData
from standard_e2e.data_structures.trajectory_data import Trajectory
from standard_e2e.enums import Modality
from .augmentation import FrameAugmentation
[docs]
class TrajectoryResampling(FrameAugmentation):
"""Resample past/future/preference trajectories onto target timestamps."""
def __init__(
self,
history_target_timestamps: list[float] | None = None,
future_target_timestamps: list[float] | None = None,
target_frame_names: list[str] | None = None,
):
"""Configure resampling targets for trajectories.
Args:
history_target_timestamps: New timestamps for ``PAST_STATES``.
future_target_timestamps: New timestamps for ``FUTURE_STATES`` and
``PREFERENCE_TRAJECTORY`` (when present).
target_frame_names: Optional subset of frame keys to resample; defaults
to all frames provided to ``augment``.
"""
super().__init__()
self._history_target_timestamps = history_target_timestamps
self._future_target_timestamps = future_target_timestamps
self._target_frame_names = target_frame_names
def _update_trajectory(
self,
frame: TransformedFrameData,
modality: Modality,
target_timestamps: list[float],
) -> TransformedFrameData:
assert modality in [Modality.FUTURE_STATES, Modality.PAST_STATES]
trajectory: Trajectory = frame.get_modality_data(modality)
resampled_trajectory = trajectory.resample(target_timestamps)
frame.set_modality_data(modality, resampled_trajectory)
return frame
def _update_preference_trajectories(
self,
frame: TransformedFrameData,
target_timestamps: list[float],
) -> TransformedFrameData:
preference_trajectories: List[Trajectory] | None = frame.get_modality_data(
Modality.PREFERENCE_TRAJECTORY
)
if preference_trajectories is not None:
resampled_preference_trajectories = [
preference_trajectory.resample(target_timestamps)
for preference_trajectory in preference_trajectories
]
frame.set_modality_data(
Modality.PREFERENCE_TRAJECTORY, resampled_preference_trajectories
)
return frame
def _augment(
self, frames: dict[str, TransformedFrameData], regime: str
) -> dict[str, TransformedFrameData]:
target_frame_names = (
self._target_frame_names
if self._target_frame_names is not None
else list(frames.keys())
)
for frame_name in target_frame_names:
if self._history_target_timestamps is not None:
frames[frame_name] = self._update_trajectory(
frames[frame_name],
Modality.PAST_STATES,
self._history_target_timestamps,
)
if self._future_target_timestamps is not None:
frames[frame_name] = self._update_trajectory(
frames[frame_name],
Modality.FUTURE_STATES,
self._future_target_timestamps,
)
frames[frame_name] = self._update_preference_trajectories(
frames[frame_name], self._future_target_timestamps
)
return frames