"""
icml5.py — Convergence-time speedup for Napp–Adams CRN trajectories

Given two marginal trajectories (original vs reduced) of the form (t, p(t)),
compute a robust "settling time" and report:
  - t_settle_orig, t_settle_reduced
  - speedup factor (t_orig / t_reduced)
  - percent time decrease 100*(1 - t_reduced/t_orig)

Settling time definition (standard in control / numerical analysis):
  1) Estimate final_value as the median over the last FINAL_WINDOW seconds.
  2) Define a tolerance band [final_value - tol, final_value + tol]
     (optionally enlarged by rel_tol * |final_value|).
  3) The settling time is the earliest time t_i such that the trajectory stays
     inside the band for a sustained duration WINDOW seconds.

This is designed to quantify plots like your napp_reduction_convergence:
e.g., reduced (9 species) settles at 189s vs original (45 species) at 434s.

Inputs
------
(A) Two CSVs, each with columns: time,value
    python icml5.py --orig orig.csv --reduced reduced.csv --outdir results

(B) One CSV with columns: time,value_orig,value_reduced
    python icml5.py --both both.csv --outdir results

Optional knobs
--------------
--tol 0.002          absolute tolerance (default 0.002)
--rel_tol 0.01       relative tolerance (default None)
--window 10.0        sustain window in seconds (default 10)
--final_window 50.0  window for final value estimation in seconds (default 50)
--smooth 0           rolling mean window in samples to reduce noise (default 0)
--stem name          output filename stem (default napp_convergence_annotated)

Outputs
-------
- Prints metrics to stdout
- Saves JSON report + annotated plot (png/pdf) when --outdir is provided
"""

from __future__ import annotations

import argparse
import csv
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt


# -------------------------
# IO helpers
# -------------------------

def _read_two_col_csv(path: str) -> Tuple[np.ndarray, np.ndarray]:
    times, vals = [], []
    with open(path, "r", newline="") as f:
        reader = csv.reader(f)
        first = next(reader, None)
        if first is None:
            raise ValueError(f"Empty CSV: {path}")

        def _is_number(s: str) -> bool:
            try:
                float(s)
                return True
            except Exception:
                return False

        # If first row looks numeric, treat as data; else treat as header.
        if len(first) >= 2 and _is_number(first[0]) and _is_number(first[1]):
            times.append(float(first[0]))
            vals.append(float(first[1]))

        for row in reader:
            if not row or len(row) < 2:
                continue
            try:
                times.append(float(row[0]))
                vals.append(float(row[1]))
            except Exception:
                continue

    t = np.asarray(times, dtype=float)
    y = np.asarray(vals, dtype=float)
    idx = np.argsort(t)
    return t[idx], y[idx]


def _read_three_col_csv(path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    t, o, r = [], [], []
    with open(path, "r", newline="") as f:
        reader = csv.DictReader(f)
        fields = reader.fieldnames or []

        def pick(*cands: str) -> Optional[str]:
            for c in cands:
                if c in fields:
                    return c
            return None

        k_t = pick("time", "t", "Time", "seconds")
        k_o = pick("value_orig", "orig", "original", "p_orig", "p_original")
        k_r = pick("value_reduced", "reduced", "p_reduced")

        if k_t is None or k_o is None or k_r is None:
            raise ValueError(f"Could not find required columns in {path}. Found: {fields}")

        for row in reader:
            try:
                t.append(float(row[k_t]))
                o.append(float(row[k_o]))
                r.append(float(row[k_r]))
            except Exception:
                continue

    t = np.asarray(t, dtype=float)
    o = np.asarray(o, dtype=float)
    r = np.asarray(r, dtype=float)
    idx = np.argsort(t)
    return t[idx], o[idx], r[idx]


# -------------------------
# Core computation
# -------------------------

@dataclass
class SettlingResult:
    t_settle: float
    final_value: float
    tol_used: float
    band: Tuple[float, float]
    entered_at: Optional[float] = None


def _rolling_mean(x: np.ndarray, w: int) -> np.ndarray:
    if w <= 1:
        return x
    w = int(w)
    if w >= len(x):
        return np.full_like(x, np.mean(x))
    c = np.cumsum(np.insert(x, 0, 0.0))
    y = (c[w:] - c[:-w]) / w
    pad_left = np.full(w - 1, y[0])
    return np.concatenate([pad_left, y])


def compute_settling_time(
    t: np.ndarray,
    y: np.ndarray,
    *,
    tol: float = 0.002,
    rel_tol: Optional[float] = None,
    window_sec: float = 10.0,
    final_window_sec: float = 50.0,
    smooth_window_samples: int = 0,
) -> SettlingResult:
    if len(t) != len(y) or len(t) < 5:
        raise ValueError("t and y must have same length and at least 5 samples.")

    idx = np.argsort(t)
    t = t[idx]
    y = y[idx]

    y_use = _rolling_mean(y, smooth_window_samples) if smooth_window_samples and smooth_window_samples > 1 else y

    t_end = float(t[-1])
    final_mask = t >= (t_end - final_window_sec)
    if not np.any(final_mask):
        final_mask = np.ones_like(t, dtype=bool)

    final_value = float(np.median(y_use[final_mask]))

    abs_tol = float(tol)
    if rel_tol is not None:
        abs_tol = max(abs_tol, float(rel_tol) * abs(final_value))

    lo, hi = final_value - abs_tol, final_value + abs_tol
    inside = (y_use >= lo) & (y_use <= hi)

    n = len(t)
    j = 0
    for i in range(n):
        if not inside[i]:
            continue
        target = t[i] + window_sec
        if j < i:
            j = i
        while j < n and t[j] < target:
            j += 1
        if j >= n:
            break
        if np.all(inside[i:j+1]):
            return SettlingResult(
                t_settle=float(t[i]),
                final_value=final_value,
                tol_used=abs_tol,
                band=(lo, hi),
                entered_at=float(t[i]),
            )

    return SettlingResult(
        t_settle=float(t[-1]),
        final_value=final_value,
        tol_used=abs_tol,
        band=(lo, hi),
        entered_at=None,
    )


def summarize_pair(t_orig: float, t_red: float) -> Dict[str, float]:
    speedup = float(t_orig / t_red) if t_red > 0 else float("inf")
    pct_time_decrease = float(100.0 * (1.0 - t_red / t_orig)) if t_orig > 0 else float("nan")
    return {
        "t_settle_orig": float(t_orig),
        "t_settle_reduced": float(t_red),
        "speedup_factor": speedup,
        "percent_time_decrease": pct_time_decrease,
    }


# -------------------------
# Plotting
# -------------------------

def plot_annotated(
    t: np.ndarray,
    y_o: np.ndarray,
    y_r: np.ndarray,
    res_o: SettlingResult,
    res_r: SettlingResult,
    outdir: str,
    stem: str = "napp_convergence_annotated",
    title: str = "",
):
    fig, ax = plt.subplots(1, 1, figsize=(6.6, 4.0), constrained_layout=True)
    ax.plot(t, y_o, linewidth=2.0, label="original")
    ax.plot(t, y_r, linewidth=2.0, label="reduced")

    ax.axhspan(res_o.band[0], res_o.band[1], alpha=0.10)
    ax.axhspan(res_r.band[0], res_r.band[1], alpha=0.10)

    ax.axvline(res_o.t_settle, linestyle="--", linewidth=1.4, alpha=0.9)
    ax.axvline(res_r.t_settle, linestyle="--", linewidth=1.4, alpha=0.9)

    ax.set_xlabel("time (s)", fontsize=12)
    ax.set_ylabel("marginal mass", fontsize=12)
    ax.grid(True, alpha=0.18, linewidth=0.5)
    if title:
        ax.set_title(title)

    stats = summarize_pair(res_o.t_settle, res_r.t_settle)
    ax.text(
        0.02, 0.98,
        f"settling time orig: {stats['t_settle_orig']:.0f}s\n"
        f"settling time red:  {stats['t_settle_reduced']:.0f}s\n"
        f"time decrease: {stats['percent_time_decrease']:.0f}%",
        transform=ax.transAxes, ha="left", va="top", fontsize=10,
        bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9),
    )

    ax.legend(frameon=True, fontsize=10, loc="best")

    os.makedirs(outdir, exist_ok=True)
    png = os.path.join(outdir, f"{stem}.png")
    pdf = os.path.join(outdir, f"{stem}.pdf")
    fig.savefig(png, dpi=220, facecolor="white", edgecolor="none")
    fig.savefig(pdf, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("Saved:", png)
    print("Saved:", pdf)


# -------------------------
# CLI
# -------------------------

def main():
    ap = argparse.ArgumentParser()
    g = ap.add_mutually_exclusive_group(required=True)
    g.add_argument("--orig", type=str, help="CSV with columns: time,value for ORIGINAL trajectory")
    g.add_argument("--both", type=str, help="CSV with columns: time,value_orig,value_reduced")
    ap.add_argument("--reduced", type=str, help="CSV with columns: time,value for REDUCED trajectory (required if --orig)")
    ap.add_argument("--outdir", type=str, default="", help="Output directory for report + annotated plot")
    ap.add_argument("--tol", type=float, default=0.002, help="Absolute tolerance around final value")
    ap.add_argument("--rel_tol", type=float, default=None, help="Relative tolerance (fraction of final value)")
    ap.add_argument("--window", type=float, default=10.0, help="Sustain window length in seconds")
    ap.add_argument("--final_window", type=float, default=50.0, help="Final value estimated using last N seconds")
    ap.add_argument("--smooth", type=int, default=0, help="Rolling mean window in samples (0 disables)")
    ap.add_argument("--title", type=str, default="", help="Optional plot title (usually leave empty; use LaTeX caption)")
    ap.add_argument("--stem", type=str, default="napp_convergence_annotated", help="Filename stem for output plots")
    args = ap.parse_args()

    if args.both:
        t, y_o, y_r = _read_three_col_csv(args.both)
    else:
        if not args.reduced:
            raise SystemExit("--reduced is required when using --orig")
        t_o, y_o = _read_two_col_csv(args.orig)
        t_r, y_r = _read_two_col_csv(args.reduced)
        # If time grids differ, interpolate reduced onto original
        if len(t_o) != len(t_r) or np.max(np.abs(t_o - t_r)) > 1e-9:
            y_r = np.interp(t_o, t_r, y_r)
        t = t_o

    res_o = compute_settling_time(
        t, y_o, tol=args.tol, rel_tol=args.rel_tol, window_sec=args.window,
        final_window_sec=args.final_window, smooth_window_samples=args.smooth
    )
    res_r = compute_settling_time(
        t, y_r, tol=args.tol, rel_tol=args.rel_tol, window_sec=args.window,
        final_window_sec=args.final_window, smooth_window_samples=args.smooth
    )

    stats = summarize_pair(res_o.t_settle, res_r.t_settle)

    print("\n=== Convergence / Settling Time Report ===")
    print(f"orig settling time:    {stats['t_settle_orig']:.3f} s")
    print(f"reduced settling time: {stats['t_settle_reduced']:.3f} s")
    print(f"speedup factor:        {stats['speedup_factor']:.3f}x")
    print(f"% time decrease:       {stats['percent_time_decrease']:.2f}%")

    if args.outdir:
        os.makedirs(args.outdir, exist_ok=True)
        report = {
            "orig": {"t_settle": res_o.t_settle, "final_value": res_o.final_value, "tol": res_o.tol_used, "band": res_o.band},
            "reduced": {"t_settle": res_r.t_settle, "final_value": res_r.final_value, "tol": res_r.tol_used, "band": res_r.band},
            "pair": stats,
            "params": {
                "tol": args.tol,
                "rel_tol": args.rel_tol,
                "window_sec": args.window,
                "final_window_sec": args.final_window,
                "smooth_window_samples": args.smooth,
            }
        }
        json_path = os.path.join(args.outdir, f"{args.stem}.json")
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(report, f, indent=2)
        print("Saved:", json_path)

        plot_annotated(t, y_o, y_r, res_o, res_r, args.outdir, stem=args.stem, title=args.title)


if __name__ == "__main__":
    main()
