"""
ICML-Ready Plotting (v2): Story-first figures for SP-B Reduction Benchmarks

What this generates (main paper friendly):
1) reduction_speedup_2x2.(png/pdf)
   A clean 2x2 grid where EACH panel carries its headline number(s) in a corner:
     - Variables: scatter (orig vs reduced) + % reduction annotation
     - Species: scatter (orig vs reduced) + % reduction annotation
     - Compile time: scatter (orig vs reduced) + speedup + % decrease annotation
     - Simulation time: scatter (orig vs reduced) + speedup + % decrease annotation
   -> This replaces cluttered dashboards: the reader sees the number AND the evidence in one glance.

2) napp_reduction_convergence.(png/pdf)
   A larger, clearer CRN trajectory comparison with a shorter, auto-chosen horizon
   and convergence defined as "enter epsilon band AND stay there" (settling time),
   avoiding the misleading "first crossing" story.

Design principles:
- No internal titles (LaTeX captions carry narrative)
- Color-blind friendly palette (blue/orange/gray/black)
- Minimal gridlines
- Keep aspect ratio / no stretching

Run:
    python benchmarks/plot_icml2.py
or import generate_icml2_plots(csv_file, output_dir)
"""

from __future__ import annotations

import csv
import os
import sys
from typing import Any, Dict, List, Optional
import re
import numpy as np
import matplotlib.pyplot as plt

# Ensure project root import works when invoked from benchmarks/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# =============================================================================
# Palette + typography (color-blind friendly)
# =============================================================================
BLUE = "#1f77b4"      # Original
ORANGE = "#ff7f0e"    # Reduced
GRAY = "#7f7f7f"      # Diagonals/refs
BLACK = "#000000"     # Targets
RED = "#d62728"       # Threshold lines (sparingly)

LABEL_SIZE = 12
TICK_SIZE = 10
LEGEND_SIZE = 10


# =============================================================================
# Robust CSV parsing
# =============================================================================

def _to_float(x: Any) -> float:
    if x is None:
        return float("nan")
    if isinstance(x, (int, float, np.floating)):
        return float(x)
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none", "n/a", "na"}:
        return float("nan")
    try:
        return float(s)
    except ValueError:
        return float("nan")


def _to_int(x: Any) -> Optional[int]:
    if x is None:
        return None
    if isinstance(x, int):
        return x
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none", "n/a", "na"}:
        return None
    try:
        return int(float(s))
    except ValueError:
        return None


def load_results_from_csv(filename: str) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []
    int_keys = {
        "orig_vars", "reduced_vars", "orig_factors", "reduced_factors",
        "orig_edges", "reduced_edges", "orig_species", "reduced_species",
        "orig_reactions", "reduced_reactions", "n_reduction_steps",
        "bp_converged_orig", "bp_converged_reduced",
    }
    with open(filename, "r", newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            out: Dict[str, Any] = {"name": row.get("name", "")}
            for k, v in row.items():
                if k == "name":
                    continue
                if k in int_keys:
                    out[k] = _to_int(v)
                else:
                    out[k] = _to_float(v)
            results.append(out)
    return results


# =============================================================================
# Helpers: reductions/speedups and annotation text
# =============================================================================

def _finite(vals: np.ndarray) -> np.ndarray:
    return vals[np.isfinite(vals)]


def _pct_reduction(reduced: np.ndarray, orig: np.ndarray) -> np.ndarray:
    # percentage decrease (positive means reduced smaller)
    with np.errstate(divide="ignore", invalid="ignore"):
        return 100.0 * (1.0 - (reduced / orig))


def _speedup(orig: np.ndarray, reduced: np.ndarray) -> np.ndarray:
    with np.errstate(divide="ignore", invalid="ignore"):
        return orig / reduced


def _annot_box(ax, lines: List[str], loc: str = "upper left"):
    if loc == "upper left":
        x, y, ha, va = 0.03, 0.97, "left", "top"
    elif loc == "upper right":
        x, y, ha, va = 0.97, 0.97, "right", "top"
    elif loc == "lower left":
        x, y, ha, va = 0.03, 0.03, "left", "bottom"
    else:
        x, y, ha, va = 0.97, 0.03, "right", "bottom"

    ax.text(
        x, y, "\n".join(lines),
        transform=ax.transAxes,
        ha=ha, va=va,
        fontsize=10,
        bbox=dict(boxstyle="round,pad=0.35", facecolor="white", edgecolor=GRAY, alpha=0.92),
    )


def _scatter_with_diagonal(ax, x: np.ndarray, y: np.ndarray, color: str,
                          xlabel: str, ylabel: str, log_if_span: bool = True):
    ax.scatter(x, y, s=28, alpha=0.75, color=color, edgecolors="white", linewidth=0.4)

    min_val = float(min(np.min(x), np.min(y))) * 0.9
    max_val = float(max(np.max(x), np.max(y))) * 1.1
    if min_val <= 0:
        min_val = 1e-3
    ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color=GRAY, linewidth=1.2)

    if log_if_span and (max_val / min_val > 50):
        ax.set_xscale("log")
        ax.set_yscale("log")

    ax.set_xlabel(xlabel, fontsize=LABEL_SIZE)
    ax.set_ylabel(ylabel, fontsize=LABEL_SIZE)
    ax.tick_params(axis="both", labelsize=TICK_SIZE)
    ax.grid(True, alpha=0.18, linewidth=0.5)


# =============================================================================
# Main figure: 2x2 evidence + headline numbers in each panel
# =============================================================================

def plot_reduction_speedup_2x2(results: List[Dict[str, Any]], output_dir: str):
    os.makedirs(output_dir, exist_ok=True)

    # Variables
    var_orig, var_red = [], []
    for r in results:
        o = r.get("orig_vars") or 0
        rr = r.get("reduced_vars") or 0
        if o > 0 and rr >= 0:
            var_orig.append(float(o))
            var_red.append(float(rr))
    var_orig = np.array(var_orig, dtype=float)
    var_red = np.array(var_red, dtype=float)

    # Species
    sp_orig, sp_red = [], []
    for r in results:
        o = r.get("orig_species") or 0
        rr = r.get("reduced_species") or 0
        if o > 0 and rr >= 0:
            sp_orig.append(float(o))
            sp_red.append(float(rr))
    sp_orig = np.array(sp_orig, dtype=float)
    sp_red = np.array(sp_red, dtype=float)

    fig, axes = plt.subplots(2, 2, figsize=(12.0, 7.6), constrained_layout=True)

    # -------------------------
    # Top-left: % variable reduction tile
    # -------------------------
    ax = axes[0, 0]
    ax.axis("off")
    if len(var_orig) > 0:
        pct = _finite(_pct_reduction(var_red, var_orig))
        avg_pct = float(np.mean(pct))
        med_pct = float(np.median(pct))

        ax.text(0.5, 0.62, f"{avg_pct:.0f}%", ha="center", va="center",
                fontsize=44, color=BLUE, fontweight="bold", transform=ax.transAxes)
        ax.text(0.5, 0.38, f"avg variable reduction\n(median {med_pct:.0f}%)",
                ha="center", va="center", fontsize=13, transform=ax.transAxes)
        ax.text(0.5, 0.14, f"n = {len(var_orig)} instances",
                ha="center", va="center", fontsize=11, color=GRAY, transform=ax.transAxes)

    # -------------------------
    # Top-right: % species reduction tile
    # -------------------------
    ax = axes[0, 1]
    ax.axis("off")
    if len(sp_orig) > 0:
        pct = _finite(_pct_reduction(sp_red, sp_orig))
        avg_pct = float(np.mean(pct))
        med_pct = float(np.median(pct))

        ax.text(0.5, 0.62, f"{avg_pct:.0f}%", ha="center", va="center",
                fontsize=44, color=ORANGE, fontweight="bold", transform=ax.transAxes)
        ax.text(0.5, 0.38, f"avg species reduction\n(median {med_pct:.0f}%)",
                ha="center", va="center", fontsize=13, transform=ax.transAxes)
        ax.text(0.5, 0.14, f"n = {len(sp_orig)} instances",
                ha="center", va="center", fontsize=11, color=GRAY, transform=ax.transAxes)

    # -------------------------
    # Bottom-left: variable scatter
    # -------------------------
    ax = axes[1, 0]
    if len(var_orig) > 0:
        _scatter_with_diagonal(
            ax, var_orig, var_red, BLUE,
            "Original variables", "Reduced variables",
            log_if_span=True
        )
    else:
        ax.axis("off")

    # -------------------------
    # Bottom-right: species scatter
    # -------------------------
    ax = axes[1, 1]
    if len(sp_orig) > 0:
        _scatter_with_diagonal(
            ax, sp_orig, sp_red, ORANGE,
            "Original CRN species", "Reduced CRN species",
            log_if_span=True
        )
    else:
        ax.axis("off")

    out_png = os.path.join(output_dir, "reduction_2x2_numbers_plus_scatter.png")
    out_pdf = os.path.join(output_dir, "reduction_2x2_numbers_plus_scatter.pdf")
    fig.savefig(out_png, dpi=220, facecolor="white", edgecolor="none")
    fig.savefig(out_pdf, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("Saved:", out_png)
    print("Saved:", out_pdf)


# =============================================================================
# CRN trajectory figure: emphasize settling time (not first crossing)
# =============================================================================

def _first_crossing_time(t: np.ndarray, y: np.ndarray, target: float) -> Optional[float]:
    d = y - target
    s = np.sign(d)
    idx = np.where(np.diff(s) != 0)[0]
    if len(idx) == 0:
        return None
    return float(t[int(idx[0] + 1)])


def _settling_time(t: np.ndarray, y: np.ndarray, target: float, eps: float = 0.01, hold_fraction: float = 0.2) -> Optional[float]:
    """First time entering |y-target|<=eps and staying within eps thereafter.
    For robustness on long horizons, we require the last `hold_fraction` of points are within eps.
    """
    if len(t) < 5:
        return None
    err = np.abs(y - target)
    within = err <= eps

    hold_n = max(3, int(len(t) * hold_fraction))
    if not np.all(within[-hold_n:]):
        return None

    for i in range(len(t) - hold_n):
        if np.all(within[i:]):
            return float(t[i])
    return float(t[-hold_n])


def plot_napp_reduction_convergence(output_dir: str,
                                   sim_time: float = 5000,
                                   eps: float = 0.01,
                                   hold_fraction: float = 0.2):
    os.makedirs(output_dir, exist_ok=True)

    from core import Variable, Factor, FactorGraph
    from crn import compile_factor_graph_to_crn, simulate_crn
    from reduction.poset_reduction import from_factor_graph, reduce_to_core_spb, to_factor_graph_if_possible
    from inference import run_bp

    # Representative graph (same idea as plot_icml.py)
    fg = FactorGraph("trajectory_test")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))

    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.9, 0.1], [0.2, 0.8]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.7, 0.3], [0.4, 0.6]])))
    fg.add_factor(Factor("u1", [x1], np.array([0.8, 0.2])))
    fg.add_factor(Factor("u3", [x3], np.array([0.3, 0.7])))

    bp = run_bp(fg, tolerance=1e-10, max_iterations=1000, damping=0.3)
    target = float(bp.get_marginal("x3")[0])

    orig_crn = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
    poset = from_factor_graph(fg)
    reduce_to_core_spb(poset)
    red_fg = to_factor_graph_if_possible(poset)
    red_crn = compile_factor_graph_to_crn(red_fg, kappa_r=0.02, kappa_prod=50.0)

    orig_sim = simulate_crn(orig_crn, t_end=sim_time, n_points=450)
    red_sim = simulate_crn(red_crn, t_end=sim_time, n_points=450)

    def get_traj(sim):
        if hasattr(sim, "get_marginal_trajectory"):
            arr = sim.get_marginal_trajectory("x3")
            return np.array(sim.times, dtype=float), np.array(arr[:, 0], dtype=float)
        t = np.array(sim.times, dtype=float)
        m0 = np.array(sim.concentrations.get("Marginal_x3_1", np.zeros_like(t)), dtype=float)
        m1 = np.array(sim.concentrations.get("Marginal_x3_2", np.zeros_like(t)), dtype=float)
        tot = m0 + m1
        y = np.where(tot > 0, m0 / tot, 0.5)
        return t, y

    t_o, y_o = get_traj(orig_sim)
    t_r, y_r = get_traj(red_sim)

    cross_o = _first_crossing_time(t_o, y_o, target)
    cross_r = _first_crossing_time(t_r, y_r, target)

    settle_o = _settling_time(t_o, y_o, target, eps=eps, hold_fraction=hold_fraction)
    settle_r = _settling_time(t_r, y_r, target, eps=eps, hold_fraction=hold_fraction)

    # Choose shorter horizon for readability based on settle/cross times
    cand = [t for t in [settle_o, settle_r, cross_o, cross_r] if t is not None and np.isfinite(t)]
    if cand:
        x_max = min(sim_time, max(cand) * 1.35)
        x_max = max(x_max, sim_time * 0.1)
    else:
        x_max = sim_time * 0.2

    fig, ax = plt.subplots(1, 1, figsize=(7.2, 4.8), constrained_layout=True)

    ax.plot(t_o, y_o, color=BLUE, linewidth=2.4, label=f"Original ({len(orig_crn.species)} species)")
    ax.plot(t_r, y_r, color=ORANGE, linewidth=2.4, label=f"Reduced ({len(red_crn.species)} species)")
    ax.axhline(target, color=BLACK, linestyle="--", linewidth=1.4, label=f"BP target {target:.3f}")

    # epsilon band (visual)
    ax.axhline(target + eps, color=GRAY, linestyle=":", linewidth=1.0, alpha=0.6)
    ax.axhline(target - eps, color=GRAY, linestyle=":", linewidth=1.0, alpha=0.6)

    ax.set_xlim(0, x_max)

    # Tight y-limits around region of interest
    y_min = min(np.min(y_o[t_o <= x_max]), np.min(y_r[t_r <= x_max]), target - 0.12)
    y_max = max(np.max(y_o[t_o <= x_max]), np.max(y_r[t_r <= x_max]), target + 0.12)
    ax.set_ylim(y_min, y_max)

    def vline(t, color, style, text):
        if t is None:
            return
        ax.axvline(t, color=color, linestyle=style, linewidth=1.2, alpha=0.9)
        ax.text(t, y_min + 0.03*(y_max - y_min), text, color=color,
                fontsize=9, rotation=90, va="bottom", ha="right")

    # Key: settling times (solid)
    vline(settle_o, BLUE, "-", f"settle {settle_o:.0f}" if settle_o is not None else "")
    vline(settle_r, ORANGE, "-", f"settle {settle_r:.0f}" if settle_r is not None else "")
    # Secondary: first crossing (dotted)
    #vline(cross_o, BLUE, ":", f"cross {cross_o:.0f}" if cross_o is not None else "")
    #vline(cross_r, ORANGE, ":", f"cross {cross_r:.0f}" if cross_r is not None else "")

    ax.set_xlabel("Time", fontsize=LABEL_SIZE)
    ax.set_ylabel(r"$P(x_3=0)$", fontsize=LABEL_SIZE)
    ax.tick_params(axis="both", labelsize=TICK_SIZE)
    ax.grid(True, alpha=0.18, linewidth=0.5)
    ax.legend(fontsize=LEGEND_SIZE, loc="best", frameon=True)

    out_png = os.path.join(output_dir, "napp_reduction_convergence.png")
    out_pdf = os.path.join(output_dir, "napp_reduction_convergence.pdf")
    fig.savefig(out_png, dpi=220, facecolor="white", edgecolor="none")
    fig.savefig(out_pdf, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("Saved:", out_png)
    print("Saved:", out_pdf)
# =============================================================================
# CRN trajectory plots for EACH benchmark instance (orig vs reduced)
# =============================================================================


def _crn_marginal_matrix(sim, var_name: str):
    """
    Return (t, P) where P is (T,K) normalized marginal over k=1..K for var_name.

    IMPORTANT: simulate_crn returns SimulationResult which does NOT include sim.crn.
    So we discover marginal species by scanning sim.concentrations keys.
    """
    t = np.array(sim.times, dtype=float)

    # Look for keys like "Marginal_x3_1", "Marginal_x3_2", ...
    pat = re.compile(rf"^Marginal_{re.escape(var_name)}_(\d+)$")

    ks = []
    for key in sim.concentrations.keys():
        m = pat.match(key)
        if m:
            k = int(m.group(1))
            # ignore the unassigned "k=0" bucket if present
            if k >= 1:
                ks.append(k)

    if not ks:
        raise ValueError(
            f"No marginal species found for variable '{var_name}'. "
            f"Example keys: {list(sim.concentrations.keys())[:10]}"
        )

    K = max(ks)

    raw = []
    denom = np.zeros_like(t, dtype=float)
    for k in range(1, K + 1):
        sp = f"Marginal_{var_name}_{k}"
        y = np.array(sim.concentrations.get(sp, np.zeros_like(t)), dtype=float)
        raw.append(y)
        denom += y

    denom = np.maximum(denom, 1e-15)
    P = np.stack([col / denom for col in raw], axis=1)  # (T,K)
    return t, P


def plot_crn_trajectories_for_instance(
    fg,
    name: str,
    output_dir: str,
    sim_time: float = 5000.0,
    n_points: int = 450,
    max_vars: int = 6,
    kappa_r: float = 0.02,
    kappa_prod: float = 50.0,
    include_bp_targets: bool = True,
):
    """
    For one FG instance:
      - reduce -> red_fg
      - compile orig/red CRNs
      - simulate both CRNs
      - plot marginal trajectories (orig solid, reduced dashed) for up to max_vars vars
    """
    from crn import compile_factor_graph_to_crn, simulate_crn
    from reduction.poset_reduction import from_factor_graph, reduce_to_core_spb, to_factor_graph_if_possible

    # Optional BP targets (reference lines)
    bp = None
    if include_bp_targets:
        from inference import run_bp
        bp = run_bp(fg, tolerance=1e-10, max_iterations=1000, damping=0.3)

    # Reduce FG
    poset = from_factor_graph(fg)
    reduce_to_core_spb(poset)
    red_fg = to_factor_graph_if_possible(poset)
    if red_fg is None or getattr(red_fg, "num_variables", 0) == 0:
        print(f"[skip] {name}: reduction produced trivial graph")
        return

    # Compile CRNs
    orig_crn = compile_factor_graph_to_crn(fg, kappa_r=kappa_r, kappa_prod=kappa_prod)
    red_crn  = compile_factor_graph_to_crn(red_fg, kappa_r=kappa_r, kappa_prod=kappa_prod)

    # IMPORTANT: simulate CRNs (ODE integration)
    orig_sim = simulate_crn(orig_crn, t_end=sim_time, n_points=n_points)
    red_sim  = simulate_crn(red_crn,  t_end=sim_time, n_points=n_points)

    # Which variables to plot: those surviving in reduced FG
    vars_to_plot = [v.name for v in red_fg.variables][:max_vars]
    if not vars_to_plot:
        print(f"[skip] {name}: no vars to plot")
        return

    # Build series
    series = []
    for vn in vars_to_plot:
        try:
            t_o, P_o = _crn_marginal_matrix(orig_sim, vn)
            t_r, P_r = _crn_marginal_matrix(red_sim, vn)
        except Exception:
            continue

        # Interpolate reduced onto orig time grid if needed
        if len(t_o) != len(t_r) or np.max(np.abs(t_o - t_r)) > 1e-9:
            P_r2 = np.zeros((len(t_o), P_r.shape[1]), dtype=float)
            for k in range(P_r.shape[1]):
                P_r2[:, k] = np.interp(t_o, t_r, P_r[:, k])
            t = t_o
            P_r_use = P_r2
            P_o_use = P_o
        else:
            t = t_o
            P_o_use = P_o
            P_r_use = P_r

        targets = None
        if bp is not None and getattr(bp, "converged", False):
            try:
                targets = bp.get_marginal(vn)  # shape (K,)
            except Exception:
                targets = None

        series.append((vn, t, P_o_use, P_r_use, targets))

    if not series:
        print(f"[skip] {name}: no valid marginal species found")
        return

    # Plot: one subplot per variable
    n = len(series)
    ncols = 2 if n > 1 else 1
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(7.2, 2.8 * nrows), constrained_layout=True)
    axes = np.atleast_1d(axes).ravel()

    for ax, (vn, t, P_o, P_r, targets) in zip(axes, series):
        K = P_o.shape[1]
        for k in range(K):
            ax.plot(t, P_o[:, k], color=BLUE, linewidth=1.7, alpha=0.9)
            ax.plot(t, P_r[:, k], color=ORANGE, linewidth=1.7, alpha=0.9, linestyle="--")

            # Optional BP reference lines per state
            if targets is not None and k < len(targets):
                ax.axhline(float(targets[k]), color=GRAY, linestyle=":", linewidth=0.9, alpha=0.7)

        ax.set_title(vn, fontsize=11)
        ax.set_xlabel("time (s)", fontsize=LABEL_SIZE)
        ax.set_ylabel("marginal mass", fontsize=LABEL_SIZE)
        ax.tick_params(axis="both", labelsize=TICK_SIZE)
        ax.grid(True, alpha=0.18, linewidth=0.5)

    for j in range(len(series), len(axes)):
        axes[j].axis("off")

    # Legend proxy
    from matplotlib.lines import Line2D
    fig.legend(
        handles=[
            Line2D([0], [0], color=BLUE, lw=2.0, linestyle="-", label=f"original ({len(orig_crn.species)} sp)"),
            Line2D([0], [0], color=ORANGE, lw=2.0, linestyle="--", label=f"reduced ({len(red_crn.species)} sp)"),
        ],
        loc="upper right",
        frameon=True,
        fontsize=LEGEND_SIZE,
    )

    # Save
    os.makedirs(output_dir, exist_ok=True)
    stem = f"crn_traj_{name}"
    out_png = os.path.join(output_dir, f"{stem}.png")
    out_pdf = os.path.join(output_dir, f"{stem}.pdf")
    fig.savefig(out_png, dpi=220, facecolor="white", edgecolor="none")
    fig.savefig(out_pdf, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("Saved:", out_png)
    print("Saved:", out_pdf)


def generate_crn_trajectory_suite(
    output_dir: str,
    sim_time: float = 5000.0,
    n_points: int = 450,
    max_vars: int = 6,
):
    """
    Re-runs the benchmark *graph generators* and outputs an individual CRN trajectory plot per instance.
    """
    from benchmarks.graph_generators import (
        generate_chain,
        generate_binary_tree,
        generate_loopy_core_with_tendrils,
        generate_grid_with_pruned_leaves,
        generate_random_with_planted_core,
    )

    cases = []

    # Chains
    for n in [5, 10, 20, 50, 100]:
        cases.append((f"chain_{n}", generate_chain(n)))

    # Trees
    for depth in [3, 4, 5, 6]:
        cases.append((f"tree_d{depth}", generate_binary_tree(depth)))

    # Loopy + tendrils
    for core_size in [3, 4, 5]:
        for tendril_len in [1, 3, 5, 10]:
            cases.append((f"loopy_c{core_size}_t{tendril_len}",
                          generate_loopy_core_with_tendrils(core_size, tendril_len)))

    # Grids
    for size in [3, 4, 5, 6]:
        cases.append((f"grid_{size}x{size}",
                      generate_grid_with_pruned_leaves(size, size, prune_fraction=0.5)))

    # Random
    for n_total in [15, 25, 40]:
        core_size = max(3, n_total // 5)
        cases.append((f"random_{n_total}_c{core_size}",
                      generate_random_with_planted_core(n_total, core_size)))

    traj_dir = os.path.join(output_dir, "crn_trajectories")
    os.makedirs(traj_dir, exist_ok=True)

    for name, fg in cases:
        try:
            plot_crn_trajectories_for_instance(
                fg=fg,
                name=name,
                output_dir=traj_dir,
                sim_time=sim_time,
                n_points=n_points,
                max_vars=max_vars,
                include_bp_targets=True,   # set False if you want CRN-only plots
            )
        except Exception as e:
            print(f"[error] {name}: {e}")


# =============================================================================
# Entrypoint
# =============================================================================

def generate_icml2_plots(csv_file: str, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    results = load_results_from_csv(csv_file)
    print(f"Loaded {len(results)} benchmark rows from {csv_file}")

    plot_reduction_speedup_2x2(results, output_dir)

    # Synthetic/illustrative CRN trajectory figure (no CSV needed)
    try:
        plot_napp_reduction_convergence(output_dir=output_dir, sim_time=5000, eps=0.01, hold_fraction=0.2)
    # Individual CRN trajectory plots for ALL benchmark instances
        generate_crn_trajectory_suite(output_dir=output_dir, sim_time=5000.0, n_points=450, max_vars=6)
    except Exception as e:
        print(f"Warning: could not generate napp convergence plot: {e}")

    print(f"All plots saved to {output_dir}")


if __name__ == "__main__":
    csv_file = "/home/mauwork/factor_graph_project/results/benchmark_results.csv"
    output_dir = "/home/mauwork/factor_graph_project/results/plots_icml2"
    if os.path.exists(csv_file):
        generate_icml2_plots(csv_file, output_dir)
    else:
        print(f"CSV file not found: {csv_file}")
        print("Run benchmark_runner.py first to generate results.")
