"""
ICML Plots (v3): Extra figures from benchmark_results.csv

Generates 3 additional ICML-friendly plots from the CSV:
  1) speedup_vs_reduction.(png/pdf)
     Simulation speedup vs % species reduction (colored by graph family)
  2) scaling_sim_time_vs_species.(png/pdf) and scaling_compile_time_vs_species.(png/pdf)
     Time vs size (orig & reduced) on log-log axes to show scaling improvement
  3) tradeoff_correctness_speed.(png/pdf)
     Marginal max diff vs simulation speedup (with 1% line)

Usage:
  python benchmarks/plot_icml3.py

Or:
  from plot_icml3 import generate_icml3_plots
  generate_icml3_plots(csv_file, output_dir)

Notes:
- No internal titles (use LaTeX captions).
- Robust CSV parsing (handles NaN / empty / strings).
- Uses matplotlib only (no seaborn).
"""

from __future__ import annotations

import csv
import os
import sys
from typing import Any, Dict, List, Optional

import numpy as np
import matplotlib.pyplot as plt

# Ensure project root import works when invoked from benchmarks/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


# =============================================================================
# Robust CSV parsing
# =============================================================================

def _to_float(x: Any) -> float:
    if x is None:
        return float("nan")
    if isinstance(x, (int, float, np.floating)):
        return float(x)
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none", "n/a", "na"}:
        return float("nan")
    try:
        return float(s)
    except ValueError:
        return float("nan")


def _to_int(x: Any) -> Optional[int]:
    if x is None:
        return None
    if isinstance(x, int):
        return x
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none", "n/a", "na"}:
        return None
    try:
        return int(float(s))
    except ValueError:
        return None


def load_results_from_csv(filename: str) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []
    int_keys = {
        "orig_vars", "reduced_vars", "orig_factors", "reduced_factors",
        "orig_edges", "reduced_edges", "orig_species", "reduced_species",
        "orig_reactions", "reduced_reactions", "n_reduction_steps",
        "bp_converged_orig", "bp_converged_reduced",
    }
    with open(filename, "r", newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            out: Dict[str, Any] = {"name": row.get("name", "")}
            for k, v in row.items():
                if k == "name":
                    continue
                if k in int_keys:
                    out[k] = _to_int(v)
                else:
                    out[k] = _to_float(v)
            results.append(out)
    return results


# =============================================================================
# Helpers
# =============================================================================

def categorize_name(name: str) -> str:
    if name.startswith("chain"):
        return "chain"
    if name.startswith("tree"):
        return "tree"
    if name.startswith("loopy"):
        return "loopy"
    if name.startswith("grid"):
        return "grid"
    if name.startswith("random"):
        return "random"
    return "other"


def _finite_mask(*arrays: np.ndarray) -> np.ndarray:
    mask = np.ones_like(arrays[0], dtype=bool)
    for a in arrays:
        mask &= np.isfinite(a)
    return mask


def _pct_reduction(reduced: np.ndarray, orig: np.ndarray) -> np.ndarray:
    with np.errstate(divide="ignore", invalid="ignore"):
        return 100.0 * (1.0 - reduced / orig)


def _speedup(orig: np.ndarray, reduced: np.ndarray) -> np.ndarray:
    with np.errstate(divide="ignore", invalid="ignore"):
        return orig / reduced


def _save(fig, output_dir: str, stem: str, dpi: int = 220):
    os.makedirs(output_dir, exist_ok=True)
    png = os.path.join(output_dir, f"{stem}.png")
    pdf = os.path.join(output_dir, f"{stem}.pdf")
    fig.savefig(png, dpi=dpi, facecolor="white", edgecolor="none")
    fig.savefig(pdf, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("Saved:", png)
    print("Saved:", pdf)


# =============================================================================
# Plot 1: Speedup vs Reduction (mechanism plot)
# =============================================================================

def plot_speedup_vs_reduction(results: List[Dict[str, Any]], output_dir: str):
    """
    Bar chart: median simulation speedup vs % species reduction (binned).
    This is much more readable than a scatter for ICML.

    Output: speedup_vs_reduction.(png/pdf)
    """
    orig_species = np.array([float(r.get("orig_species") or np.nan) for r in results], dtype=float)
    red_species  = np.array([float(r.get("reduced_species") or np.nan) for r in results], dtype=float)
    orig_sim = np.array([float(r.get("orig_sim_time", np.nan)) for r in results], dtype=float)
    red_sim  = np.array([float(r.get("reduced_sim_time", np.nan)) for r in results], dtype=float)

    mask = (
        np.isfinite(orig_species) & np.isfinite(red_species) &
        np.isfinite(orig_sim) & np.isfinite(red_sim) &
        (orig_species > 0) & (red_species > 0) &
        (orig_sim > 0) & (red_sim > 0)
    )
    if not np.any(mask):
        print("plot_speedup_vs_reduction: no valid simulation rows.")
        return

    pct = _pct_reduction(red_species[mask], orig_species[mask])   # in %
    spd = _pct_reduction(red_sim[mask], orig_sim[mask])
    #spd = _speedup(orig_sim[mask], red_sim[mask])                # orig/reduced

    # --- Bin by % reduction ---
    # Choose bins that read well in a paper. Adjust if your reductions cluster.
    #bins = np.array([0, 50, 70, 80, 90, 95, 100], dtype=float)
    bins = np.array([0, 20, 40, 60, 80, 100], dtype=float)
    bin_ids = np.digitize(pct, bins, right=False) - 1  # 0..len(bins)-2
    n_bins = len(bins) - 1

    med = np.full(n_bins, np.nan)
    err = np.full(n_bins, np.nan)
    counts = np.zeros(n_bins, dtype=int)

    for b in range(n_bins):
        vals = spd[bin_ids == b]
        vals = vals[np.isfinite(vals)]
        counts[b] = len(vals)
        if len(vals) == 0:
            continue
        med[b] = float(np.median(vals))
        # IQR-based error bar (half IQR) — stable and not too noisy
        q25, q75 = np.percentile(vals, [25, 75])
        err[b] = float((q75 - q25) / 2.0)

    # Drop empty bins for a cleaner figure
    keep = counts > 0
    bins_lo = bins[:-1][keep]
    bins_hi = bins[1:][keep]
    med = med[keep]
    err = err[keep]
    counts = counts[keep]

    labels = [f"{int(lo)}–{int(hi)}%" for lo, hi in zip(bins_lo, bins_hi)]

    fig, ax = plt.subplots(1, 1, figsize=(6.6, 4.2), constrained_layout=True)

    x = np.arange(len(labels))
    ax.bar(x, med,  alpha=0.85, edgecolor="black", linewidth=0.6)

    # Reference line: no speedup
    #ax.axhline(1.0, linestyle="--", linewidth=1.0, alpha=0.8)

    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=0)
    ax.set_xlabel("% species reduction", fontsize=12)
    ax.set_ylabel("% simulation speedup", fontsize=12)
    ax.grid(True, axis="y", alpha=0.18, linewidth=0.5)

    # Annotate counts above bars
    for i, (m, n) in enumerate(zip(med, counts)):
        ax.text(i, m + 0.02 * np.nanmax(med), f"n={n}", ha="center", va="bottom", fontsize=9)

    # Compact headline
    #ax.text(
    #    0.02, 0.98,
    #    f"overall mean speedup: {np.mean(spd):.2f}×\n(n={len(spd)})",
    #    transform=ax.transAxes,
    #    ha="left", va="top", fontsize=10,
    #    bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9)
    #)

    _save(fig, output_dir, "speedup_vs_reduction")


# =============================================================================
# Plot 2: Scaling law (time vs size) before/after
# =============================================================================

def plot_scaling_time_vs_size(results: List[Dict[str, Any]], output_dir: str, which: str = "sim"):
    orig_species = np.array([float(r.get("orig_species") or np.nan) for r in results], dtype=float)
    red_species  = np.array([float(r.get("reduced_species") or np.nan) for r in results], dtype=float)

    if which == "compile":
        orig_t = np.array([float(r.get("orig_compile_time", np.nan)) for r in results], dtype=float)
        red_t  = np.array([float(r.get("reduced_compile_time", np.nan)) for r in results], dtype=float)
        ylabel = "compile time (s)"
        stem = "scaling_compile_time_vs_species"
    else:
        orig_t = np.array([float(r.get("orig_sim_time", np.nan)) for r in results], dtype=float)
        red_t  = np.array([float(r.get("reduced_sim_time", np.nan)) for r in results], dtype=float)
        ylabel = "simulation time (s)"
        stem = "scaling_sim_time_vs_species"

    mask_o = np.isfinite(orig_species) & np.isfinite(orig_t) & (orig_species > 0) & (orig_t > 0)
    mask_r = np.isfinite(red_species)  & np.isfinite(red_t)  & (red_species > 0)  & (red_t > 0)

    if not np.any(mask_o) and not np.any(mask_r):
        print(f"plot_scaling_time_vs_size({which}): no valid rows.")
        return

    fig, ax = plt.subplots(1, 1, figsize=(6.6, 4.6), constrained_layout=True)

    if np.any(mask_o):
        ax.scatter(orig_species[mask_o], orig_t[mask_o], s=30, alpha=0.75, edgecolors="white", linewidth=0.4, label="original")
    if np.any(mask_r):
        ax.scatter(red_species[mask_r], red_t[mask_r], s=30, alpha=0.75, edgecolors="white", linewidth=0.4, label="reduced")

    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("# species (ODE dimension)", fontsize=12)
    ax.set_ylabel(ylabel, fontsize=12)
    ax.grid(True, which="both", alpha=0.18, linewidth=0.5)
    ax.legend(fontsize=10, frameon=True, loc="best")

    # Fit slopes (log-log) as compact scaling summary
    def fit_slope(x, y):
        lx = np.log10(x)
        ly = np.log10(y)
        A = np.vstack([lx, np.ones_like(lx)]).T
        m, b = np.linalg.lstsq(A, ly, rcond=None)[0]
        return float(m)

    lines = []
    if np.any(mask_o) and np.sum(mask_o) >= 3:
        lines.append(f"orig slope ≈ {fit_slope(orig_species[mask_o], orig_t[mask_o]):.2f}")
    if np.any(mask_r) and np.sum(mask_r) >= 3:
        lines.append(f"red slope ≈ {fit_slope(red_species[mask_r], red_t[mask_r]):.2f}")
    if lines:
        ax.text(0.02, 0.98, "\n".join(lines), transform=ax.transAxes,
                ha="left", va="top", fontsize=10,
                bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9))

    _save(fig, output_dir, stem)


# =============================================================================
# Plot 3: Correctness–speed tradeoff
# =============================================================================

def plot_tradeoff_correctness_speed(results: List[Dict[str, Any]], output_dir: str):
    diff = np.array([float(r.get("marginal_max_diff", np.nan)) for r in results], dtype=float)
    orig_sim = np.array([float(r.get("orig_sim_time", np.nan)) for r in results], dtype=float)
    red_sim  = np.array([float(r.get("reduced_sim_time", np.nan)) for r in results], dtype=float)

    mask = np.isfinite(diff) & np.isfinite(orig_sim) & np.isfinite(red_sim) & (orig_sim > 0) & (red_sim > 0) & (diff >= 0)
    if not np.any(mask):
        print("plot_tradeoff_correctness_speed: no valid rows.")
        return

    spd = _speedup(orig_sim[mask], red_sim[mask])
    d = diff[mask]

    fig, ax = plt.subplots(1, 1, figsize=(6.6, 4.6), constrained_layout=True)
    ax.scatter(spd, d, s=32, alpha=0.8, edgecolors="white", linewidth=0.4)

    ax.set_xlabel("simulation speedup (orig / reduced)", fontsize=12)
    ax.set_ylabel("max marginal difference", fontsize=12)
    ax.set_yscale("log")
    ax.grid(True, which="both", alpha=0.18, linewidth=0.5)

    ax.axhline(0.01, linestyle="--", linewidth=1.0, alpha=0.8)
    ax.text(0.98, 0.01, "1% threshold", transform=ax.get_yaxis_transform(),
            ha="right", va="bottom", fontsize=10)

    # Label up to 3 largest diffs
    idx = np.argsort(d)[-3:]
    for i in idx:
        ax.annotate(f"{d[i]:.2g}", (spd[i], d[i]), textcoords="offset points", xytext=(5, 5), fontsize=9)

    ax.text(0.02, 0.98, f"median speedup {np.median(spd):.2f}×\nmax diff {np.max(d):.2g}\n(n={len(d)})",
            transform=ax.transAxes, ha="left", va="top", fontsize=10,
            bbox=dict(boxstyle="round,pad=0.35", facecolor="white", alpha=0.9))

    _save(fig, output_dir, "tradeoff_correctness_speed")


# =============================================================================
# Entrypoint
# =============================================================================

def generate_icml3_plots(csv_file: str, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    results = load_results_from_csv(csv_file)
    print(f"Loaded {len(results)} benchmark rows from {csv_file}")

    plot_speedup_vs_reduction(results, output_dir)
    plot_scaling_time_vs_size(results, output_dir, which="sim")
    plot_scaling_time_vs_size(results, output_dir, which="compile")
    plot_tradeoff_correctness_speed(results, output_dir)

    print(f"All plots saved to {output_dir}")


if __name__ == "__main__":
    csv_file = "/home/mauwork/factor_graph_project/results/benchmark_results.csv"
    output_dir = "/home/mauwork/factor_graph_project/results/plots_icml3"
    if os.path.exists(csv_file):
        generate_icml3_plots(csv_file, output_dir)
    else:
        print(f"CSV file not found: {csv_file}")
        print("Run benchmark_runner.py first to generate results.")
