"""
ICML-style plots for SP-B reduction benchmarks.

Produces:
1) Summary table by family (with colored dot) and columns:
   - variables down %
   - species down %
   - sim speedup (x)
2) Median compression by family (species down %, colored bars)
3) Median simulation speedup by family (colored bars, log y-scale)
4) BP marginal preservation: FG_orig vs FG_red (scatter, log y-scale)

Usage:
  python3 make_icml_benchmark_plots.py \
    --csv /mnt/data/benchmark_results_cleaned.csv \
    --outdir ./icml_figs
"""

from __future__ import annotations

import argparse
import os
from typing import Dict, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


FAMILY_ORDER: List[str] = ["chain", "tree", "loopy", "grid", "random"]

# Match your palette (blue/green/orange/purple/red)
FAMILY_COLORS: Dict[str, str] = {
    "chain":  "#2563eb",  # blue
    "tree":   "#059669",  # green
    "loopy":  "#d97706",  # orange
    "grid":   "#7c3aed",  # purple
    "random": "#dc2626",  # red
}

DOT = "●"


def safe_div(a: pd.Series, b: pd.Series) -> pd.Series:
    """Elementwise a/b with protection against 0 and NaNs."""
    out = pd.Series(np.nan, index=a.index, dtype=float)
    mask = (b.astype(float) != 0) & np.isfinite(a.astype(float)) & np.isfinite(b.astype(float))
    out.loc[mask] = a.loc[mask].astype(float) / b.loc[mask].astype(float)
    return out


def ensure_outdir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def save_fig(fig: plt.Figure, outdir: str, stem: str) -> None:
    png = os.path.join(outdir, f"{stem}.png")
    pdf = os.path.join(outdir, f"{stem}.pdf")
    fig.savefig(png, dpi=300, bbox_inches="tight")
    fig.savefig(pdf, bbox_inches="tight")
    print(f"[saved] {png}")
    print(f"[saved] {pdf}")


def compute_derived(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    # If family not present, infer from name prefix
    if "family" not in df.columns:
        def infer_family(name: str) -> str:
            n = str(name)
            if n.startswith("chain"):
                return "chain"
            if n.startswith("tree"):
                return "tree"
            if n.startswith("loopy"):
                return "loopy"
            if n.startswith("grid"):
                return "grid"
            if n.startswith("random"):
                return "random"
            return "unknown"
        df["family"] = df["name"].map(infer_family)

    # Core derived metrics
    df["var_down_pct"] = 100.0 * (1.0 - safe_div(df["reduced_vars"], df["orig_vars"]))
    df["species_down_pct"] = 100.0 * (1.0 - safe_div(df["reduced_species"], df["orig_species"]))
    df["sim_speedup_x"] = safe_div(df["orig_sim_time"], df["reduced_sim_time"])

    # For BP plot
    if "marginal_max_diff" in df.columns:
        df["bp_diff"] = pd.to_numeric(df["marginal_max_diff"], errors="coerce")
    else:
        df["bp_diff"] = np.nan

    return df


def aggregate_by_family(df: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for fam in FAMILY_ORDER:
        sub = df[df["family"] == fam]
        if len(sub) == 0:
            continue
        rows.append({
            "family": fam,
            "n": len(sub),
            "var_down_median": np.nanmedian(sub["var_down_pct"].values.astype(float)),
            "species_down_median": np.nanmedian(sub["species_down_pct"].values.astype(float)),
            "sim_speedup_median": np.nanmedian(sub["sim_speedup_x"].values.astype(float)),
        })
    out = pd.DataFrame(rows)
    return out


def plot_family_summary_table(agg: pd.DataFrame, outdir: str) -> None:
    # Build table text
    table_rows = []
    for _, r in agg.iterrows():
        fam = r["family"]
        dot = f"{DOT} {fam}"
        table_rows.append([
            dot,
            int(r["n"]),
            f"{r['var_down_median']:.1f}",
            f"{r['species_down_median']:.1f}",
            f"{r['sim_speedup_median']:.1f}",
        ])

    col_labels = ["Family", "n", "Variables ↓ % (median)", "Species ↓ % (median)", "Sim speedup × (median)"]

    fig, ax = plt.subplots(figsize=(9.5, 2.8))
    ax.axis("off")

    tbl = ax.table(
        cellText=table_rows,
        colLabels=col_labels,
        cellLoc="center",
        colLoc="center",
        loc="center",
    )
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(10)
    tbl.scale(1, 1.35)

    # Header styling
    for j in range(len(col_labels)):
        cell = tbl[(0, j)]
        cell.set_text_props(weight="bold")
        cell.set_facecolor("#f2f2f2")

    # Colorize the dot+family text in first column
    for i in range(1, len(table_rows) + 1):
        fam_text = agg.iloc[i - 1]["family"]
        color = FAMILY_COLORS.get(fam_text, "black")
        cell = tbl[(i, 0)]
        cell.get_text().set_color(color)
        cell.get_text().set_fontweight("bold")
        # left-align first column a bit
        cell._loc = "left"

    fig.tight_layout()
    save_fig(fig, outdir, "icml_table_family_summary")
    plt.close(fig)


def plot_median_compression_by_family(agg: pd.DataFrame, outdir: str) -> None:
    fig, ax = plt.subplots(figsize=(6.5, 3.6))
    x = np.arange(len(agg))
    y = agg["species_down_median"].values

    bars = ax.bar(x, y)
    for i, fam in enumerate(agg["family"].values):
        bars[i].set_color(FAMILY_COLORS.get(fam, "gray"))

    ax.set_xticks(x)
    ax.set_xticklabels([f.capitalize() for f in agg["family"].values])
    ax.set_ylabel("Median species reduction (%)")
    ax.set_ylim(0, 100)
    ax.set_title("Median compression by family")

    # Value labels
    for i, val in enumerate(y):
        ax.text(i, min(val + 2, 98), f"{val:.0f}%", ha="center", va="bottom", fontsize=9)

    fig.tight_layout()
    save_fig(fig, outdir, "icml_median_species_reduction_by_family")
    plt.close(fig)


def plot_median_sim_speedup_by_family(agg: pd.DataFrame, outdir: str) -> None:
    fig, ax = plt.subplots(figsize=(6.5, 3.6))
    x = np.arange(len(agg))
    y = agg["sim_speedup_median"].values

    bars = ax.bar(x, y)
    for i, fam in enumerate(agg["family"].values):
        bars[i].set_color(FAMILY_COLORS.get(fam, "gray"))

    ax.set_xticks(x)
    ax.set_xticklabels([f.capitalize() for f in agg["family"].values])
    ax.set_ylabel("Median simulation speedup (×, log scale)")
    ax.set_yscale("log")
    ax.set_ylim(0.8, max(2.0, np.nanmax(y) * 1.8))
    ax.axhline(1.0, linewidth=1.0, linestyle="--", alpha=0.6)

    ax.set_title("Median simulation speedup by family")

    # Value labels
    for i, val in enumerate(y):
        if np.isfinite(val) and val > 0:
            ax.text(i, val * 1.15, f"{val:.1f}×", ha="center", va="bottom", fontsize=9)

    fig.tight_layout()
    save_fig(fig, outdir, "icml_median_sim_speedup_by_family")
    plt.close(fig)


def plot_bp_marginal_preservation(df: pd.DataFrame, outdir: str) -> None:
    """
    BP marginal preservation (FG_orig vs FG_red):
    scatter: x = orig_vars (or maybe reduced_vars), y = marginal_max_diff (log scale)
    colored by family.
    """
    fig, ax = plt.subplots(figsize=(7.0, 4.0))

    # Filter finite diffs and positive for log scale
    dff = df.copy()
    dff["orig_vars"] = pd.to_numeric(dff["orig_vars"], errors="coerce")
    dff["bp_diff"] = pd.to_numeric(dff["bp_diff"], errors="coerce")

    dff = dff[np.isfinite(dff["orig_vars"]) & np.isfinite(dff["bp_diff"]) & (dff["bp_diff"] > 0)]

    for fam in FAMILY_ORDER:
        sub = dff[dff["family"] == fam]
        if len(sub) == 0:
            continue
        ax.scatter(
            sub["orig_vars"].values,
            sub["bp_diff"].values,
            s=35,
            alpha=0.85,
            label=fam,
            color=FAMILY_COLORS.get(fam, "gray"),
            edgecolors="none",
        )

    ax.set_yscale("log")
    ax.set_xlabel("Original # variables")
    ax.set_ylabel("Max marginal difference (BP orig vs BP reduced)")
    ax.set_title("BP marginal preservation under SP-B reduction")

    # Helpful reference line at 1e-10 (like your React)
    ax.axhline(1e-10, linestyle="--", linewidth=1.0, alpha=0.7)

    ax.legend(frameon=False, ncol=3, fontsize=9)
    fig.tight_layout()
    save_fig(fig, outdir, "icml_bp_marginal_preservation")
    plt.close(fig)


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--csv", required=True, help="Path to benchmark_results CSV (cleaned or raw).")
    ap.add_argument("--outdir", default="./icml_figs", help="Output directory for figures.")
    args = ap.parse_args()

    ensure_outdir(args.outdir)

    df = pd.read_csv(args.csv)
    df = compute_derived(df)
    agg = aggregate_by_family(df)

    # Keep order stable
    agg["family"] = pd.Categorical(agg["family"], categories=FAMILY_ORDER, ordered=True)
    agg = agg.sort_values("family").reset_index(drop=True)

    plot_family_summary_table(agg, args.outdir)
    plot_median_compression_by_family(agg, args.outdir)
    plot_median_sim_speedup_by_family(agg, args.outdir)
    plot_bp_marginal_preservation(df, args.outdir)

    # Also save aggregated CSV (handy for LaTeX tables)
    agg_path = os.path.join(args.outdir, "aggregate_by_family.csv")
    agg.to_csv(agg_path, index=False)
    print(f"[saved] {agg_path}")


if __name__ == "__main__":
    main()
