#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Publication-ready WSD plots (Nature-like style)

Generates:
  Fig1_top_fragments_wsd_bar.{pdf,svg,png}
  Fig2_wsd_heatmap_fragment_by_layer.{pdf,svg,png}
  Fig3_layer_profiles_[fragment].{pdf,svg,png}   (one per selected fragment)
  Fig4_wsd_by_sae_[fragment].{pdf,svg,png}       (one per selected fragment)

Usage:
  uv run plot_wsd_publication_figs.py \
    --summary-dir fragment_scan_summaries_wsd \
    --out-dir figures_wsd \
    --top 15 --heatmap-top 25 --profile-top 8
"""

from __future__ import annotations

import argparse
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt


# =========================
# ======= STYLE ===========
# =========================


def _font_is_available(name: str) -> bool:
    try:
        return any(f.name == name for f in mpl.font_manager.fontManager.ttflist)
    except Exception:
        return False


def set_nature_style():
    """
    Nature-like minimalist style:
    - Sans-serif (Arial/Helvetica fallback to DejaVu Sans)
    - 8 pt base font (suitable for 85–90 mm single-column)
    - Thin spines, outward ticks, light grids
    - Vector-friendly outputs
    """
    # font
    if _font_is_available("Arial"):
        base_font = "Arial"
    elif _font_is_available("Helvetica"):
        base_font = "Helvetica"
    else:
        base_font = "DejaVu Sans"

    mpl.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": [base_font],
            "font.size": 8,
            "axes.titlesize": 8,
            "axes.labelsize": 8,
            "xtick.labelsize": 7,
            "ytick.labelsize": 7,
            "legend.fontsize": 7,
            "axes.linewidth": 0.6,
            "xtick.major.width": 0.6,
            "ytick.major.width": 0.6,
            "xtick.major.size": 2.8,
            "ytick.major.size": 2.8,
            "xtick.minor.size": 0,
            "ytick.minor.size": 0,
            "axes.grid": True,
            "grid.linewidth": 0.4,
            "grid.color": "#E0E0E0",
            "axes.grid.axis": "y",
            "axes.spines.top": False,
            "axes.spines.right": False,
            "savefig.dpi": 300,
            "savefig.bbox": "tight",
            "pdf.fonttype": 42,  # keep text as text
            "ps.fonttype": 42,
            "svg.fonttype": "none",
            "figure.dpi": 300,
        }
    )


def mm_to_in(mm: float) -> float:
    return mm / 25.4


def fig_size(width_mm: float, height_mm: float) -> Tuple[float, float]:
    return (mm_to_in(width_mm), mm_to_in(height_mm))


def save_figure(fig: mpl.figure.Figure, out_base: Path):
    out_base.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_base.with_suffix(".pdf"))
    fig.savefig(out_base.with_suffix(".svg"))
    fig.savefig(out_base.with_suffix(".png"))


# =========================
# ======= IO / UTIL =======
# =========================


def load_csv_safe(path: Path) -> Optional[pd.DataFrame]:
    if not path.exists():
        print(f"[warn] missing file: {path}")
        return None
    try:
        return pd.read_csv(path)
    except Exception as e:
        print(f"[warn] failed to read {path}: {e}")
        return None


def guess_wsd_global_col(df: pd.DataFrame) -> Optional[str]:
    """
    Find a column that looks like 'wsd_top10_med_global_median'
    with graceful fallbacks.
    """
    if df is None:
        return None
    cands = [
        c
        for c in df.columns
        if "wsd" in c and "global" in c and ("med" in c or "median" in c)
    ]
    if cands:
        return cands[0]
    # fallbacks for older schemas
    for name in [
        "wsd_top10_med_global_median",
        "wsd_top10_median_global",
        "pb_top10_med_global_median",
    ]:
        if name in df.columns:
            return name
    return None


def guess_wsd_layer_col(df: pd.DataFrame) -> Optional[str]:
    """
    Find per-layer summary column (median of top-10 WSD per run, aggregated across runs).
    Typically 'wsd_top10_med__median' after groupby aggregation.
    """
    if df is None:
        return None
    cands = [
        c
        for c in df.columns
        if "wsd" in c and "top10" in c and ("__median" in c or c.endswith("_median"))
    ]
    if cands:
        # prefer ones containing 'top10' then 'median'
        cands = sorted(cands, key=lambda s: (("top10" not in s), ("__median" not in s)))
        return cands[0]
    # fallback: any wsd median
    cands = [
        c
        for c in df.columns
        if c.startswith("wsd") and ("__median" in c or c.endswith("_median"))
    ]
    return cands[0] if cands else None


def short_label(s: str, maxlen: int = 32) -> str:
    return s if len(s) <= maxlen else s[: maxlen - 1] + "…"


# =========================
# ======= PLOTS ===========
# =========================


def plot_top_fragments_bar(
    df_frag: pd.DataFrame,
    out_dir: Path,
    top_n: int = 15,
    title: str = "Top fragments (global WSD)",
) -> Optional[Path]:
    if df_frag is None or df_frag.empty:
        print("[warn] no fragment_difficulty data; skipping Fig 1")
        return None

    ycol = guess_wsd_global_col(df_frag)
    if not ycol:
        print("[warn] cannot find a global WSD difficulty column; skipping Fig 1")
        return None

    df = df_frag.copy()
    df = df.sort_values(ycol, ascending=False).head(top_n)
    df = df.rename(columns={ycol: "wsd"})

    fig, ax = plt.subplots(figsize=fig_size(85, 70 + 4.8 * len(df)))  # scalable height
    bars = ax.barh(range(len(df)), df["wsd"], color="#4C72B0")
    ax.set_yticks(range(len(df)))
    ax.set_yticklabels([short_label(f) for f in df["fragment"]])
    ax.invert_yaxis()
    ax.set_xlabel("Global median of per-run top‑10 WSD")
    ax.set_title(title)

    # value labels
    for i, b in enumerate(bars):
        v = float(df["wsd"].iloc[i])
        ax.text(
            b.get_width() + 0.002,
            b.get_y() + b.get_height() / 2,
            f"{v:.3f}",
            va="center",
            ha="left",
            fontsize=7,
        )

    out_path = out_dir / "Fig1_top_fragments_wsd_bar"
    save_figure(fig, out_path)
    plt.close(fig)
    return out_path


def plot_heatmap_fragment_by_layer(
    df_layer: pd.DataFrame,
    out_dir: Path,
    top_from_frag_df: Optional[pd.DataFrame] = None,
    top_n_heatmap: int = 25,
    title: str = "WSD (median top‑10) by fragment × layer",
) -> Optional[Path]:
    if df_layer is None or df_layer.empty:
        print("[warn] no per_fragment_layer data; skipping Fig 2")
        return None
    ycol = guess_wsd_layer_col(df_layer)
    if not ycol:
        print("[warn] cannot find a per-layer WSD column; skipping Fig 2")
        return None

    # Pick top N fragments (by global difficulty) if provided; else pick the most complete ones
    if top_from_frag_df is not None and not top_from_frag_df.empty:
        gcol = guess_wsd_global_col(top_from_frag_df)
        if gcol:
            top_frags = (
                top_from_frag_df.sort_values(gcol, ascending=False)
                .head(top_n_heatmap)["fragment"]
                .tolist()
            )
        else:
            top_frags = (
                df_layer["fragment"].value_counts().head(top_n_heatmap).index.tolist()
            )
    else:
        top_frags = (
            df_layer["fragment"].value_counts().head(top_n_heatmap).index.tolist()
        )

    sub = df_layer[df_layer["fragment"].isin(top_frags)].copy()
    # pivot
    piv = sub.pivot(index="fragment", columns="layer", values=ycol)
    # order fragments using global difficulty proxy: row max
    order = np.argsort(-np.nan_to_num(piv.max(axis=1).values, nan=-1.0))
    piv = piv.iloc[order]

    # figure
    nrows = len(piv)
    fig_h = 60 + 6.5 * nrows  # mm
    fig, ax = plt.subplots(figsize=fig_size(85, fig_h))
    im = ax.imshow(np.nan_to_num(piv.values, nan=np.nan), aspect="auto", cmap="viridis")

    # axis labels
    ax.set_yticks(range(len(piv.index)))
    ax.set_yticklabels([short_label(s) for s in piv.index], fontsize=7)
    ax.set_xticks(range(len(piv.columns)))
    ax.set_xticklabels(piv.columns)
    ax.set_xlabel("Layer")
    ax.set_title(title)

    # colorbar
    cbar = fig.colorbar(im, ax=ax, fraction=0.025, pad=0.02)
    cbar.ax.set_ylabel("Median of per‑run top‑10 WSD", rotation=90, va="center")

    out_path = out_dir / "Fig2_wsd_heatmap_fragment_by_layer"
    save_figure(fig, out_path)
    plt.close(fig)
    return out_path


def plot_layer_profiles(
    df_layer: pd.DataFrame,
    out_dir: Path,
    selected_fragments: List[str],
    title_prefix: str = "WSD vs layer",
) -> List[Path]:
    paths = []
    if df_layer is None or df_layer.empty or not selected_fragments:
        print("[warn] no per_fragment_layer data or no fragments; skipping Fig 3")
        return paths
    ycol = guess_wsd_layer_col(df_layer)
    if not ycol:
        print("[warn] cannot find a per-layer WSD column; skipping Fig 3")
        return paths

    for frag in selected_fragments:
        sub = df_layer[df_layer["fragment"] == frag].copy()
        if sub.empty:
            continue
        sub = sub.sort_values("layer")
        fig, ax = plt.subplots(figsize=fig_size(85, 55))
        ax.plot(sub["layer"], sub[ycol], marker="o", linewidth=1.2)
        ax.set_xlabel("Layer")
        ax.set_ylabel("Median of per‑run top‑10 WSD")
        ax.set_title(f"{title_prefix}: {frag}")
        out_path = out_dir / f"Fig3_layer_profile_{frag.replace(':', '_')}"
        save_figure(fig, out_path)
        plt.close(fig)
        paths.append(out_path)
    return paths


def plot_wsd_by_sae(
    df_sae: pd.DataFrame,
    out_dir: Path,
    selected_fragments: List[str],
    title_prefix: str = "WSD by SAE",
) -> List[Path]:
    paths = []
    if df_sae is None or df_sae.empty or not selected_fragments:
        print("[warn] no per_fragment_sae data or no fragments; skipping Fig 4")
        return paths

    # guess a median column for WSD by SAE
    ycols = [
        c
        for c in df_sae.columns
        if c.startswith("wsd") and ("__median" in c or c.endswith("_median"))
    ]
    if not ycols:
        print(
            "[warn] cannot find WSD median columns in per_fragment_sae; skipping Fig 4"
        )
        return paths
    # choose the top-10‑median flavor if present
    prefer = [c for c in ycols if "top10" in c]
    ycol = prefer[0] if prefer else ycols[0]

    # SAE order: natsort-ish
    def sae_key(s: str) -> Tuple[int, str]:
        # e.g., relu_4x_abcd → (4, rest)
        import re

        m = re.search(r"relu_(\d+)x", s or "")
        return (int(m.group(1)) if m else 0, s or "")

    for frag in selected_fragments:
        sub = df_sae[df_sae["fragment"] == frag].copy()
        if sub.empty:
            continue
        sub = sub.sort_values("sae_tag", key=lambda s: s.map(lambda x: sae_key(x)))
        x = np.arange(len(sub))
        fig, ax = plt.subplots(figsize=fig_size(85, 55))
        ax.bar(x, sub[ycol], width=0.6, color="#4C72B0")
        ax.set_xticks(x)
        ax.set_xticklabels(
            [s.replace("relu_", "").replace("_", "\n") for s in sub["sae_tag"]],
            fontsize=7,
        )
        ax.set_ylabel("Median of per‑run top‑10 WSD")
        ax.set_title(f"{title_prefix}: {frag}")
        out_path = out_dir / f"Fig4_wsd_by_sae_{frag.replace(':', '_')}"
        save_figure(fig, out_path)
        plt.close(fig)
        paths.append(out_path)
    return paths


# =========================
# ========= MAIN ==========
# =========================


def main():
    parser = argparse.ArgumentParser(description="Make Nature-style WSD plots.")
    parser.add_argument(
        "--summary-dir", required=True, help="Directory with *wsd*.csv summaries"
    )
    parser.add_argument("--out-dir", required=True, help="Where to write figures")
    parser.add_argument("--top", type=int, default=15, help="Top-N fragments in Fig 1")
    parser.add_argument(
        "--heatmap-top", type=int, default=25, help="Top-N fragments in heatmap"
    )
    parser.add_argument(
        "--profile-top",
        type=int,
        default=8,
        help="How many fragments to show as layer profiles & SAE bars",
    )
    args = parser.parse_args()

    set_nature_style()

    root = Path(args.summary_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Load data
    df_frag = load_csv_safe(root / "fragment_difficulty_wsd.csv")
    df_layer = load_csv_safe(root / "per_fragment_layer_wsd.csv")
    df_sae = load_csv_safe(root / "per_fragment_sae_wsd.csv")

    # Fig 1: top fragments (global WSD)
    fig1_path = plot_top_fragments_bar(df_frag, out_dir, top_n=args.top)

    # Choose fragments for later figures
    selected = []
    if df_frag is not None and not df_frag.empty:
        ycol = guess_wsd_global_col(df_frag)
        if ycol:
            selected = (
                df_frag.sort_values(ycol, ascending=False)
                .head(args.profile_top)["fragment"]
                .tolist()
            )
    if not selected and df_layer is not None:
        # fallback: most frequent fragments in per-layer table
        selected = (
            df_layer["fragment"].value_counts().head(args.profile_top).index.tolist()
        )

    # Fig 2: heatmap fragment × layer
    fig2_path = plot_heatmap_fragment_by_layer(
        df_layer, out_dir, top_from_frag_df=df_frag, top_n_heatmap=args.heatmap_top
    )

    # Fig 3: layer profiles (one per selected fragment)
    fig3_paths = plot_layer_profiles(df_layer, out_dir, selected_fragments=selected)

    # Fig 4: SAE effects (one per selected fragment)
    fig4_paths = plot_wsd_by_sae(df_sae, out_dir, selected_fragments=selected)

    # Console summary
    print("Wrote figures:")
    for p in [fig1_path, fig2_path] + fig3_paths + fig4_paths:
        if p is not None:
            for ext in (".pdf", ".svg", ".png"):
                q = p.with_suffix(ext)
                if q.exists():
                    print("  •", q)


if __name__ == "__main__":
    main()
