from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
import numpy as np
import torch
from standard_e2e.enums import TrajectoryComponent
from standard_e2e.utils import _assert_strictly_increasing
# ======================= Helpers (module-level) =======================
Array1DNP = Union[
np.ndarray, List[float], List[int], Tuple[float, ...], Tuple[int, ...]
]
def _to_numpy_1d_float32(values: Array1DNP) -> np.ndarray:
"""Convert ndarray/list/tuple to 1D np.float32. Accepts (N,) or (N,1)."""
if isinstance(values, np.ndarray):
array = values
elif isinstance(values, (list, tuple)):
array = np.asarray(values)
else:
raise TypeError(f"Expected NumPy array or list/tuple, got {type(values)!r}")
if array.ndim == 1:
pass
elif array.ndim == 2 and array.shape[1] == 1:
array = array.reshape(-1)
else:
raise ValueError(f"Expected shape (N,) or (N,1); got {tuple(array.shape)}")
return array.astype(np.float32, copy=False)
def _pad_numpy_1d(array: np.ndarray, target_length: int, side: str) -> np.ndarray:
"""Zero-pad a 1D array to target_length on 'left' or 'right'."""
length = array.shape[0]
if length >= target_length:
return array
pad = target_length - length
zeros = np.zeros((pad,), dtype=np.float32)
if side == "left":
return np.concatenate([zeros, array])
if side == "right":
return np.concatenate([array, zeros])
raise ValueError("side must be 'left' or 'right'")
def _validate_components_arg(
components: Union[TrajectoryComponent, Sequence[TrajectoryComponent]],
) -> List[TrajectoryComponent]:
"""Normalize and validate a component or a list of components."""
if isinstance(components, TrajectoryComponent):
components_list = [components]
elif isinstance(components, Sequence) and components:
# check this is a sequence of TrajectoryComponent
components_list = [c for c in components] # type: ignore[misc]
if not all(isinstance(c, TrajectoryComponent) for c in components_list):
valid = ", ".join(m.name for m in TrajectoryComponent)
raise TypeError(
f"All elements must be TrajectoryComponent. Valid: [{valid}]"
)
else:
raise TypeError(
"components must be a TrajectoryComponent or a non-empty sequence of them"
)
return components_list
# ======================= Trajectory (NumPy) =======================
[docs]
class Trajectory:
"""
Single trajectory data container (numpy) with helpers for processing.
- Components: 1D np.float32 arrays of equal length N.
- Accepts shapes (N,) or (N,1) from ndarray/list/tuple; stored as (N,).
- Auto IS_VALID: first non-IS_VALID set -> is_valid = 1;
user-provided IS_VALID overrides.
- score: optional scalar (float).
- get(one or many, strict=True) -> (N, K) np.float32.
- pad/trim/pad_or_trim with zeros; padded rows marked invalid via is_valid zeros.
"""
def __init__(
self,
data: Dict[TrajectoryComponent, Array1DNP] | None = None,
score: Optional[float] = None,
time_lattice: Optional[Array1DNP] = None,
) -> None:
self._data: Dict[TrajectoryComponent, np.ndarray] = {}
self.score: Optional[float] = None if score is None else float(score)
if data:
self.set_many(data)
if time_lattice is not None:
self.resample(time_lattice, inplace=True)
# ----- basics -----
@property
def length(self) -> int:
"""Return the number of timesteps stored in this trajectory.
If no data has been added yet, this returns 0.
Returns:
int: Number of timesteps (rows) in the trajectory.
"""
if not self._data:
return 0
first_key = next(iter(self._data))
return int(self._data[first_key].shape[0])
@property
def isEmpty(self) -> bool:
"""Check if the trajectory is empty or has no valid data."""
if self.length == 0:
return True
if self.get(TrajectoryComponent.IS_VALID, strict=False).sum() == 0:
return True
return False
[docs]
def components(self) -> List[TrajectoryComponent]:
"""List the trajectory components currently stored."""
return list(self._data.keys())
[docs]
def has(self, component: TrajectoryComponent) -> bool:
"""Check if a trajectory component is present."""
return component in self._data
# ----- set/get -----
[docs]
def set(self, component: TrajectoryComponent, values: Array1DNP) -> "Trajectory":
"""Add or replace a trajectory component, enforcing length consistency."""
array = _to_numpy_1d_float32(values)
if component is TrajectoryComponent.TIMESTAMP:
_assert_strictly_increasing(array)
if self.length and array.shape[0] != self.length:
raise ValueError(f"Length mismatch for {component.name}: \
{array.shape[0]} != {self.length}")
self._data[component] = array
# Auto-create is_valid=1 for first real component
if (
component is not TrajectoryComponent.IS_VALID
and TrajectoryComponent.IS_VALID not in self._data
):
self._data[TrajectoryComponent.IS_VALID] = np.ones(
array.shape[0], dtype=np.float32
)
return self
[docs]
def set_many(
self, mapping: Mapping[TrajectoryComponent, Array1DNP]
) -> "Trajectory":
"""Set multiple components in one call, respecting IS_VALID ordering."""
for component, values in mapping.items():
if component is not TrajectoryComponent.IS_VALID:
self.set(component, values)
if TrajectoryComponent.IS_VALID in mapping:
self.set(
TrajectoryComponent.IS_VALID, mapping[TrajectoryComponent.IS_VALID]
)
return self
[docs]
def get(
self,
components: Union[TrajectoryComponent, Sequence[TrajectoryComponent]],
*,
strict: bool = True,
) -> np.ndarray:
"""Fetch one or more components as a stacked ``(N, K)`` array."""
components_list = _validate_components_arg(components)
length = self.length
columns: List[np.ndarray] = []
missing_components: List[TrajectoryComponent] = []
for component in components_list:
if component in self._data:
columns.append(self._data[component].reshape(length, 1))
else:
if strict:
missing_components.append(component)
else:
columns.append(np.zeros((length, 1), dtype=np.float32))
if missing_components:
available = ", ".join(c.name for c in self._data)
needed = ", ".join(c.name for c in missing_components)
raise KeyError(f"Missing component(s): {needed}. Available: [{available}]")
if not columns:
return np.zeros((self.length, 0), dtype=np.float32)
out = np.concatenate(columns, axis=1).astype(np.float32, copy=False)
return cast(np.ndarray, out)
# ----- pad/trim -----
[docs]
def pad(self, target_length: int, *, side: str = "right") -> "Trajectory":
"""Pad all components with zeros to ``target_length`` (left/right)."""
if target_length <= self.length:
return self
if TrajectoryComponent.IS_VALID not in self._data:
self._data[TrajectoryComponent.IS_VALID] = np.zeros(
self.length, dtype=np.float32
)
for component, array in list(self._data.items()):
self._data[component] = _pad_numpy_1d(array, target_length, side)
return self
[docs]
def trim(self, target_length: int, *, side: str = "right") -> "Trajectory":
"""Trim components to ``target_length`` from the specified side."""
current_length = self.length
if target_length >= current_length:
return self
if side not in {"left", "right"}:
raise ValueError("side must be 'left' or 'right'")
cut = current_length - target_length
for component, array in list(self._data.items()):
self._data[component] = array[cut:] if side == "left" else array[:-cut]
return self
[docs]
def pad_or_trim(self, target_length: int, *, side: str = "right") -> "Trajectory":
"""Pad or trim to ``target_length`` depending on current length."""
return (
self.pad(target_length, side=side)
if target_length > self.length
else self.trim(target_length, side=side)
)
[docs]
def resample(self, time_lattice: Array1DNP, inplace: bool = False) -> "Trajectory":
"""
Resample trajectory components onto a new time lattice using linear
interpolation.
Args:
time_lattice: New timestamps to interpolate onto (1D array-like).
Returns:
self: Modified trajectory with resampled data.
Raises:
ValueError: If TIMESTAMP component is missing or time_lattice is invalid.
"""
if time_lattice is None:
raise ValueError("time_lattice cannot be None")
# Convert time_lattice to numpy array
new_times = _to_numpy_1d_float32(time_lattice)
_assert_strictly_increasing(new_times)
if len(new_times) == 0:
raise ValueError("time_lattice cannot be empty")
if self.isEmpty:
# For empty trajectory, create new data with all components
# set to 0 and IS_VALID to False
new_data = {}
for component in self.components():
new_data[component] = np.zeros(len(new_times), dtype=np.float32)
new_data[TrajectoryComponent.TIMESTAMP] = new_times.copy()
new_data[TrajectoryComponent.IS_VALID] = np.zeros(
len(new_times), dtype=np.float32
)
self._data = new_data
return self
# Check if we have timestamp data
if TrajectoryComponent.TIMESTAMP not in self._data:
raise ValueError("Cannot resample trajectory without TIMESTAMP component")
original_times = self._data[TrajectoryComponent.TIMESTAMP]
# Validate stored timestamps each resample call (in case legacy object)
_assert_strictly_increasing(original_times)
# Handle edge cases - if original trajectory has only one point
if len(original_times) == 1:
new_data = {}
original_time = original_times[0]
for component, values in self._data.items():
if component == TrajectoryComponent.TIMESTAMP:
new_data[component] = new_times.copy()
elif component == TrajectoryComponent.IS_VALID:
# Set IS_VALID to True only for exact timestamp matches
new_data[component] = np.zeros(len(new_times), dtype=np.float32)
# Find exact matches with the original timestamp
exact_matches = np.isclose(
new_times, original_time, rtol=1e-6, atol=1e-6
)
new_data[component][exact_matches] = 1.0
else:
# Set to zero for extrapolated points, original value for
# exact matches
new_data[component] = np.zeros(len(new_times), dtype=np.float32)
exact_matches = np.isclose(
new_times, original_time, rtol=1e-6, atol=1e-6
)
new_data[component][exact_matches] = values[0]
self._data = new_data
return self
# Determine time boundaries for validity
min_time = np.min(original_times)
max_time = np.max(original_times)
# Create new data dictionary
new_data = {}
# Always add timestamp
new_data[TrajectoryComponent.TIMESTAMP] = new_times.copy()
# Handle IS_VALID component and determine valid time ranges preserving gaps
new_is_valid = np.zeros(len(new_times), dtype=np.float32)
interpolation_bounds = (min_time, max_time) # default
valid_intervals: List[Tuple[float, float]] = []
if TrajectoryComponent.IS_VALID in self._data:
original_is_valid = self._data[TrajectoryComponent.IS_VALID]
valid_mask = original_is_valid == 1.0
if np.any(valid_mask):
# Build contiguous intervals (in index space) of valid points
start_idx: Optional[int] = None
for idx, is_v in enumerate(valid_mask):
if is_v and start_idx is None:
start_idx = idx
elif not is_v and start_idx is not None:
# end previous interval
end_idx = idx - 1
valid_intervals.append(
(
float(original_times[start_idx]),
float(original_times[end_idx]),
)
)
start_idx = None
if start_idx is not None:
valid_intervals.append(
(
float(original_times[start_idx]),
float(original_times[len(valid_mask) - 1]),
)
)
# Mark new_times valid if they lie inside any valid interval
for i, t in enumerate(new_times):
for a, b in valid_intervals:
if a <= t <= b:
new_is_valid[i] = 1.0
break
# Interpolation bounds: overall min/max of valid points for
# outside zeroing
all_valid_times = original_times[valid_mask]
interpolation_bounds = (
float(np.min(all_valid_times)),
float(np.max(all_valid_times)),
)
else:
# All invalid: leave new_is_valid zeros, keep default bounds
pass
else:
# No IS_VALID provided: treat full original span as one valid interval
new_is_valid = ((new_times >= min_time) & (new_times <= max_time)).astype(
np.float32
)
valid_intervals.append((float(min_time), float(max_time)))
interpolation_bounds = (float(min_time), float(max_time))
new_data[TrajectoryComponent.IS_VALID] = new_is_valid
# Mask for overall interpolation bounds (outermost valid extent)
in_bounds = (new_times >= interpolation_bounds[0]) & (
new_times <= interpolation_bounds[1]
)
# Interpolate all other components
for component, values in self._data.items():
if component in [
TrajectoryComponent.TIMESTAMP,
TrajectoryComponent.IS_VALID,
]:
continue # Already handled
# Determine which original data points to use for interpolation
if TrajectoryComponent.IS_VALID in self._data:
original_is_valid = self._data[TrajectoryComponent.IS_VALID]
valid_mask = original_is_valid == 1.0
if np.any(valid_mask):
# Use only valid timestamps and values for interpolation
valid_original_times = original_times[valid_mask]
valid_original_values = values[valid_mask]
# Linear interpolation using only valid points
interpolated_values = np.interp(
new_times, valid_original_times, valid_original_values
)
else:
# No valid original points, fill with zeros
interpolated_values = np.zeros_like(new_times)
else:
# Use all original data for interpolation
interpolated_values = np.interp(new_times, original_times, values)
# Zero out values outside interpolation bounds OR not marked valid
valid_new_mask = new_is_valid.astype(bool)
interpolated_values[~in_bounds | ~valid_new_mask] = 0.0
new_data[component] = interpolated_values.astype(np.float32)
if inplace:
self._data = new_data
return self
return Trajectory(
data=cast(Dict[TrajectoryComponent, Array1DNP], new_data),
score=self.score,
)
def __len__(self) -> int:
return self.length
def __repr__(self) -> str:
shapes = ", ".join(f"{k.name}:{tuple(v.shape)}" for k, v in self._data.items())
return f"Trajectory(N={self.length}, isEmpty={self.isEmpty}, \
score={self.score}, comps=[{shapes}])"
# ======================= BatchedTrajectory (Torch) =======================
[docs]
class BatchedTrajectory:
"""
Data container for a batch of trajectories.
- __init__ accepts a NON-EMPTY list of Trajectory objects.
- Empty trajectories inside the list are allowed and zero-filled.
- strict=True: all **non-empty** trajectories must have the same \
component set (IS_VALID always included).
strict=False: union of components; per-sample missing components are zero-filled.
- Stores per-component tensors float32 (batch_size, sequence_length) on one device.
- Exposes:
- get(..., strict=...) -> (batch_size, sequence_length, num_components)
- scores: (batch_size,) float32 (NaN where missing)
- is_empty_mask: (batch_size,) bool
- .to(device) / .cuda()
"""
def _replace_none_with_empty_trajs(
self, trajectories: Sequence[Trajectory]
) -> Sequence[Trajectory]:
"""Replace None trajectories with empty ones."""
return [t if t is not None else Trajectory() for t in trajectories]
def __init__(
self,
trajectories: Sequence[Trajectory],
device: Optional[torch.device] = None,
side: str = "right",
strict: bool = False,
) -> None:
if not trajectories or len(trajectories) == 0:
raise ValueError(
"BatchedTrajectory requires a non-empty list of trajectories."
)
self._original_trajectories = trajectories
self._side = side
self._strict = strict
trajectories = self._replace_none_with_empty_trajs(trajectories)
self._device = device or torch.device("cpu")
self._data: Dict[TrajectoryComponent, torch.Tensor] = {}
# Meta: scores and emptiness
self._scores = torch.tensor(
[
np.float32(t.score) if t.score is not None else np.float32("nan")
for t in trajectories
],
dtype=torch.float32,
device=self._device,
)
self._is_empty_mask = torch.tensor(
[t.isEmpty for t in trajectories], dtype=torch.bool, device=self._device
)
# sequence lengths
sequence_lengths = [t.length for t in trajectories]
max_length = max(sequence_lengths)
# determine component set to use
component_sets = [set(t.components()) for t in trajectories]
non_empty_indices = [i for i, t in enumerate(trajectories) if not t.isEmpty]
if strict:
if non_empty_indices:
expected_components = set(component_sets[non_empty_indices[0]]).union(
{TrajectoryComponent.IS_VALID}
)
for idx in non_empty_indices:
if (
component_sets[idx].union({TrajectoryComponent.IS_VALID})
!= expected_components
):
got = component_sets[idx].union({TrajectoryComponent.IS_VALID})
expected_components_list = [
c.name
for c in sorted(expected_components, key=lambda c: c.name)
]
raise ValueError(
(
"Strict mode: component mismatch at sample "
f"{idx}. Expected: {expected_components_list}, got:"
f"{[c.name for c in sorted(got, key=lambda c: c.name)]}"
)
)
else:
expected_components = {TrajectoryComponent.IS_VALID}
components_to_use = sorted(expected_components, key=lambda c: c.name)
else:
union_components = (
set().union(*component_sets).union({TrajectoryComponent.IS_VALID})
)
components_to_use = sorted(union_components, key=lambda c: c.name)
for trajectory in trajectories:
trajectory.pad(max_length, side=side)
for component in components_to_use:
values = []
for trajectory in trajectories:
values.append(trajectory.get(component, strict=False))
# stack all values for this component
stacked = torch.tensor(
np.stack(values, axis=0), dtype=torch.float32, device=self._device
)
self._data[component] = stacked
# ---- basics ----
@property
def device(self) -> torch.device:
"""Return the device on which the trajectory data is stored."""
return self._device
@property
def batch_size(self) -> int:
"""Return the number of trajectories in the batch."""
return int(self._scores.shape[0])
@property
def length(self) -> int:
"""Return the length of the trajectories in the batch (automatically padded)."""
return 0 if not self._data else int(next(iter(self._data.values())).shape[1])
@property
def scores(self) -> torch.Tensor:
"""Return the scores associated with each trajectory in the batch."""
return self._scores
@property
def is_empty_mask(self) -> torch.Tensor:
"""Return a mask indicating which trajectories in the batch are empty."""
return self._is_empty_mask
[docs]
def components(self) -> List[TrajectoryComponent]:
"""List the trajectory components currently stored."""
return list(self._data.keys())
[docs]
def has(self, component: TrajectoryComponent) -> bool:
"""Check if a trajectory component is present."""
return component in self._data
# ---- get ----
[docs]
def get(
self,
components: Union[TrajectoryComponent, Sequence[TrajectoryComponent]],
*,
strict: bool = True,
) -> torch.Tensor:
"""Retrieve stacked component tensors of shape ``(B, T, K)``.
Args:
components: Single component or sequence thereof.
strict: If True, missing components raise ``KeyError``; otherwise zeros
are returned for missing entries.
"""
components_list = _validate_components_arg(components)
batch_size = self.batch_size
sequence_length = self.length
columns: List[torch.Tensor] = []
missing_components: List[TrajectoryComponent] = []
for component in components_list:
if component in self._data:
columns.append(
self._data[component]
) # (batch_size, sequence_length, 1)
else:
if strict:
missing_components.append(component)
else:
columns.append(
torch.zeros(
(batch_size, sequence_length, 1),
dtype=torch.float32,
device=self.device,
)
)
if missing_components:
available = ", ".join(c.name for c in self._data)
needed = ", ".join(c.name for c in missing_components)
raise KeyError(f"Missing component(s): {needed}. Available: [{available}]")
if not columns:
return torch.zeros(
(batch_size, sequence_length, 0),
dtype=torch.float32,
device=self.device,
)
return torch.cat(columns, dim=-1)
# # ---- device moves ----
# @classmethod
# def _from_internal(
# cls,
# data: Dict[TrajectoryComponent, torch.Tensor],
# scores: torch.Tensor,
# empty_mask: torch.Tensor,
# device: torch.device,
# ) -> "BatchedTrajectory":
# obj = cls.__new__(cls) # bypass __init__
# obj._data = data
# obj._scores = scores
# obj._is_empty_mask = empty_mask
# obj._device = device
# return obj
[docs]
def to(self, device: Optional[torch.device] = None) -> "BatchedTrajectory":
if device is None or device == self.device:
return self
for k, v in self._data.items():
self._data[k] = v.to(device=device, non_blocking=True)
self._scores = self._scores.to(device=device, non_blocking=True)
self._is_empty_mask = self._is_empty_mask.to(device=device, non_blocking=True)
self._device = device
return self
[docs]
def cuda(self, device: Optional[int] = None) -> "BatchedTrajectory":
dev = torch.device(f"cuda:{device}" if device is not None else "cuda")
return self.to(dev)
# ---- batch-level pad/trim ----
[docs]
def trim(self, target_length: int, *, side: str = "right") -> "BatchedTrajectory":
"""Trim all sequences in the batch to ``target_length`` in-place."""
current_length = self.length
if target_length >= current_length:
return self
cut = current_length - target_length
self._data = {
k: (v[:, cut:] if side == "left" else v[:, :-cut])
for k, v in self._data.items()
}
return self
[docs]
def pad(self, target_length: int, *, side: str = "right") -> "BatchedTrajectory":
"""Pad all sequences in the batch to ``target_length`` with zeros."""
current_length = self.length
if target_length <= current_length:
return self
pad = target_length - current_length
# padded: Dict[TrajectoryComponent, torch.Tensor] = {}
for component, tensor in self._data.items():
zeros = torch.zeros(
(tensor.shape[0], pad, tensor.shape[2]),
dtype=torch.float32,
device=self.device,
)
self._data[component] = (
torch.cat([zeros, tensor], dim=1)
if side == "left"
else torch.cat([tensor, zeros], dim=1)
)
return self
[docs]
def pad_or_trim(
self, target_length: int, *, side: str = "right"
) -> "BatchedTrajectory":
"""Pad or trim batch to ``target_length`` depending on current length."""
current_length = self.length
return (
self.pad(target_length, side=side)
if target_length > current_length
else self.trim(target_length, side=side)
)
# ---- resample ----
[docs]
def resample(self, time_lattice: Array1DNP) -> "BatchedTrajectory":
"""Return a new BatchedTrajectory resampled onto a new time lattice.
This constructs (optionally cloned) copies of the original per-sample
Trajectory objects, calls their ``Trajectory.resample`` method, and
returns a freshly built BatchedTrajectory. The current object is not
modified.
Args:
time_lattice: 1D array-like of new timestamps (passed to each
underlying ``Trajectory.resample`` call).
Returns:
A new ``BatchedTrajectory`` instance on the same device.
"""
new_trajs: List[Trajectory] = []
for traj in self._original_trajectories:
new_trajs.append(traj.resample(time_lattice))
return BatchedTrajectory(
new_trajs, device=self.device, side=self._side, strict=self._strict
)
def __repr__(self) -> str:
shapes = ", ".join(f"{k.name}:{tuple(v.shape)}" for k, v in self._data.items())
return (
f"BatchedTrajectory(batch_size={self.batch_size}, \
sequence_length={self.length}, device={self.device}, "
f"components=[{shapes}], scores_shape={tuple(self._scores.shape)}, "
f"empty_mask={self._is_empty_mask.tolist()})"
)