#!/usr/bin/env python3
"""
Animate a single trajectory from an NPZ file and save to MP4 (ffmpeg required).
"""

from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation


def _compute_bounds(xy: np.ndarray, pad_frac: float = 0.05) -> tuple[float, float, float, float]:
    x_min = float(np.min(xy[..., 0]))
    x_max = float(np.max(xy[..., 0]))
    y_min = float(np.min(xy[..., 1]))
    y_max = float(np.max(xy[..., 1]))
    dx = x_max - x_min
    dy = y_max - y_min
    pad_x = dx * pad_frac if dx > 0 else 1.0
    pad_y = dy * pad_frac if dy > 0 else 1.0
    return x_min - pad_x, x_max + pad_x, y_min - pad_y, y_max + pad_y


def animate_npz(
    npz_path: Path,
    traj_id: int,
    output: Path,
    fps: int,
    stride: int,
    dpi: int,
    point_size: float,
    color_a: str,
    color_b: str,
) -> None:
    data = np.load(npz_path, allow_pickle=True)
    trajectories = data["trajectories"]
    types = data["types"]
    print(f"trajectories shape: {trajectories.shape}, types shape: {types.shape}")

    if traj_id < 0 or traj_id >= trajectories.shape[0]:
        raise ValueError(f"traj_id must be in [0, {trajectories.shape[0] - 1}]")

    traj = trajectories[traj_id]  # (T, N, 2)
    ttypes = types[traj_id]       # (N,)

    colors = np.where(ttypes == 1, color_a, color_b)
    fig, ax = plt.subplots(figsize=(5, 5), dpi=dpi)

    x_min, x_max, y_min, y_max = _compute_bounds(traj)
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    scatter = ax.scatter(
        traj[0, :, 0],
        traj[0, :, 1],
        s=point_size,
        c=colors,
        edgecolors="none",
    )
    title = ax.set_title(f"Trajectory {traj_id} | frame 0")

    frames = list(range(0, traj.shape[0], stride))

    def update(frame_idx: int):
        xy = traj[frame_idx]
        scatter.set_offsets(xy)
        title.set_text(f"Trajectory {traj_id} | frame {frame_idx}")
        return scatter, title

    anim = animation.FuncAnimation(
        fig,
        update,
        frames=frames,
        interval=1000 / fps, 
        blit=False,
        repeat=False,
    )

    output.parent.mkdir(parents=True, exist_ok=True)
    writer = animation.FFMpegWriter(fps=fps)
    anim.save(output, writer=writer)
    plt.close(fig)


def main() -> None:
    script_dir = Path(__file__).resolve().parent
    default_npz = script_dir / "dataset" / "trajectories.npz"
    default_out = script_dir / "dataset" / "trajectory.mp4"

    parser = argparse.ArgumentParser(description="Animate a trajectory from NPZ.")
    parser.add_argument("--npz", type=Path, default=default_npz, help="Input NPZ file.")
    parser.add_argument("--traj-id", type=int, default=0, help="Trajectory index.")
    parser.add_argument("--output", type=Path, default=default_out, help="Output MP4 path.")
    parser.add_argument("--fps", type=int, default=10, help="Frames per second.")
    parser.add_argument("--stride", type=int, default=1, help="Use every k-th frame.")
    parser.add_argument("--dpi", type=int, default=300, help="Output DPI.")
    parser.add_argument("--point-size", type=float, default=5.0, help="Scatter point size.")
    parser.add_argument("--color-a", type=str, default="#1f77b4", help="Color for type==1.")
    parser.add_argument("--color-b", type=str, default="#ff7f0e", help="Color for type!=1.")
    args = parser.parse_args()

    animate_npz(
        npz_path=args.npz,
        traj_id=args.traj_id,
        output=args.output,
        fps=args.fps,
        stride=max(1, args.stride),
        dpi=args.dpi,
        point_size=args.point_size,
        color_a=args.color_a,
        color_b=args.color_b,
    )

    print(f"Saved animation to {args.output}")


if __name__ == "__main__":
    main()
