

from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt


def _compute_bounds(xy: np.ndarray, pad_frac: float = 0.05) -> tuple[float, float, float, float]:
    x_min = float(np.min(xy[..., 0]))
    x_max = float(np.max(xy[..., 0]))
    y_min = float(np.min(xy[..., 1]))
    y_max = float(np.max(xy[..., 1]))
    dx = x_max - x_min
    dy = y_max - y_min
    pad_x = dx * pad_frac if dx > 0 else 1.0
    pad_y = dy * pad_frac if dy > 0 else 1.0
    return x_min - pad_x, x_max + pad_x, y_min - pad_y, y_max + pad_y


def _parse_frames(frames: str | None, num_frames: int, max_frame: int) -> list[int]:
    if frames is None:
        return np.linspace(0, max_frame, num=num_frames, dtype=int).tolist()
    frame_list = [int(item.strip()) for item in frames.split(",") if item.strip()]
    if len(frame_list) != num_frames:
        raise ValueError(f"--frames must contain exactly {num_frames} entries.")
    if any(frame < 0 or frame > max_frame for frame in frame_list):
        raise ValueError(f"--frames entries must be within [0, {max_frame}].")
    return frame_list


def _parse_range(range_str: str | None, name: str) -> tuple[float, float] | None:
    if range_str is None:
        return None
    parts = [item.strip() for item in range_str.split(",") if item.strip()]
    if len(parts) != 2:
        raise ValueError(f"--{name} must be formatted as min,max.")
    lower, upper = (float(parts[0]), float(parts[1]))
    if lower >= upper:
        raise ValueError(f"--{name} must satisfy min < max.")
    return lower, upper


def plot_npz(
    npz_path: Path,
    traj_id: int,
    output: Path,
    dpi: int,
    point_size: float,
    fig_width: float,
    fig_height: float,
    num_panels: int,
    frames: str | None,
    xlim: str | None,
    ylim: str | None,
) -> None:
    data = np.load(npz_path, allow_pickle=True)
    trajectories = data["trajectories"]
    types = data["types"]
    print(f"trajectories shape: {trajectories.shape}, types shape: {types.shape}")

    if traj_id < 0 or traj_id >= trajectories.shape[0]:
        raise ValueError(f"traj_id must be in [0, {trajectories.shape[0] - 1}]")

    traj = trajectories[traj_id]  # (T, N, 2)
    ttypes = types[traj_id]       # (N,)

    if num_panels <= 0:
        raise ValueError("--num-panels must be a positive integer.")
    frame_list = _parse_frames(frames, num_frames=num_panels, max_frame=traj.shape[0] - 1)
    colors = np.where(ttypes == 1, "#1f77b4", "#ff7f0e")

    plt.rcParams.update(
        {
            "font.size": 6,
            "axes.titlesize": 6,
            "axes.linewidth": 0.3,
        }
    )
    fig, axes = plt.subplots(
        1,
        len(frame_list),
        figsize=(fig_width, fig_height),
        dpi=dpi,
        constrained_layout=True,
        sharey=True,
    )

    x_bounds = _parse_range(xlim, "xlim")
    y_bounds = _parse_range(ylim, "ylim")
    x_min, x_max, y_min, y_max = _compute_bounds(traj)
    if x_bounds is not None:
        x_min, x_max = x_bounds
    if y_bounds is not None:
        y_min, y_max = y_bounds
    for ax, frame_idx in zip(axes, frame_list):
        xy = traj[frame_idx]
        ax.scatter(
            xy[:, 0],
            xy[:, 1],
            s=point_size,
            c=colors,
            edgecolors="none",
            rasterized=True,
        )
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        ax.set_aspect("equal", adjustable="box")
        ax.tick_params(axis="both", labelsize=5.5, length=2, width=0.4, pad=0.6)
        ax.set_title(f"time step {frame_idx}")
        for spine in ax.spines.values():
            spine.set_linewidth(0.3)

    output.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output)
    plt.close(fig)


def main() -> None:
    script_dir = Path(__file__).resolve().parent
    default_npz = script_dir / "dataset" / "trajectories.npz"
    default_out = script_dir / "dataset" / "trajectory_panels.pdf"

    parser = argparse.ArgumentParser(description="Plot 5 frames of a trajectory from NPZ.")
    parser.add_argument("--npz", type=Path, default=default_npz, help="Input NPZ file.")
    parser.add_argument("--traj-id", type=int, default=110, help="Trajectory index.")
    parser.add_argument("--output", type=Path, default=default_out, help="Output figure path.")
    parser.add_argument("--dpi", type=int, default=300, help="Output DPI.")
    parser.add_argument("--point_size", type=float, default=2.0, help="Scatter point size.")
    parser.add_argument(
        "--fig-width",
        type=float,
        default=6.7,
        help="Figure width in inches (conference-ready width is 6.7).",
    )
    parser.add_argument(
        "--fig-height",
        type=float,
        default=1.6,
        help="Figure height in inches.",
    )
    parser.add_argument(
        "--num-panels",
        type=int,
        default=6,
        help="Number of panels (frames) in the 1xN layout.",
    )
    parser.add_argument(
        "--frames",
        type=str,
        default=None,
        help="Comma-separated list of frame indices to plot (must match --num-panels).",
    )
    parser.add_argument(
        "--xlim",
        type=str,
        default=None,
        help="Override x-axis range as min,max.",
    )
    parser.add_argument(
        "--ylim",
        type=str,
        default=None,
        help="Override y-axis range as min,max.",
    )
    args = parser.parse_args()

    plot_npz(
        npz_path=args.npz,
        traj_id=args.traj_id,
        output=args.output,
        dpi=args.dpi,
        point_size=args.point_size,
        fig_width=args.fig_width,
        fig_height=args.fig_height,
        num_panels=args.num_panels,
        frames=args.frames,
        xlim=args.xlim,
        ylim=args.ylim,
    )

    print(f"Saved figure to {args.output}")


if __name__ == "__main__":
    main()
