"""
plot_icml5.py — ICML inference-trajectory plots + settling-time speedups for Napp–Adams CRNs

This script AUTO-GENERATES the requested loopy/tendril examples, compiles to Napp–Adams
CRNs, simulates both ORIGINAL and REDUCED CRNs, and produces trajectory plots that show
"time vs inferred marginal mass" for several variables at once.

Requested defaults:
  - loopy core size 3, tendril length 10
  - loopy core size 4, tendril length 10

(You can override via CLI.)

Each output figure:
  - one subplot per variable (limited by --max_vars to avoid unreadable clutter)
  - within each subplot, all states k=1..K are plotted
  - original trajectories: solid lines
  - reduced trajectories: dashed lines
  - dotted vertical markers show per-variable settling times (vector settling)
  - the figure header reports median settling time (across plotted variables) and % time decrease

Outputs per instance:
  - napp_inference_loopy_c{core}_t{tendril}.png and .pdf
  - napp_inference_loopy_c{core}_t{tendril}.json (numbers for paper)

Run (from repo root):
  python3 benchmarks/plot_icml5.py --outdir results/plots_icml5

"""

from __future__ import annotations

import argparse
import json
import os
from dataclasses import dataclass
from typing import Optional, Tuple, List

import numpy as np
import matplotlib.pyplot as plt


# -------------------------
# Repo import helper
# -------------------------
def _try_import_repo():
    import sys
    import pathlib
    here = pathlib.Path(__file__).resolve()
    repo_root = here.parent.parent  # assumes benchmarks/plot_icml5.py
    sys.path.insert(0, str(repo_root))


# -------------------------
# Settling time on a vector-valued marginal
# -------------------------
@dataclass
class VecSettlingResult:
    t_settle: float
    final_vec: np.ndarray
    tol_used: float
    entered_at: Optional[float] = None


def compute_vector_settling_time(
    t: np.ndarray,
    P: np.ndarray,
    *,
    tol: float = 0.002,
    window_sec: float = 10.0,
    final_window_sec: float = 50.0,
) -> VecSettlingResult:
    """
    P has shape (T, K) and each row should sum to 1 (or close).

    final_vec = median over last final_window_sec seconds (componentwise), renormalized.
    error(t) = max_k |P(t,k) - final_vec(k)|
    settling time = earliest t_i when error <= tol for a full window_sec interval.
    """
    idx = np.argsort(t)
    t = t[idx]
    P = P[idx]

    t_end = float(t[-1])
    final_mask = t >= (t_end - final_window_sec)
    if not np.any(final_mask):
        final_mask = np.ones_like(t, dtype=bool)

    final_vec = np.median(P[final_mask], axis=0)
    final_vec = np.maximum(final_vec, 0.0)
    s = float(final_vec.sum())
    if s > 0:
        final_vec = final_vec / s
    else:
        final_vec = np.ones(P.shape[1], dtype=float) / P.shape[1]

    err = np.max(np.abs(P - final_vec[None, :]), axis=1)
    inside = err <= tol

    n = len(t)
    j = 0
    for i in range(n):
        if not inside[i]:
            continue
        target = t[i] + window_sec
        if j < i:
            j = i
        while j < n and t[j] < target:
            j += 1
        if j >= n:
            break
        if np.all(inside[i:j + 1]):
            return VecSettlingResult(
                t_settle=float(t[i]),
                final_vec=final_vec,
                tol_used=float(tol),
                entered_at=float(t[i]),
            )

    return VecSettlingResult(
        t_settle=float(t[-1]),
        final_vec=final_vec,
        tol_used=float(tol),
        entered_at=None,
    )


def percent_time_decrease(t_orig: float, t_red: float) -> float:
    if t_orig <= 0:
        return float("nan")
    return float(100.0 * (1.0 - t_red / t_orig))


# -------------------------
# Extract marginals from SimulationResult
# -------------------------
def marginal_matrix_from_sim(sim, var_name: str) -> Tuple[np.ndarray, np.ndarray]:
    """
    Robustly extract marginal trajectories from species named like:
        Marginal_<var_tokens...>_<state>
    where <var_tokens...> may contain underscores.

    Example:
        Marginal_c0_2   -> var='c0', state=2
        Marginal_t0_0_2 -> var='t0_0', state=2
    """
    times = np.asarray(sim.times, dtype=float)

    # collect all marginal species for this var
    pairs = []
    for sp in sim.crn.species:
        if not sp.startswith("Marginal_"):
            continue
        parts = sp.split("_")
        if len(parts) < 3:
            continue
        try:
            state = int(parts[-1])
        except Exception:
            continue
        var = "_".join(parts[1:-1])  # everything between 'Marginal' and the last token
        if var == var_name:
            pairs.append((state, sp))

    if not pairs:
        raise ValueError(f"No Marginal_* species found for var='{var_name}'")

    # sort by state index
    pairs.sort(key=lambda x: x[0])
    states = [st for st, _ in pairs]
    # make state indices contiguous by position (we'll normalize anyway)
    cols = []
    denom = np.zeros_like(times, dtype=float)
    for st, sp in pairs:
        y = np.asarray(sim.concentrations.get(sp), dtype=float)
        cols.append(y)
        denom += y

    denom = np.maximum(denom, 1e-15)
    P = np.stack([c / denom for c in cols], axis=1)  # (T, K)
    return times, P

#--------------------------
# Get plot variables directly from sim 
# -------------------------

def list_marginal_vars(sim) -> List[str]:
    vars_found = set()
    for sp in sim.crn.species:
        if sp.startswith("Marginal_"):
            parts = sp.split("_")
            if len(parts) >= 3:
                try:
                    int(parts[-1])
                    var = "_".join(parts[1:-1])
                    vars_found.add(var)
                except Exception:
                    pass
    return sorted(vars_found)

# -------------------------
# Build graphs, reduce, compile, simulate
# -------------------------
def build_loopy_and_reduce(core_size: int, tendril_len: int):
    _try_import_repo()
    from benchmarks.graph_generators import generate_loopy_core_with_tendrils
    from reduction.poset_reduction import (
        from_factor_graph,
        to_factor_graph_if_possible,
        reduce_to_core_spb,
    )

    fg = generate_loopy_core_with_tendrils(core_size, tendril_len)
    poset = from_factor_graph(fg)
    reduce_to_core_spb(poset)
    red = to_factor_graph_if_possible(poset)
    if red is None or red.num_variables == 0:
        raise RuntimeError(f"Reduction produced trivial graph for loopy_c{core_size}_t{tendril_len}")
    return fg, red


def compile_and_simulate(fg, red_fg, *, kappa_r: float, kappa_prod: float, t_end: float, n_points: int):
    _try_import_repo()
    from crn import compile_factor_graph_to_crn, simulate_crn

    crn_o = compile_factor_graph_to_crn(fg, kappa_r=kappa_r, kappa_prod=kappa_prod)
    crn_r = compile_factor_graph_to_crn(red_fg, kappa_r=kappa_r, kappa_prod=kappa_prod)

    sim_o = simulate_crn(crn_o, t_end=t_end, n_points=n_points)
    sim_r = simulate_crn(crn_r, t_end=t_end, n_points=n_points)
    print("Example species names (orig):", list(crn_o.species)[:30])
    print("Example species names (red):",  list(crn_r.species)[:30])

    print("Any Marginal_ species (orig)?", any(s.startswith("Marginal_") for s in crn_o.species))
    print("Any Marginal_ species (red)?", any(s.startswith("Marginal_") for s in crn_r.species))

# Show any candidate prefixes that look like beliefs
    cands = [s for s in crn_o.species if "arg" in s.lower() or "belief" in s.lower() or s.startswith(("P_", "S_"))]
    print("Candidate belief-ish species (orig):", cands[:30])
    return crn_o, crn_r, sim_o, sim_r


# -------------------------
# Plotting
# -------------------------
def plot_inference_trajectories(
    core_size: int,
    tendril_len: int,
    fg,
    red_fg,
    crn_o,
    crn_r,
    sim_o,
    sim_r,
    *,
    outdir: str,
    max_vars: int,
    tol: float,
    window: float,
    final_window: float,
):
    # choose variables to plot based on what marginal species actually exist
    def list_marginal_vars(crn) -> List[str]:
        vars_found = set()
        for sp in crn.species:
            if not sp.startswith("Marginal_"):
                continue
            parts = sp.split("_")
            if len(parts) < 3:
                continue
            try:
                int(parts[-1])  # state index
            except Exception:
                continue
            var = "_".join(parts[1:-1])  # robust: var may contain underscores
            vars_found.add(var)
        return sorted(vars_found)

    vars_to_plot = list_marginal_vars(sim_r.crn)[:max_vars]

    series = []
    for vn in vars_to_plot:
        try:
            t_o, P_o = marginal_matrix_from_sim(sim_o, vn)
            t_r, P_r = marginal_matrix_from_sim(sim_r, vn)
        except Exception:
            continue

        # align times if needed
        if len(t_o) != len(t_r) or np.max(np.abs(t_o - t_r)) > 1e-9:
            P_r_interp = np.zeros((len(t_o), P_r.shape[1]), dtype=float)
            for k in range(P_r.shape[1]):
                P_r_interp[:, k] = np.interp(t_o, t_r, P_r[:, k])
            t = t_o
            P_o_use = P_o
            P_r_use = P_r_interp
        else:
            t = t_o
            P_o_use = P_o
            P_r_use = P_r

        so = compute_vector_settling_time(t, P_o_use, tol=tol, window_sec=window, final_window_sec=final_window)
        sr = compute_vector_settling_time(t, P_r_use, tol=tol, window_sec=window, final_window_sec=final_window)

        series.append({
            "var": vn,
            "t": t,
            "P_orig": P_o_use,
            "P_red": P_r_use,
            "settle_orig": so,
            "settle_red": sr,
        })

    if not series:
        print(f"[warn] No variables plotted for loopy_c{core_size}_t{tendril_len}")
        return

    t_settle_o = float(np.median([s["settle_orig"].t_settle for s in series]))
    t_settle_r = float(np.median([s["settle_red"].t_settle for s in series]))
    pct_dec = percent_time_decrease(t_settle_o, t_settle_r)

    # Figure layout: one subplot per variable
    n = len(series)
    ncols = 2 if n > 1 else 1
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(7.0, 2.6 * nrows), constrained_layout=True)
    axes = np.atleast_1d(axes).ravel()

    for ax, s in zip(axes, series):
        t = s["t"]
        P_o = s["P_orig"]
        P_r = s["P_red"]
        K = P_o.shape[1]

        # plot each state trajectory
        for k in range(K):
            ax.plot(t, P_o[:, k], linewidth=1.6, alpha=0.9)
            ax.plot(t, P_r[:, k], linewidth=1.6, alpha=0.9, linestyle="--")

        ax.set_title(s["var"], fontsize=11)
        ax.set_xlabel("time (s)", fontsize=10)
        ax.set_ylabel("marginal mass", fontsize=10)
        ax.grid(True, alpha=0.18, linewidth=0.5)

        # settling markers
        ax.axvline(s["settle_orig"].t_settle, linestyle=":", linewidth=1.1, alpha=0.8)
        ax.axvline(s["settle_red"].t_settle, linestyle=":", linewidth=1.1, alpha=0.8)

    for j in range(len(series), len(axes)):
        axes[j].axis("off")

    # global legend (proxies)
    from matplotlib.lines import Line2D
    proxies = [
        Line2D([0], [0], linestyle="-", linewidth=2.0, label="original"),
        Line2D([0], [0], linestyle="--", linewidth=2.0, label="reduced"),
    ]
    fig.legend(handles=proxies, loc="upper right", frameon=True, fontsize=10)

    # Figure-level annotation
    fig.text(
        0.02, 0.99,
        f"loopy core={core_size}, tendril={tendril_len}   |   species {len(crn_o.species)} → {len(crn_r.species)}\n"
        f"median settling: {t_settle_o:.0f}s → {t_settle_r:.0f}s   ({pct_dec:.0f}% time decrease)",
        ha="left", va="top", fontsize=10,
        bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9),
    )

    os.makedirs(outdir, exist_ok=True)
    stem = f"napp_inference_loopy_c{core_size}_t{tendril_len}"
    png = os.path.join(outdir, f"{stem}.png")
    pdf = os.path.join(outdir, f"{stem}.pdf")
    fig.savefig(png, dpi=220, facecolor="white", edgecolor="none")
    fig.savefig(pdf, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("Saved:", png)
    print("Saved:", pdf)

    # JSON report
    report = {
        "instance": {"family": "loopy", "core_size": core_size, "tendril_len": tendril_len},
        "sizes": {
            "orig_species": len(crn_o.species),
            "reduced_species": len(crn_r.species),
            "orig_reactions": len(crn_o.reactions),
            "reduced_reactions": len(crn_r.reactions),
        },
        "settling_median_over_vars": {
            "orig": t_settle_o,
            "reduced": t_settle_r,
            "percent_time_decrease": pct_dec,
            "speedup_factor": float(t_settle_o / t_settle_r) if t_settle_r > 0 else float("inf"),
        },
        "per_variable": [
            {
                "var": s["var"],
                "t_settle_orig": s["settle_orig"].t_settle,
                "t_settle_reduced": s["settle_red"].t_settle,
                "percent_time_decrease": percent_time_decrease(s["settle_orig"].t_settle, s["settle_red"].t_settle),
            }
            for s in series
        ],
        "params": {
            "tol": tol,
            "window_sec": window,
            "final_window_sec": final_window,
            "t_end": float(sim_o.times[-1]) if len(sim_o.times) else None,
            "n_points": int(len(sim_o.times)),
            "max_vars": int(max_vars),
        }
    }
    json_path = os.path.join(outdir, f"{stem}.json")
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(report, f, indent=2)
    print("Saved:", json_path)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--outdir", type=str, default="results/plots_icml5")

    # choose which instances to render
    ap.add_argument("--cores", type=int, nargs="+", default=[3, 4])
    ap.add_argument("--tendrils", type=int, nargs="+", default=[10])

    # simulation params
    ap.add_argument("--t_end", type=float, default=5000.0)
    ap.add_argument("--n_points", type=int, default=400)
    ap.add_argument("--kappa_r", type=float, default=0.02)
    ap.add_argument("--kappa_prod", type=float, default=50.0)

    # settling params
    ap.add_argument("--tol", type=float, default=0.002)
    ap.add_argument("--window", type=float, default=10.0)
    ap.add_argument("--final_window", type=float, default=50.0)

    # plot params
    ap.add_argument("--max_vars", type=int, default=6)

    args = ap.parse_args()

    for c in args.cores:
        for tlen in args.tendrils:
            fg, red = build_loopy_and_reduce(c, tlen)
            crn_o, crn_r, sim_o, sim_r = compile_and_simulate(
                fg, red,
                kappa_r=args.kappa_r,
                kappa_prod=args.kappa_prod,
                t_end=args.t_end,
                n_points=args.n_points,
            )
            plot_inference_trajectories(
                c, tlen, fg, red, crn_o, crn_r, sim_o, sim_r,
                outdir=args.outdir,
                max_vars=args.max_vars,
                tol=args.tol,
                window=args.window,
                final_window=args.final_window,
            )


if __name__ == "__main__":
    main()
