

from __future__ import annotations

import argparse
import multiprocessing as mp
from pathlib import Path
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import cKDTree

_WORKER_TRAJECTORIES: np.ndarray | None = None
_WORKER_TYPES: np.ndarray | None = None
_WORKER_RC: float | None = None


def _init_pair_prob_worker(
    trajectories: np.ndarray,
    types: np.ndarray,
    rc: float,
) -> None:
    global _WORKER_TRAJECTORIES, _WORKER_TYPES, _WORKER_RC
    _WORKER_TRAJECTORIES = trajectories
    _WORKER_TYPES = types
    _WORKER_RC = rc


def _pair_probs_for_trajectory(m: int) -> np.ndarray:
    if _WORKER_TRAJECTORIES is None or _WORKER_TYPES is None or _WORKER_RC is None:
        raise RuntimeError("Worker globals not initialized.")
    traj = _WORKER_TRAJECTORIES[m]
    ttypes = _WORKER_TYPES[m]
    n_steps = traj.shape[0]
    out = np.empty((n_steps, 2), dtype=np.float64)
    for t in range(n_steps):
        p_ab, p_ba = _pair_probs_for_frame(traj[t], ttypes, _WORKER_RC)
        out[t, 0] = p_ab
        out[t, 1] = p_ba
    return out


def _pair_probs_for_frame(xy: np.ndarray, types: np.ndarray, rc: float) -> tuple[float, float]:
    tree = cKDTree(xy)
    neighbors = tree.query_ball_point(xy, r=rc)

    a_total = 0
    a_unlike = 0
    b_total = 0
    b_unlike = 0

    for i, neigh in enumerate(neighbors):
        if i in neigh:
            neigh = [j for j in neigh if j != i]
        if not neigh:
            continue

        ti = types[i]
        neigh_types = types[neigh]
        n_all = len(neigh)

        if ti == 1:
            a_total += n_all
            a_unlike += int(np.sum(neigh_types == 2))
        elif ti == 2:
            b_total += n_all
            b_unlike += int(np.sum(neigh_types == 1))
        else:
            continue

    p_ab = (a_unlike / a_total) if a_total > 0 else float("nan")
    p_ba = (b_unlike / b_total) if b_total > 0 else float("nan")
    return p_ab, p_ba


def compute_pair_probs(
    trajectories: np.ndarray,
    types: np.ndarray,
    rc: float,
    workers: int | None = None,
) -> np.ndarray:
    """
    trajectories: (M, T, N, 2)
    types:        (M, N)
    Returns: (M, T, 2) with [P_AB, P_BA]
    """
    n_traj, n_steps, _, _ = trajectories.shape
    out = np.empty((n_traj, n_steps, 2), dtype=np.float64)

    if workers is None or workers <= 1:
        for m in tqdm(range(n_traj), desc="Computing pair probabilities"):
            ttypes = types[m]
            for t in range(n_steps):
                p_ab, p_ba = _pair_probs_for_frame(trajectories[m, t], ttypes, rc)
                out[m, t, 0] = p_ab
                out[m, t, 1] = p_ba
    else:
        with mp.Pool(
            processes=workers,
            initializer=_init_pair_prob_worker,
            initargs=(trajectories, types, rc),
        ) as pool:
            # Pool.imap preserves input order, keeping trajectory order consistent.
            results = tqdm(
                pool.imap(_pair_probs_for_trajectory, range(n_traj)),
                total=n_traj,
                desc="Computing pair probabilities",
            )
            for m, traj_probs in enumerate(results):
                out[m] = traj_probs

    return out


def plot_pair_probs(pair_probs: np.ndarray, output: Path) -> None:
    """
    pair_probs: (M, T, 2) with [P_AB, P_BA]
    """
    n_traj, n_steps, _ = pair_probs.shape
    t = np.arange(n_steps)

    fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)

    for m in range(n_traj):
        axes[0].plot(t, pair_probs[m, :, 0], alpha=0.7, linewidth=1.0)
        axes[1].plot(t, pair_probs[m, :, 1], alpha=0.7, linewidth=1.0)

    axes[0].set_title("P_AB")
    axes[1].set_title("P_BA")
    for ax in axes:
        ax.set_xlabel("time step")
        ax.set_ylabel("probability")
        ax.set_ylim(0.0, 1.0)

    fig.tight_layout()
    output.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output, dpi=200)
    plt.close(fig)


def main() -> None:
    script_dir = Path(__file__).resolve().parent
    # default_npz = script_dir / "dataset" / "trajectories.npz"
    # default_out = script_dir / "dataset" / "macro_feature.npy"
    # default_npz = script_dir / "dataset" / "trajectories_large.npz"
    # default_out = script_dir / "dataset" / "macro_feature_large.npy"
    # default_npz = script_dir / "dataset" / "trajectories_test_right.npz"
    # default_out = script_dir / "dataset" / "macro_feature_test_right.npy"

    # default_npz = script_dir / "dataset" / "trajectories_inDistribution_test.npz"
    # default_out = script_dir / "dataset" / "macro_feature_inDistribution_test.npy"
    # default_fig = script_dir / "dataset" / "pair_probs_inDistribution_test.png"

    default_npz = script_dir / "dataset" / "trajectories_diffN_test.npz"
    default_out = script_dir / "dataset" / "macro_feature_diffN_test.npy"
    default_fig = script_dir / "dataset" / "pair_probs_diffN_test.png"

    parser = argparse.ArgumentParser(description="Compute P_AB and P_BA for each frame.")
    parser.add_argument("--npz", type=Path, default=default_npz, help="Input NPZ file.")
    parser.add_argument("--output", type=Path, default=default_out, help="Output NPY file.")
    parser.add_argument("--figure", type=Path, default=default_fig, help="Output figure path.")
    parser.add_argument("--rc", type=float, default=2.5, help="Neighbor cutoff radius.")
    parser.add_argument(
        "--workers",
        type=int,
        default=10,
        help="Number of worker processes (0/1 disables multiprocessing).",
    )
    args = parser.parse_args()

    data = np.load(args.npz, allow_pickle=True)
    trajectories = data["trajectories"]
    types = data["types"]
    print(f"trajectories shape: {trajectories.shape}, types shape: {types.shape}")

    pair_probs = compute_pair_probs(trajectories, types, args.rc, args.workers)

    args.output.parent.mkdir(parents=True, exist_ok=True)
    np.save(args.output, pair_probs)
    # plot_pair_probs(pair_probs, args.figure)
    print(
        f"Saved pair_probs with shape {pair_probs.shape} to {args.output} "
        # f"and plot to {args.figure}"
    )


if __name__ == "__main__":
    main()
