"""
ICML Plots (v4): Story-driven figures for speedups and tendril scaling

Reads benchmark_results.csv and generates clean, ICML-friendly plots:

Core plots:
  1) sim_speedup_by_family.(png/pdf)
     Median simulation speedup per graph family (with IQR whiskers).

  2) sim_time_scatter_with_headline.(png/pdf)
     Original simulation time vs reduced simulation time (log-log), with headline stats.

  3) tendril_speedup_scaling.(png/pdf)
     Speedup vs tendril length for synthetic loopy-core+tendrils graphs,
     with separate lines per core size. Uses simulation speedup when available,
     otherwise falls back to compile speedup.

Additional storytelling plots:
  4) sim_speedup_vs_orig_species.(png/pdf)
     How speedup grows with problem size: binned median speedup vs original #species
     (optionally shows raw scatter).

  5) speedup_vs_species_reduction_by_family.(png/pdf)
     Small multiples (one panel per family): speedup vs % species reduction.

No internal titles (use LaTeX captions).

Usage:
  python benchmarks/plots_icml4.py
or:
  from plots_icml4 import generate_icml4_plots
  generate_icml4_plots(csv_file, output_dir)
"""

from __future__ import annotations

import csv
import os
import re
import sys
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt

# allow running from benchmarks/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))



# =============================================================================
# CSV parsing (robust to strings / NaN)
# =============================================================================

def _to_float(x: Any) -> float:
    if x is None:
        return float("nan")
    if isinstance(x, (int, float, np.floating)):
        return float(x)
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none", "n/a", "na"}:
        return float("nan")
    try:
        return float(s)
    except ValueError:
        return float("nan")


def _to_int(x: Any) -> Optional[int]:
    if x is None:
        return None
    if isinstance(x, int):
        return x
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none", "n/a", "na"}:
        return None
    try:
        return int(float(s))
    except ValueError:
        return None


def load_results_from_csv(filename: str) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []
    int_keys = {
        "orig_vars", "reduced_vars", "orig_factors", "reduced_factors",
        "orig_edges", "reduced_edges", "orig_species", "reduced_species",
        "orig_reactions", "reduced_reactions", "n_reduction_steps",
        "n_linear_steps", "n_colinear_steps",
    }
    with open(filename, "r", newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            out: Dict[str, Any] = {"name": row.get("name", "")}
            for k, v in row.items():
                if k == "name":
                    continue
                out[k] = _to_int(v) if k in int_keys else _to_float(v)
            results.append(out)
    return results



# =============================================================================
# Helpers
# =============================================================================

def categorize_name(name: str) -> str:
    if name.startswith("chain"):
        return "chain"
    if name.startswith("tree"):
        return "tree"
    if name.startswith("loopy"):
        return "loopy"
    if name.startswith("grid"):
        return "grid"
    if name.startswith("random"):
        return "random"
    return "other"


def _speedup(orig: np.ndarray, reduced: np.ndarray) -> np.ndarray:
    with np.errstate(divide="ignore", invalid="ignore"):
        return orig / reduced


def _pct_decrease(orig: np.ndarray, reduced: np.ndarray) -> np.ndarray:
    with np.errstate(divide="ignore", invalid="ignore"):
        return 100.0 * (1.0 - reduced / orig)


def _pct_reduction(reduced: np.ndarray, orig: np.ndarray) -> np.ndarray:
    with np.errstate(divide="ignore", invalid="ignore"):
        return 100.0 * (1.0 - reduced / orig)


def _save(fig, output_dir: str, stem: str, dpi: int = 220):
    os.makedirs(output_dir, exist_ok=True)
    png = os.path.join(output_dir, f"{stem}.png")
    pdf = os.path.join(output_dir, f"{stem}.pdf")
    fig.savefig(png, dpi=dpi, facecolor="white", edgecolor="none")
    fig.savefig(pdf, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("Saved:", png)
    print("Saved:", pdf)



# =============================================================================
# Plot 1: Simulation speedup by family (median + IQR)
# =============================================================================

def plot_sim_speedup_by_family(results: List[Dict[str, Any]], output_dir: str):
    fams = [categorize_name(r["name"]) for r in results]
    orig_sim = np.array([_to_float(r.get("orig_sim_time")) for r in results], dtype=float)
    red_sim  = np.array([_to_float(r.get("reduced_sim_time")) for r in results], dtype=float)

    mask = np.isfinite(orig_sim) & np.isfinite(red_sim) & (orig_sim > 0) & (red_sim > 0)
    if not np.any(mask):
        print("plot_sim_speedup_by_family: no valid simulation rows.")
        return

    fams_m = np.array([fams[i] for i in np.where(mask)[0]])
    spd = _speedup(orig_sim[mask], red_sim[mask])

    order = [f for f in ["chain", "tree", "loopy", "grid", "random", "other"] if f in set(fams_m)]
    meds, yerr_lo, yerr_hi, ns = [], [], [], []

    for f in order:
        vals = spd[fams_m == f]
        vals = vals[np.isfinite(vals)]
        if len(vals) == 0:
            continue
        q25, q50, q75 = np.percentile(vals, [25, 50, 75])
        meds.append(float(q50))
        yerr_lo.append(float(q50 - q25))
        yerr_hi.append(float(q75 - q50))
        ns.append(int(len(vals)))

    x = np.arange(len(meds))
    yerr = np.vstack([yerr_lo, yerr_hi]) if meds else None

    fig, ax = plt.subplots(1, 1, figsize=(6.3, 3.8), constrained_layout=True)
    ax.bar(x, meds, yerr=yerr, capsize=3, alpha=0.85, edgecolor="black", linewidth=0.6)

    ax.axhline(1.0, linestyle="--", linewidth=1.0, alpha=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels([o.capitalize() for o in order], fontsize=11)
    ax.set_ylabel("simulation speedup (orig / reduced)", fontsize=12)
    ax.grid(True, axis="y", alpha=0.18, linewidth=0.5)

    ymax = max(meds) if meds else 1.0
    for i, (m, n) in enumerate(zip(meds, ns)):
        ax.text(i, m + 0.04 * ymax, f"n={n}", ha="center", va="bottom", fontsize=9)

    ax.text(0.02, 0.98, f"overall median {np.median(spd):.2f}×  (n={len(spd)})",
            transform=ax.transAxes, ha="left", va="top", fontsize=10,
            bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9))

    _save(fig, output_dir, "sim_speedup_by_family")



# =============================================================================
# Plot 2: Orig-vs-reduced simulation scatter + headline stats
# =============================================================================

def plot_sim_time_scatter_with_headline(results: List[Dict[str, Any]], output_dir: str):
    orig_sim = np.array([_to_float(r.get("orig_sim_time")) for r in results], dtype=float)
    red_sim  = np.array([_to_float(r.get("reduced_sim_time")) for r in results], dtype=float)

    mask = np.isfinite(orig_sim) & np.isfinite(red_sim) & (orig_sim > 0) & (red_sim > 0)
    if not np.any(mask):
        print("plot_sim_time_scatter_with_headline: no valid simulation rows.")
        return

    o = orig_sim[mask]
    rr = red_sim[mask]
    spd = _speedup(o, rr)
    dec = _pct_decrease(o, rr)

    fig, ax = plt.subplots(1, 1, figsize=(6.3, 4.0), constrained_layout=True)
    ax.scatter(o, rr, s=30, alpha=0.8, edgecolors="white", linewidth=0.4)

    maxv = max(float(np.max(o)), float(np.max(rr)))
    ax.plot([1e-6, maxv], [1e-6, maxv], linestyle="--", linewidth=1.0, alpha=0.7)

    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("original simulation time (s)", fontsize=12)
    ax.set_ylabel("reduced simulation time (s)", fontsize=12)
    ax.grid(True, which="both", alpha=0.18, linewidth=0.5)

    ax.text(
        0.02, 0.98,
        f"median speedup: {np.median(spd):.2f}×\nmedian time decrease: {np.median(dec):.0f}%\n(n={len(o)})",
        transform=ax.transAxes, ha="left", va="top", fontsize=10,
        bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9)
    )

    _save(fig, output_dir, "sim_time_scatter_with_headline")



# =============================================================================
# Plot 3: Tendril scaling (speedup vs tendril length)
# =============================================================================

_TENDRIL_RE = re.compile(r"loopy_c(?P<c>\d+)_t(?P<t>\d+)$")


def _parse_core_tendril(name: str) -> Optional[Tuple[int, int]]:
    m = _TENDRIL_RE.match(name.strip())
    if not m:
        return None
    return int(m.group("c")), int(m.group("t"))


def plot_tendril_percent_speedup_scaling(results: List[Dict[str, Any]], output_dir: str):
    """
    Tendril length vs % speedup.
    Uses simulation % speedup when available; otherwise falls back to compile % speedup.
      % speedup = 100 * (orig / reduced - 1)
    """
    rows: List[Tuple[int, int, float, str]] = []
    for r in results:
        parsed = _parse_core_tendril(r["name"])
        if parsed is None:
            continue
        core, tendril = parsed

        os_ = _to_float(r.get("orig_sim_time"))
        rs_ = _to_float(r.get("reduced_sim_time"))
        if np.isfinite(os_) and np.isfinite(rs_) and os_ > 0 and rs_ > 0:
            pct = 100.0 * (os_ / rs_ - 1.0)
            mode = "sim"
        else:
            oc = _to_float(r.get("orig_compile_time"))
            rc = _to_float(r.get("reduced_compile_time"))
            if not (np.isfinite(oc) and np.isfinite(rc) and oc > 0 and rc > 0):
                continue
            pct = 100.0 * (oc / rc - 1.0)
            mode = "compile"

        rows.append((core, tendril, float(pct), mode))

    if not rows:
        print("plot_tendril_percent_speedup_scaling: no loopy_c*_t* rows found.")
        return

    cores = sorted(set(c for c, t, pct, mode in rows))

    fig, ax = plt.subplots(1, 1, figsize=(6.6, 4.0), constrained_layout=True)

    for c in cores:
        sub = [(t, pct) for (cc, t, pct, mode) in rows if cc == c]
        ts = sorted(set(t for t, _ in sub))
        meds = []
        for t in ts:
            vals = np.array([pct for tt, pct in sub if tt == t], dtype=float)
            vals = vals[np.isfinite(vals)]
            meds.append(float(np.median(vals)) if len(vals) else float("nan"))
        ax.plot(ts, meds, marker="o", linewidth=1.7, markersize=6, label=f"core {c}")

    ax.axhline(0.0, linestyle="--", linewidth=1.0, alpha=0.8)
    ax.set_xlabel("tendril length", fontsize=12)
    ax.set_ylabel("percent speedup (%)", fontsize=12)
    ax.grid(True, alpha=0.18, linewidth=0.5)
    ax.legend(fontsize=10, frameon=True, loc="best")

    # Headline at max tendril length
    max_t = max(t for c, t, pct, mode in rows)
    at_max = np.array([pct for c, t, pct, mode in rows if t == max_t], dtype=float)
    if len(at_max) > 0:
        ax.text(
            0.02, 0.98,
            f"at tendril {max_t}: median {np.median(at_max):.0f}% (n={len(at_max)})",
            transform=ax.transAxes, ha="left", va="top", fontsize=10,
            bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9)
        )

    _save(fig, output_dir, "tendril_percent_speedup_scaling")


# =============================================================================
# Plot 4: speedup vs original size (# species), binned medians
# =============================================================================

def plot_sim_speedup_vs_orig_species(results: List[Dict[str, Any]], output_dir: str, n_bins: int = 6, show_scatter: bool = True):
    """
    Show how speedup increases with original system size.
    We bin by log10(orig_species) and plot median speedup per bin.
    """
    orig_species = np.array([_to_float(r.get("orig_species")) for r in results], dtype=float)
    orig_sim = np.array([_to_float(r.get("orig_sim_time")) for r in results], dtype=float)
    red_sim  = np.array([_to_float(r.get("reduced_sim_time")) for r in results], dtype=float)

    mask = np.isfinite(orig_species) & np.isfinite(orig_sim) & np.isfinite(red_sim) & (orig_species > 0) & (orig_sim > 0) & (red_sim > 0)
    if not np.any(mask):
        print("plot_sim_speedup_vs_orig_species: no valid simulation rows.")
        return

    s = orig_species[mask]
    spd = _speedup(orig_sim[mask], red_sim[mask])

    # bins in log-space
    ls = np.log10(s)
    edges = np.linspace(ls.min(), ls.max(), n_bins + 1)
    bin_id = np.digitize(ls, edges, right=False) - 1
    bin_id = np.clip(bin_id, 0, n_bins - 1)

    centers = []
    meds = []
    q25s = []
    q75s = []
    counts = []
    for b in range(n_bins):
        vals = spd[bin_id == b]
        if len(vals) == 0:
            continue
        centers.append(10 ** (0.5 * (edges[b] + edges[b+1])))
        q25, q50, q75 = np.percentile(vals, [25, 50, 75])
        meds.append(q50)
        q25s.append(q25)
        q75s.append(q75)
        counts.append(len(vals))

    centers = np.array(centers, dtype=float)
    meds = np.array(meds, dtype=float)
    q25s = np.array(q25s, dtype=float)
    q75s = np.array(q75s, dtype=float)

    fig, ax = plt.subplots(1, 1, figsize=(6.6, 4.0), constrained_layout=True)

    if show_scatter:
        ax.scatter(s, spd, s=22, alpha=0.25, edgecolors="none")

    ax.plot(centers, meds, marker="o", linewidth=2.0, markersize=6)
    ax.fill_between(centers, q25s, q75s, alpha=0.18)

    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("original # species", fontsize=12)
    ax.set_ylabel("simulation speedup (orig / reduced)", fontsize=12)
    ax.grid(True, which="both", alpha=0.18, linewidth=0.5)
    ax.axhline(1.0, linestyle="--", linewidth=1.0, alpha=0.8)

    ax.text(0.02, 0.98, f"overall median {np.median(spd):.2f}×  (n={len(spd)})",
            transform=ax.transAxes, ha="left", va="top", fontsize=10,
            bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9))

    _save(fig, output_dir, "sim_speedup_vs_orig_species")


# =============================================================================
# Plot 5: small multiples — speedup vs % species reduction, stratified by family
# =============================================================================

def plot_speedup_vs_species_reduction_per_family(results: List[Dict[str, Any]], output_dir: str):
    """
    For each family, save an individual plot:
      y = simulation percent speedup = 100*(orig_sim/reduced_sim - 1)
      x = percent species reduction   = 100*(1 - reduced_species/orig_species)
    """
    fam = np.array([categorize_name(r["name"]) for r in results])

    orig_species = np.array([_to_float(r.get("orig_species")) for r in results], dtype=float)
    red_species  = np.array([_to_float(r.get("reduced_species")) for r in results], dtype=float)
    orig_sim = np.array([_to_float(r.get("orig_sim_time")) for r in results], dtype=float)
    red_sim  = np.array([_to_float(r.get("reduced_sim_time")) for r in results], dtype=float)

    mask = (
        np.isfinite(orig_species) & np.isfinite(red_species) &
        np.isfinite(orig_sim) & np.isfinite(red_sim) &
        (orig_species > 0) & (red_species > 0) &
        (orig_sim > 0) & (red_sim > 0)
    )
    if not np.any(mask):
        print("plot_speedup_vs_species_reduction_per_family: no valid simulation rows.")
        return

    pct_species_red = 100.0 * (1.0 - (red_species[mask] / orig_species[mask]))
    pct_speedup = 100.0 * ((orig_sim[mask] / red_sim[mask]) - 1.0)
    fam_m = fam[mask]

    families = [f for f in ["chain", "tree", "loopy", "grid", "random", "other"] if f in set(fam_m)]
    if not families:
        families = sorted(set(fam_m))

    os.makedirs(output_dir, exist_ok=True)

    for f in families:
        m = fam_m == f
        if np.sum(m) == 0:
            continue

        x = pct_species_red[m]
        y = pct_speedup[m]

        fig, ax = plt.subplots(1, 1, figsize=(6.2, 4.0), constrained_layout=True)
        ax.scatter(x, y, s=28, alpha=0.75, edgecolors="white", linewidth=0.4)

        ax.axhline(0.0, linestyle="--", linewidth=1.0, alpha=0.8)
        ax.axvline(0.0, linestyle="--", linewidth=1.0, alpha=0.35)
        ax.grid(True, alpha=0.18, linewidth=0.5)

        ax.set_xlabel("% species reduction", fontsize=12)
        ax.set_ylabel("% speedup (simulation)", fontsize=12)

        ax.text(
            0.02, 0.98,
            f"{f.capitalize()}  |  median {np.median(y):.0f}%  (n={np.sum(m)})",
            transform=ax.transAxes, ha="left", va="top", fontsize=10,
            bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9)
        )

        _save(fig, output_dir, f"speedup_vs_species_reduction_{f}")

#=============================================================================
#Logtime
def plot_tendril_log_runtime_ratio(results: List[Dict[str, Any]], output_dir: str):
    """
    Tendril length vs log10(runtime ratio) = log10(orig/reduced).
    Uses simulation times if available; otherwise falls back to compile times.
    """
    rows = []  # (core, tendril, log10_ratio)
    for r in results:
        parsed = _parse_core_tendril(r["name"])
        if parsed is None:
            continue
        core, tendril = parsed

        os_ = _to_float(r.get("orig_sim_time"))
        rs_ = _to_float(r.get("reduced_sim_time"))
        if np.isfinite(os_) and np.isfinite(rs_) and os_ > 0 and rs_ > 0:
            ratio = os_ / rs_
        else:
            oc = _to_float(r.get("orig_compile_time"))
            rc = _to_float(r.get("reduced_compile_time"))
            if not (np.isfinite(oc) and np.isfinite(rc) and oc > 0 and rc > 0):
                continue
            ratio = oc / rc

        if ratio <= 0 or not np.isfinite(ratio):
            continue

        rows.append((core, tendril, float(np.log10(ratio))))

    if not rows:
        print("plot_tendril_log_runtime_ratio: no loopy_c*_t* rows found.")
        return

    cores = sorted(set(c for c, t, v in rows))
    fig, ax = plt.subplots(1, 1, figsize=(6.6, 4.0), constrained_layout=True)

    for c in cores:
        sub = [(t, v) for (cc, t, v) in rows if cc == c]
        ts = sorted(set(t for t, _ in sub))
        meds = []
        for t in ts:
            vals = np.array([v for tt, v in sub if tt == t], dtype=float)
            vals = vals[np.isfinite(vals)]
            meds.append(float(np.median(vals)) if len(vals) else float("nan"))
        ax.plot(ts, meds, marker="o", linewidth=1.7, markersize=6, label=f"core {c}")

    ax.axhline(0.0, linestyle="--", linewidth=1.0, alpha=0.8)
    ax.set_xlabel("tendril length", fontsize=12)
    ax.set_ylabel(r"$\log_{10}(t_{\mathrm{orig}}/t_{\mathrm{reduced}})$", fontsize=12)
    ax.grid(True, alpha=0.18, linewidth=0.5)
    ax.legend(fontsize=10, frameon=True, loc="best")

    # Helpful guide lines (optional): 10x and 100x faster
    ax.axhline(1.0, linestyle=":", linewidth=1.0, alpha=0.6)
    ax.axhline(2.0, linestyle=":", linewidth=1.0, alpha=0.6)
    ax.text(0.98, 1.0, "10×", transform=ax.get_yaxis_transform(),
            ha="right", va="bottom", fontsize=9, alpha=0.8)
    ax.text(0.98, 2.0, "100×", transform=ax.get_yaxis_transform(),
            ha="right", va="bottom", fontsize=9, alpha=0.8)

    _save(fig, output_dir, "tendril_log_runtime_ratio")

# =============================================================================
# Entrypoint
# =============================================================================

def generate_icml4_plots(csv_file: str, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    results = load_results_from_csv(csv_file)
    print(f"Loaded {len(results)} benchmark rows from {csv_file}")

    plot_sim_speedup_by_family(results, output_dir)
    plot_sim_time_scatter_with_headline(results, output_dir)
    plot_tendril_percent_speedup_scaling(results, output_dir)

    plot_sim_speedup_vs_orig_species(results, output_dir)
    plot_speedup_vs_species_reduction_per_family(results, output_dir)
    plot_tendril_log_runtime_ratio(results, output_dir)
    print(f"All plots saved to {output_dir}")


if __name__ == "__main__":
    csv_file = "/home/mauwork/factor_graph_project/results/benchmark_results.csv"
    output_dir = "/home/mauwork/factor_graph_project/results/plots_icml4"
    if os.path.exists(csv_file):
        generate_icml4_plots(csv_file, output_dir)
    else:
        print(f"CSV file not found: {csv_file}")
        print("Run benchmark_runner.py first to generate results.")
