from torch.utils.data import DataLoader
from trajdata import AgentBatch, UnifiedDataset
from trajdata.visualization.interactive_animation import (
    InteractiveAnimation,
    animate_agent_batch_interactive,
)

dataset = UnifiedDataset(
    desired_data=["nusc_mini-boston"],
    desired_dt=0.1,
    centric="agent",
    history_sec=(1.0, 3.0),
    future_sec=(4.0, 4.0),
    data_dirs={"nusc_mini": "", "sdd": ""},
    incl_raster_map=True,
    raster_map_params={
        "px_per_m": 2,
        "map_size_px": 224,
        "offset_frac_xy": (-0.5, 0.0),
    },
)  # These settings were used to create Figure 2.

print(f"# Data Samples: {len(dataset):,}")

dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=dataset.get_collate_fn(),
    num_workers=0,
)

batch: AgentBatch
for batch in dataloader:
    animation = InteractiveAnimation(
        animate_agent_batch_interactive,
        batch=batch,
        batch_idx=0,
        cache_path=dataset.cache_path,
    )
    animation.show()
    break
