from __future__ import annotations
import argparse, pickle
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, Tuple

# ──────────────────────────────────────────────────────────────────────────────
# Style / constants
# ──────────────────────────────────────────────────────────────────────────────
PASTEL_ORANGE = '#FDBE87'
PASTEL_RED    = '#FF9999'
PASTEL_GREEN  = '#99CC99'

model_names = {
    "mamba":       "Mamba-1.4B",
    "falcon3":     "Falcon3-1B",
    "llama3.2-1B": "Llama3.2-1B",
    "gemma":       "Gemma-2B",
    "zamba":       "Zamba2-1.2B",
}
plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 18,
    "axes.labelsize": 14,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "legend.fontsize": 14,
})

# ──────────────────────────────────────────────────────────────────────────────
# Figure 1  (IoU vs threshold, mean line + per-model symbols)
# ──────────────────────────────────────────────────────────────────────────────
def plot_figure1(cache_root: Path, out_dir: Path) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    files = sorted((cache_root / "figure1").glob("*_iou_data.npz"))
    if not files:
        print("[Fig-1] - no cache files found"); return

    rows = []
    for f in files:
        d      = np.load(f)
        model  = f.stem.replace("_iou_data", "")
        for x, y in zip(d["x"], d["y"]):
            rows.append((model, int(x), float(y)))
    df = pd.DataFrame(rows, columns=["model", "xpos", "iou"])

    thresholds = np.sort(df["xpos"].unique())
    mean_iou   = df.groupby("xpos")["iou"].mean().reindex(thresholds)

    markers = ['o','s','D','^','v','P','X','*','+']
    marker_map = {m: markers[i % len(markers)] for i, m in enumerate(df["model"].unique())}

    fig, ax = plt.subplots(figsize=(9, 4))

    # Mean line (no error bars)
    ax.plot(thresholds, mean_iou, '-',
            color=PASTEL_ORANGE, linewidth=2, label='Mean across models')

    # Scatter the individual‑model points
    for model, d in df.groupby('model'):        
        # Map threshold → x‑index
        ax.scatter(d["xpos"], d["iou"],
                   marker=marker_map[model],
                   s=25,
                   edgecolor='darkgray',
                   facecolor='none',   # hollow for clarity
                   linewidth=0.8,
                   label=model_names.get(model, model))

    ax.set_xlabel('Top % attribution threshold')
    ax.set_ylabel('Intersection over Union (IoU)')
    ax.set_ylim(0, 1)
    ax.set_xticks(d["xpos"])
    ax.set_xticklabels(thresholds)
    ax.grid(True, ls=':')

    # Move legend outside the plot
    ax.legend(loc='center left', fontsize=14, frameon=False, bbox_to_anchor=(1, 0.5), ncol=1)

    fig.tight_layout()
    fig.savefig(out_dir / 'figure1_iou_vs_threshold.png', bbox_inches='tight')
    fig.savefig(out_dir / 'figure1_iou_vs_threshold.pdf', dpi=1200, bbox_inches='tight')
    plt.close(fig)

# ──────────────────────────────────────────────────────────────────────────────
# Figure 3  (per-layer curves + horizontal AUC bars)
# ──────────────────────────────────────────────────────────────────────────────
def _annotate_significance(ax, x_pos: float, bar_height: float, p_val: float):
    "Draw an asterisk above bar if p < 0.05."
    if p_val < 0.05:
        text = "∗"
        if p_val < 0.01:
            text = "∗∗"
        if p_val < 0.001:
            text = "∗∗∗"
        ax.text(x_pos, bar_height * 0.9, text, ha="center", va="bottom", fontsize=14)

def plot_figure3(cache_root: Path, out_dir: Path) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    for cache_file in (cache_root/"figure3").glob("figure3_*_caches.pkl"):
        model = cache_file.stem.split("_")[1]          # figure3_<model>_caches.pkl
        with open(cache_file, "rb") as f:
            cache: Dict = pickle.load(f)

        curves_cache : Dict[int,Tuple[pd.DataFrame,pd.DataFrame]] = cache["curves_cache"]
        aucs_cache   : Dict[int,Dict[str,Tuple[float,float]]]     = cache["aucs_cache"]
        pvals_cache  : Dict[int,float]                            = {
            k : v["p-val"].iloc[0] for k,v in cache["t_res_cache"].items()
        }
        global_ymax, global_auc_ymax = cache["global_ymax"], cache["global_auc_ymax"]
        thresholds = next(iter(curves_cache.values()))[0].index.values

        for layer in curves_cache:
            curves_b, curves_l = curves_cache[layer]
            auc_stats  = aucs_cache[layer]
            p_val  = pvals_cache[layer]

            fig, (ax_line, ax_bar) = plt.subplots(
                ncols=2, figsize=(7.5, 4), gridspec_kw={"width_ratios": [3, 1]}
            )
            ax_line.errorbar(
                thresholds, curves_b["mean"], yerr=curves_b["sem"],
                color=PASTEL_RED, label="Brain alignment", lw=1,
                ecolor="black", elinewidth=1.2, capsize=4, capthick=1.2
            )
            ax_line.errorbar(
                thresholds, curves_l["mean"], yerr=curves_l["sem"],
                color=PASTEL_GREEN, label="Next‑word prediction", lw=1,
                ecolor="black", elinewidth=1.2, capsize=4, capthick=1.2
            )

            ax_line.set_xlabel("Top % attribution", fontsize=18)
            ax_line.set_ylabel("# unique important words", fontsize=18)
            ax_line.grid(ls=":")
            ax_line.legend(loc="upper left", frameon=False)
            ax_line.set_ylim(0, global_ymax)
            ax_line.tick_params(axis='x')
            ax_line.tick_params(axis='y')

            bars = ax_bar.bar(
                [0, 1],
                [auc_stats["brain"][0], auc_stats["lm"][0]],
                yerr=[auc_stats["brain"][1], auc_stats["lm"][1]],
                color=[PASTEL_RED, PASTEL_GREEN],
                width=0.6, linewidth=1.2, capsize=4
            )
            ax_bar.set_ylim(0, global_auc_ymax)
            ax_bar.set_xticks([0, 1], ["BA", "NWP"], fontsize=18)
            ax_bar.set_ylabel("AUC", fontsize=18)
            ax_bar.grid(True, axis="y", ls=":")

            # Annotate significance
            higher_idx = int(auc_stats["brain"][0] < auc_stats["lm"][0])  # 0 if brain higher else 1
            _annotate_significance(ax_bar, higher_idx, global_auc_ymax, p_val)

            fig.tight_layout()
            fig.savefig(out_dir / f"figure3_{model}_layer{layer}.png")
            fig.savefig(out_dir / f"figure3_{model}_layer{layer}.pdf", dpi=1200)
            plt.close(fig)

# ──────────────────────────────────────────────────────────────────────────────
# Figure 4  (Early/Middle/Late + AUC bars)
# ──────────────────────────────────────────────────────────────────────────────
def plot_figure4(cache_root: Path, out_dir: Path) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    ORD = {0:"Early",1:"Middle",2:"Late"}
    for cache_file in (cache_root/"figure4").glob("figure4_*_caches.pkl"):
        model = cache_file.stem.split("_")[1]
        with open(cache_file,"rb") as f:
            cache : Dict = pickle.load(f)
        curves_cache : Dict[int,Tuple[pd.DataFrame,pd.DataFrame]] = cache["curves_cache"]
        aucs_cache   : Dict[int,Dict[str,Tuple[float,float]]]     = cache["aucs_cache"]
        pval_cache   : Dict[int,float]                            = cache["p_val_cache"]
        global_ymax, global_auc_ymax = cache["global_ymax"], cache["global_auc_ymax"]
        thresholds = next(iter(curves_cache.values()))[0].index.values

        for ord_idx in curves_cache:
            curves_brain, curves_lm  = curves_cache[ord_idx]
            auc_stats   = aucs_cache[ord_idx]
            p_val   = pval_cache[ord_idx]

            fig, (ax_line, ax_bar) = plt.subplots(
                ncols=2, figsize=(7.5, 4), gridspec_kw={"width_ratios": [3, 1]})

            ax_line.errorbar(thresholds, curves_brain["mean"], yerr=curves_brain["sem"],
                            color=PASTEL_RED,   label="Brain alignment", lw=1.5,
                            ecolor="black", elinewidth=1.2, capsize=4, capthick=1.2)
            ax_line.errorbar(thresholds, curves_lm["mean"],   yerr=curves_lm["sem"],
                            color=PASTEL_GREEN, label="Next‑word prediction", lw=1.5,
                            ecolor="black", elinewidth=1.2, capsize=4, capthick=1.2)


            ax_line.set_xlabel("Top % attribution", fontsize=18)
            ax_line.set_ylabel("# unique important words", fontsize=18)
            ax_line.grid(ls=":")
            ax_line.legend(loc="upper left", frameon=False)
            ax_line.set_ylim(0, global_ymax)

            # Plot AUC bars
            bars = ax_bar.bar([0, 1],
                            [auc_stats["brain"][0], auc_stats["lm"][0]],
                            yerr=[auc_stats["brain"][1], auc_stats["lm"][1]],
                            color=[PASTEL_RED, PASTEL_GREEN], width=0.6,
                            linewidth=1.2, capsize=4)
            ax_bar.set_ylim(0, global_auc_ymax)
            ax_bar.set_xticks([0, 1], ["BA", "NWP"], fontsize=18)
            ax_bar.set_ylabel("AUC", fontsize=18)
            ax_bar.grid(True, axis="y", ls=":")

            higher_idx = int(auc_stats["brain"][0] < auc_stats["lm"][0])
            _annotate_significance(ax_bar, higher_idx,
                        global_auc_ymax, p_val)

            fig.tight_layout()
            fig.savefig(out_dir / f"figure4_{ORD[ord_idx]}_layer.png")
            fig.savefig(out_dir / f"figure4_{ORD[ord_idx]}_layer.pdf", dpi=1200)
            plt.close(fig)

# ──────────────────────────────────────────────────────────────────────────────
# Figure 5  (feature analysis)
# ──────────────────────────────────────────────────────────────────────────────
def plot_figure5(summary_dir: Path,
                 out_dir: Path,
                 legend_on_first: bool = True) -> None:
    """
    Parameters
    ----------
    summary_dir : Path
        Folder that contains files named  'thr-<T>_summary.pkl'.
    out_dir : Path
        Where to save the regenerated PNG/PDF plots.
    legend_on_first : bool
        Show the legend only for the lowest threshold (tidier for multi-panel
        layouts).  Set to False if you always want a legend.
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    # Find all summary files, e.g.  thr-10_summary.pkl, thr-60_summary.pkl …
    for summ_path in sorted(summary_dir.glob("thr-*_*summary.pkl")):
        with summ_path.open("rb") as f:
            payload = pickle.load(f)

        thresh  = payload["threshold"]            # int, e.g. 10
        feats   = payload["feature_names"]        # list[str]
        bars    = np.asarray(payload["bars"])     # shape (F, 3)
        sems    = np.asarray(payload["sems"])     # shape (F, 3)

        # Plotting
        x      = np.arange(len(feats))
        width  = 0.25

        fig, ax = plt.subplots(figsize=(2 * len(feats), 4))

        ax.bar(x - width, bars[:, 0], width,
               yerr=sems[:, 0], capsize=3,
               color=PASTEL_RED,    label="Brain alignment")

        ax.bar(x,          bars[:, 1], width,
               yerr=sems[:, 1], capsize=3,
               color=PASTEL_GREEN,  label="Next-word prediction")

        ax.bar(x + width,  bars[:, 2], width,
               yerr=sems[:, 2], capsize=3,
               color=PASTEL_ORANGE, label="Intersection")

        ax.set_xticks(x, feats, fontsize=16)
        ax.set_xticklabels(feats, rotation=45, ha="right")
        ax.set_ylabel("Percentage of important words")

        # Small head-room above largest bar
        ymax = max(bars.max(), ax.get_ylim()[1])
        ax.set_ylim(0, ymax + 5)

        if (not legend_on_first) or (legend_on_first and summ_path == sorted(summary_dir.glob("thr-*_*summary.pkl"))[0]):
            ax.legend(frameon=False, loc="upper left", fontsize=16)

        fig.tight_layout()
        print(f"saving to {out_dir / f'thr-{thresh}_mean_across_models.png'}")
        fig.savefig(out_dir / f"thr-{thresh}_mean_across_models.png", dpi=300)
        fig.savefig(out_dir / f"thr-{thresh}_mean_across_models.pdf", dpi=1200)
        plt.close(fig)

# ──────────────────────────────────────────────────────────────────────────────
# Figure 6  (distance distributions)
# ──────────────────────────────────────────────────────────────────────────────
def plot_figure6(cache_root: Path, out_dir: Path) -> None:
    """
    Re-creates every Figure-6 panel saved by `_figure6`:
      • layer-averaged:   figure6_<model>_thr<thr>_dists.npz
      • per-subject:      figure6_<model>_subj<id>_thr<thr>.npz
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    def _make_panel(x: np.ndarray,
                    brain: np.ndarray,
                    lm: np.ndarray,
                    inter: np.ndarray,
                    com_b: float,
                    com_l: float,
                    com_i: float,
                    save_stem:  str) -> None:
        """Draw three stacked bar-plots (BA / NWP / INT) + CoM lines."""
        dists  = [("Brain alignment",  brain,  com_b, PASTEL_RED),
                  ("Next-word prediction", lm, com_l, PASTEL_GREEN),
                  ("Intersection",        inter, com_i, PASTEL_ORANGE)]
        #y_max = np.nanmax([brain, lm, inter]) * 1.05

        fig, axes = plt.subplots(nrows=3, figsize=(10, 8), sharex=False)
        for ax, (title, y, com, col) in zip(axes, dists):
            ax.bar(x, y, color=col, label=title)
            ax.axvline(com, ls='--', color='k')
            ax.text(com + 2, 0.66, f"{com:.1f}", rotation=90)
            ax.set_ylim(0, 1)
            ax.grid(ls=':')
            ax.legend(frameon=False, loc='upper right')
        axes[1].set_ylabel("Proportion of important words")
        axes[-1].set_xlabel("Distance from most recent word")
        fig.tight_layout()

        fig.savefig(out_dir / f"{save_stem}.png")
        fig.savefig(out_dir / f"{save_stem}.pdf", dpi=1200)
        plt.close(fig)

    # Layer-averaged panels
    for f in (cache_root / "figure6").glob("figure6_*_thr*_dists.npz"):
        tokens = f.stem.split("_")
        model = tokens[1]
        thr   = tokens[2][3:]
        d     = np.load(f)
        _make_panel(d["x"],
                    d["brain"], d["lm"], d["inter"],
                    float(d["com_brain"]), float(d["com_lm"]), float(d["com_int"]),
                    save_stem  = f"figure6_{model}_thr{thr}")

    # Per-subject panels
    for f in (cache_root / "figure6").glob("figure6_*_subj*_thr*.npz"):
        # name pattern:  figure6_<model>_subj<id>_thr<thr>.npz
        tokens = f.stem.split("_")
        model = tokens[1]
        subj  = tokens[2][4:]                 # strip 'subj'
        thr   = tokens[3][3:]                 # strip 'thr'
        d     = np.load(f)
        _make_panel(d["x"],
                    d["brain"], d["lm"], d["inter"],
                    float(d["com_brain"]), float(d["com_lm"]), float(d["com_int"]),
                    save_stem  = f"figure6_{model}_subj{subj}_thr{thr}")


def main():
    ap=argparse.ArgumentParser(description="Re-generate figures from .pkl/.npz caches")
    ap.add_argument("--cache-dir",required=True,
                    help="root directory containing figure1/3/4/6 subfolders")
    ap.add_argument("--plots-dir",default="regenerated_plots",
                    help="output directory for the recreated PNGs")
    args=ap.parse_args()

    cache_root = Path(args.cache_dir).expanduser().resolve()
    plots_root = Path(args.plots_dir).expanduser().resolve()
    plots_root.mkdir(parents=True, exist_ok=True)

    plot_figure1(cache_root, plots_root/"figure1")
    plot_figure3(cache_root, plots_root/"figure3")
    plot_figure4(cache_root, plots_root/"figure4")
    plot_figure5(cache_root/"df_analysis", plots_root/"figure5")
    plot_figure6(cache_root, plots_root/"figure6")
    print("✓  All requested figures regenerated.")

if __name__ == "__main__":
    main()
