#!/usr/bin/env python3
"""
Plot aggregated traces produced by experiment runs in
<root>/<param>=*/trace.csv and save PDFs that match experiment.py styling.
"""
import os, glob, argparse, math
from typing import Dict, List, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib

# Headless-safe backend
if not os.environ.get("DISPLAY"):
    matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    import seaborn as sns
    _HAS_SNS = True
except Exception:
    _HAS_SNS = False


plt.rcParams.update({
    "figure.figsize": (9, 6),
    "font.size": 18,           # base font size
    "axes.labelsize": 24,      # x/y label size
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 16,
    "axes.spines.top": True,
    "axes.spines.right": True,
})

EPS = 1e-12

PARAM_SYMBOLS = {
    "r": r"r", "c": r"c",
    "sigma": r"\sigma", "sigma1": r"\sigma_1", "sigma2": r"\sigma_2",
    "alpha": r"\alpha", "beta": r"\beta", "gamma": r"\gamma", "rho": r"\rho",
    "v": r"v", "U_res": r"U_{\mathrm{res}}", "ell": r"\ell",
    "s": r"s", "a0": r"a_0", "w_min": r"w_{\min}", "a_peer": r"a_{\mathrm{peer}}",
}

# ===================== Figure scaffolding =====================

def _setup_figure():
    """Match experiment.py look: seaborn ticks + talk, 9x6 figure."""
    if _HAS_SNS:
        sns.set_style("ticks"); sns.set_context("talk")
    plt.figure(figsize=(9, 6))


def _style_and_save(ylabel: str, outpath: str):
    """Axis styling identical in spirit to experiment.py."""
    plt.xlabel("Iteration", fontsize=24)
    plt.ylabel(ylabel, fontsize=24)
    plt.xscale("log"); plt.yscale("log")
    plt.grid(True, which="major", linestyle="--", linewidth=0.8, alpha=0.6)
    plt.grid(False, which="minor")

    leg = plt.legend(fontsize=16, frameon=True, loc="lower left")
    if leg is not None:
        fr = leg.get_frame()
        fr.set_edgecolor("black"); fr.set_linewidth(1.2); fr.set_alpha(1.0)

    ax = plt.gca()
    for spine in ax.spines.values():
        spine.set_visible(True); spine.set_linewidth(1.5); spine.set_color("black")
    ax.patch.set_edgecolor("black"); ax.patch.set_linewidth(1.5)
    ax.tick_params(axis="both", labelsize=18, width=1.2)

    plt.tight_layout()
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, bbox_inches="tight", pad_inches=0.05)
    print(f"✅ Saved: {outpath}")

    if os.environ.get("DISPLAY"):
        plt.show()
    else:
        plt.close()

# ===================== Utilities =====================

def _safe_steps(steps: np.ndarray) -> np.ndarray:
    steps = steps.astype(float)
    m = np.nanmin(steps)
    if not np.isfinite(m) or m <= 0:
        steps = steps - (m if np.isfinite(m) else 0) + 1.0
    if np.any(np.diff(steps) <= 0):
        steps = np.arange(1, len(steps) + 1, dtype=float)
    return steps


def parse_val_from_folder(folder_name: str, prefix: str) -> float:
    if not folder_name.startswith(prefix):
        return float("nan")
    raw = folder_name[len(prefix):].replace("p", ".")
    try:
        return float(raw)
    except Exception:
        return float("nan")


def discover_runs(root: str, param: str) -> List[Tuple[float, str]]:
    prefix = f"{param}="
    paths: List[Tuple[float, str]] = []
    for d in sorted(glob.glob(os.path.join(root, f"{param}=*"))):
        trace_path = os.path.join(d, "trace.csv")
        if not os.path.isfile(trace_path):
            cands = glob.glob(os.path.join(d, "trace*.csv"))
            if not cands:
                continue
            trace_path = max(cands, key=os.path.getmtime)
        val = parse_val_from_folder(os.path.basename(d), prefix=prefix)
        if not math.isnan(val):
            paths.append((val, trace_path))
    paths.sort(key=lambda x: x[0])
    return paths


def load_trace(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    needed = ["step", "u1", "u2"]
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise ValueError(f"{csv_path}: missing required column(s) {missing}")
    return df.sort_values("step").reset_index(drop=True)


def _final_or_tail_median(arr: np.ndarray) -> float:
    finite = np.isfinite(arr)
    if not finite.any():
        return float("nan")
    last_idx = np.where(finite)[0][-1]
    last_val = float(arr[last_idx])
    if np.isfinite(last_val):
        return last_val
    n = len(arr)
    tail = arr[int(0.95 * n):]
    tail = tail[np.isfinite(tail)]
    if tail.size:
        return float(np.median(np.abs(tail)))
    return float(np.median(np.abs(arr[finite])))


def _choose_star(df: pd.DataFrame, key_star: str, key_final: str) -> Optional[float]:
    if key_star in df.columns:
        return _final_or_tail_median(df[key_star].to_numpy(float))
    if key_final in df.columns:
        return _final_or_tail_median(df[key_final].to_numpy(float))
    return None


def _first_present(df: pd.DataFrame, candidates) -> Optional[str]:
    for c in candidates:
        if c in df.columns:
            return c
    return None


def _fmt_val(v: float) -> str:
    if v == 0 or (1e-3 <= abs(v) < 1e3):
        return f"{v:g}"
    return f"{v:.2e}"

def _denom_floor_from_stars(stars: List[float]) -> float:
    """Experiment-style: median(|stars|) * 1e-6, lower-bounded by EPS."""
    arr = np.array([s for s in stars if np.isfinite(s)], dtype=float)
    if arr.size == 0:
        return 1.0
    return max(EPS, 1e-6 * float(np.median(np.abs(arr))))

# ===================== Main =====================

def main():
    ap = argparse.ArgumentParser(description="Plot trace.csv across <param>=* folders.")
    ap.add_argument("--root", required=True)
    ap.add_argument("--outdir", required=True)
    ap.add_argument("--setting", required=True)
    ap.add_argument("--param", required=True)
    ap.add_argument("--normalize", type=int, default=1, help="Use relative gaps if 1 (default)")
    args = ap.parse_args()

    print("[plot_traces] seaborn=", _HAS_SNS, " context=talk" if _HAS_SNS else " (rcParams only)")

    symbol = PARAM_SYMBOLS.get(args.param, args.param)
    normalize = bool(int(args.normalize))

    runs = discover_runs(args.root, args.param)
    if not runs:
        raise SystemExit(f"No trace*.csv files found under {args.root}/{args.param}=*/")

    data: Dict[float, pd.DataFrame] = {}
    u1_star_scalars: Dict[float, float] = {}
    u2_star_scalars: Dict[float, float] = {}

    for val, path in runs:
        df = load_trace(path)
        data[val] = df
        u1_star = _choose_star(df, "u1_star", "u1_final")
        u2_star = _choose_star(df, "u2_star", "u2_final")
        if u1_star is None or u2_star is None:
            raise ValueError(f"{path}: need u1_star/u2_star or u1_final/u2_final")
        u1_star_scalars[val] = float(u1_star)
        u2_star_scalars[val] = float(u2_star)

    # Filepaths
    f_u1 = os.path.join(args.outdir, f"sweep_gap_u1_{args.setting}_{args.param}{'_rel' if normalize else ''}.pdf")
    f_u2 = os.path.join(args.outdir, f"sweep_gap_u2_{args.setting}_{args.param}{'_rel' if normalize else ''}.pdf")
    f_ad = os.path.join(args.outdir, f"sweep_dist_a_{args.setting}_{args.param}{'_rel' if normalize else ''}.pdf")
    f_td = os.path.join(args.outdir, f"sweep_dist_t_{args.setting}_{args.param}{'_rel' if normalize else ''}.pdf")

    # Denominator floors (like experiment.py)
    denom_floor_u1 = _denom_floor_from_stars(list(u1_star_scalars.values())) if normalize else 1.0
    denom_floor_u2 = _denom_floor_from_stars(list(u2_star_scalars.values())) if normalize else 1.0

    # ---- u1 gap ----
    _setup_figure()
    for val, _ in runs:
        df = data[val]
        steps = _safe_steps(df["step"].to_numpy(float))
        if normalize:
            den = max(abs(u1_star_scalars[val]), denom_floor_u1, 1e-12)
            curve = np.abs(u1_star_scalars[val] - df["u1"].to_numpy(float)) / den
        else:
            curve = np.abs(u1_star_scalars[val] - df["u1"].to_numpy(float))
        curve = np.where(np.isfinite(curve), np.maximum(curve, EPS), np.nan)
        plt.plot(steps, curve, linewidth=2, label=fr"${symbol}={_fmt_val(val)}$")
    _style_and_save(r"$u_1$ relative gap" if normalize else r"$u_1$ gap", f_u1)

    # ---- u2 gap ----
    _setup_figure()
    for val, _ in runs:
        df = data[val]
        steps = _safe_steps(df["step"].to_numpy(float))
        if normalize:
            den = max(abs(u2_star_scalars[val]), denom_floor_u2, 1e-12)
            curve = np.abs(u2_star_scalars[val] - df["u2"].to_numpy(float)) / den
        else:
            curve = np.abs(u2_star_scalars[val] - df["u2"].to_numpy(float))
        curve = np.where(np.isfinite(curve), np.maximum(curve, EPS), np.nan)
        plt.plot(steps, curve, linewidth=2, label=fr"${symbol}={_fmt_val(val)}$")
    _style_and_save(r"$u_2$ relative gap" if normalize else r"$u_2$ gap", f_u2)

    # ---- a distance ----
    any_a = False
    _setup_figure()
    for val, _ in runs:
        df = data[val]
        col = _first_present(df, ["err_a_l2", "err_a_abs", "err_a"])
        if col is None:
            continue
        any_a = True
        steps = _safe_steps(df["step"].to_numpy(float))
        vals = np.asarray(df[col].to_numpy(float), dtype=float)
        vals = np.where(np.isfinite(vals), np.maximum(vals, EPS), np.nan)
        plt.plot(steps, vals, linewidth=2, label=fr"${symbol}={_fmt_val(val)}$")
    if any_a:
        _style_and_save(r"$a$ relative distance", f_ad)
    else:
        plt.close()

    # ---- t distance ----
    any_t = False
    _setup_figure()
    for val, _ in runs:
        df = data[val]
        col = _first_present(df, ["err_t_l2", "err_t_abs", "err_t"])
        if col is None:
            continue
        any_t = True
        steps = _safe_steps(df["step"].to_numpy(float))
        vals = np.asarray(df[col].to_numpy(float), dtype=float)
        vals = np.where(np.isfinite(vals), np.maximum(vals, EPS), np.nan)
        plt.plot(steps, vals, linewidth=2, label=fr"${symbol}={_fmt_val(val)}$")
    if any_t:
        _style_and_save(r"$t$ relative distance", f_td)
    else:
        plt.close()

if __name__ == "__main__":
    main()