import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind, ttest_rel

# ---------- GLOBAL STYLE (print-friendly, clutter-free) ----------
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman"],
    "font.size": 12,
    "axes.labelsize": 12,
    "axes.titlesize": 12,
    "xtick.labelsize": 7,       # <-- remain small to prevent overlap
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "lines.linewidth": 1.8,
    "axes.linewidth": 1.2,
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "axes.grid": True,
    # "grid.axis": "y",          # ← only horizontal grid lines
    "grid.linestyle": "--",
    "grid.alpha": 0.4,
    'axes.grid.axis' : 'y',       # ← apply grid only along y
    'axes.grid.which': 'major',   # (optional) majors only
    'grid.linestyle' : '--',
})

# Okabe–Ito palette (color-blind safe, prints well)
OKABE = ["#0072B2", "#E69F00", "#009E73"]

# ---------- HELPER UTILITIES ----------
def _format_labels(cfg_series: pd.Series) -> list[str]:
    """Make tick labels human-legible."""
    mapping = {
        # canonical names
        "FAST_TRACES_BOTH_UPDATES":   "Fast",
        "MEDIUM_TRACES_BOTH_UPDATES": "Medium",
        "SLOW_TRACES_BOTH_UPDATES":   "Slow",
        "DUAL_TRACES_BOTH_UPDATES":   "Dual",
        "DUAL_TRACES_FAST_ONLY":      "Fast-only",
        "DUAL_TRACES_SLOW_ONLY":      "Slow-only",
        "DUAL_TRACES_NO_UPDATES":     "Frozen",
        "DELTA_RULE":                 "Delta Rule",
        "HEBBIAN":                    "Hebbian",
        # simplified synonyms sometimes used in Zenodo CSVs
        "FAST_ONLY":   "Fast",
        "MEDIUM_ONLY": "Medium",
        "SLOW_ONLY":   "Slow",
        "DUAL":        "Dual",
    }
    return [mapping.get(x, x) for x in cfg_series]

def _ci(series: pd.Series) -> float:
    """Approx. 95 % CI half-width (±1.96 · SEM)."""
    sem = series.std(ddof=1) / np.sqrt(len(series))
    return 1.96 * sem

def add_star(ax, x, y_top, p, pad=0.01):
    """Place *, **, *** or leave blank."""
    
    stars = "*" * (1 + (p < .01) + (p < .001))

    if p >= .05:
        stars = "ns"
    ax.text(x, y_top + pad, stars,
            ha="center", va="bottom", fontsize=12, clip_on=False)


def compute_paired_effect(a: pd.Series, b: pd.Series) -> tuple[float, float]:
    """Return Cohen's dz and mean difference for paired samples (b - a).
    dz = mean(diff) / sd(diff), ddof=1. Returns (nan, nan) if not enough data or zero SD.
    """
    try:
        diffs = (b - a).astype(float)
        n = diffs.shape[0]
        if n < 2:
            return np.nan, np.nan
        sd = diffs.std(ddof=1)
        if sd == 0:
            return np.nan, diffs.mean()
        return diffs.mean() / sd, diffs.mean()
    except Exception:
        return np.nan, np.nan


def add_effect_label(ax, x, y_top, d: float, pad=0.03):
    """Annotate Cohen's d near the significance marker with sign to show direction."""
    if np.isnan(d):
        return
    label = f"d={d:+.2f}"
    ax.text(x, y_top + pad, label, ha="center", va="bottom", fontsize=8, color="#444444", clip_on=False)

thing = False

# ---------- MAIN PLOTTING FUNCTION ----------
def plot_ablation_results(csv_file: str,
                          mcmaze: bool = False,
                          zenodo: bool = False,
                          out_dir: str | Path = "./figures") -> None:
    global thing
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    

    df = pd.read_csv(csv_file)
    print(f"Unique configs in CSV: {df['config'].unique()}")

    # -------- FIG 1: Trace-timescale ----------
    if mcmaze:
        trace_order = [
            "FAST_TRACES_BOTH_UPDATES",
            "SLOW_TRACES_BOTH_UPDATES",
            "MEDIUM_TRACES_BOTH_UPDATES",
            "DUAL_TRACES_BOTH_UPDATES",
            "DUAL_TRACES_FAST_ONLY",
            "DUAL_TRACES_SLOW_ONLY",
            "DUAL_TRACES_NO_UPDATES"
            "DELTA_RULE",
            "HEBBIAN"
            "FIXED_LR",
            "META_LEARNING"
            "FEEDFORWARD","RECURRENT"
            "WITH_RMS","PARTIAL_RMS","WITHOUT_RMS"

        ]
        namedict = {
            "trace" : [
                "DUAL_TRACES_BOTH_UPDATES",
                "SLOW_TRACES_BOTH_UPDATES",
                "MEDIUM_TRACES_BOTH_UPDATES",
                "FAST_TRACES_BOTH_UPDATES",
            ],
            "update_mechanism" : ["DUAL_TRACES_FAST_ONLY",
            "DUAL_TRACES_SLOW_ONLY",
            "DUAL_TRACES_NO_UPDATES"],
            "hebbian" : ["HEBBIAN",
            "DELTA_RULE"],
            "recurrent" : ["RECURRENT","FEEDFORWARD"],
            "meta" : ["META_LEARNING",
            "FIXED_LR"],
            "rms" : ["WITH_RMS",
            "PARTIAL_RMS",
            "WITHOUT_RMS"]
        }
        
        csv_to_nm = {
            "mcmaze_ablation4.txt.csv" : "trace",
            "mcmaze_ablation6.txt.csv" : "update_mechanism",
            "mcmaze_ablation1.txt.csv" : "hebbian",
            "mcmaze_ablation3.txt.csv" : "recurrent",
            "mcmaze_ablation2.txt.csv" : "meta",
            "mcmaze_ablation5.txt.csv" : "rms"
        }
        print(csv_file)
        # Custom logic for the 'update_mechanism' plot
        if csv_to_nm.get(csv_file) == "update_mechanism":
            # Load the DUAL data from its corresponding CSV
            dual_df = pd.read_csv("mcmaze_ablation4.txt.csv")
            dual_data = dual_df[dual_df['config'] == 'DUAL_TRACES_BOTH_UPDATES']
            
            # Combine with the current dataframe
            df = pd.concat([df, dual_data], ignore_index=True)
            
            # Define the order and the extra comparisons
            plot_configs = [
                "DUAL_TRACES_BOTH_UPDATES",
                "DUAL_TRACES_FAST_ONLY",
                "DUAL_TRACES_SLOW_ONLY",
                "DUAL_TRACES_NO_UPDATES"
            ]
            extra_comparisons = [
                ("DUAL_TRACES_BOTH_UPDATES", "DUAL_TRACES_FAST_ONLY"),
                ("DUAL_TRACES_BOTH_UPDATES", "DUAL_TRACES_SLOW_ONLY"),
                ("DUAL_TRACES_BOTH_UPDATES", "DUAL_TRACES_NO_UPDATES"),
                ("DUAL_TRACES_FAST_ONLY", "DUAL_TRACES_SLOW_ONLY"),
            ]
            baseline_config = "DUAL_TRACES_BOTH_UPDATES"
        else:
            plot_configs = namedict[csv_to_nm[csv_file]]
            extra_comparisons = None
            baseline_config = plot_configs[0]

        _plot_group(
            df, plot_configs,
            baseline=baseline_config,
            fname=out_dir / f"mcmaze_{csv_to_nm[csv_file]}_ablation.png",
            extra_pairs=extra_comparisons,
        )
    elif zenodo:
        # Normalize Zenodo config names to canonical where possible
        zenodo_synonyms = {
            # trace synonyms
            "FAST_ONLY":   "FAST_TRACES_BOTH_UPDATES",
            "SLOW_ONLY":   "SLOW_TRACES_BOTH_UPDATES",
            "MEDIUM_ONLY": "MEDIUM_TRACES_BOTH_UPDATES",
            "DUAL":        "DUAL_TRACES_BOTH_UPDATES",
        }
        if df["config"].isin(zenodo_synonyms.keys()).any():
            df["config"] = df["config"].replace(zenodo_synonyms)

        trace_order = [
            "FAST_TRACES_BOTH_UPDATES",
            "SLOW_TRACES_BOTH_UPDATES",
            "MEDIUM_TRACES_BOTH_UPDATES",
            "DUAL_TRACES_BOTH_UPDATES",
            "DUAL_TRACES_FAST_ONLY",
            "DUAL_TRACES_SLOW_ONLY",
            "DUAL_TRACES_NO_UPDATES"
            "DELTA_RULE",
            "HEBBIAN"
            "FIXED_LR",
            "META_LEARNING"
            "FEEDFORWARD","RECURRENT"
            "WITH_RMS","PARTIAL_RMS","WITHOUT_RMS"

        ]
        namedict = {
            "trace" : [
                "FAST_TRACES_BOTH_UPDATES",
                "SLOW_TRACES_BOTH_UPDATES",
                "MEDIUM_TRACES_BOTH_UPDATES",
                "DUAL_TRACES_BOTH_UPDATES",
            ],
            "update_mechanism" : [
                "DUAL_TRACES_BOTH_UPDATES",
                "DUAL_TRACES_FAST_ONLY",
                "DUAL_TRACES_SLOW_ONLY",
                "DUAL_TRACES_NO_UPDATES"
            ],
            "hebbian" : ["HEBBIAN",
            "DELTA_RULE"],
            "recurrent" : ["RECURRENT","FEEDFORWARD"],
            "meta" : ["META_LEARNING",
            "FIXED_LR"],
            "rms" : ["WITH_RMS",
            "PARTIAL_RMS",
            "WITHOUT_RMS"]
        }
        csv_to_nm = {
            "zenodo_ablation4_1.txt.csv" : "trace",               # expected FAST/SLOW/MEDIUM/DUAL *_BOTH_UPDATES
            "zenodo_ablation4_2.txt.csv" : "update_mechanism",   # expected FAST_ONLY/SLOW_ONLY/NO_UPDATES
            "zenodo_ablation1.txt.csv" : "hebbian",
            "zenodo_ablation3.txt.csv" : "recurrent",
            "zenodo_ablation2.txt.csv" : "meta",
            "zenodo_ablation5.txt.csv" : "rms"
        }

        # If the file content doesn't match the expected mapping, auto-detect by present configs
        present = set(df["config"].unique().tolist())
        update_keys = {"DUAL_TRACES_FAST_ONLY", "DUAL_TRACES_SLOW_ONLY", "DUAL_TRACES_NO_UPDATES"}
        trace_keys = {"FAST_TRACES_BOTH_UPDATES", "SLOW_TRACES_BOTH_UPDATES", "MEDIUM_TRACES_BOTH_UPDATES", "DUAL_TRACES_BOTH_UPDATES"}
        nm = csv_to_nm.get(csv_file)
        if nm is None:
            # unknown file name; try to detect
            nm = "update_mechanism" if present & update_keys else "trace" if present & trace_keys else None
        else:
            # Validate; if mismatch, override using detection
            if nm == "trace" and present & update_keys and not (present & trace_keys):
                nm = "update_mechanism"
            elif nm == "update_mechanism" and present & trace_keys and not (present & update_keys):
                nm = "trace"

        # Custom logic for the 'update_mechanism' plot
        if nm == "update_mechanism":
            def _load_dual_baseline() -> pd.DataFrame:
                candidates = [
                    "zenodo_ablation4_1.txt.csv",
                    "zenodo_ablation4_2.txt.csv",
                ]
                wanted = {"DUAL_TRACES_BOTH_UPDATES", "DUAL"}
                rows = []
                for cand in candidates:
                    try:
                        tmp = pd.read_csv(cand)
                        # normalize any synonyms in candidate file
                        if tmp["config"].isin(zenodo_synonyms.keys()).any():
                            tmp["config"] = tmp["config"].replace(zenodo_synonyms)
                        # accept either canonical or synonym
                        hit = tmp[tmp["config"].isin(["DUAL_TRACES_BOTH_UPDATES", "DUAL"])]
                        if not hit.empty:
                            # ensure canonical name
                            hit = hit.copy()
                            hit["config"] = hit["config"].replace({"DUAL": "DUAL_TRACES_BOTH_UPDATES"})
                            rows.append(hit)
                            print(f"Loaded DUAL baseline from {cand} with {hit.shape[0]} rows.")
                            break
                        else:
                            print(f"No DUAL baseline found in {cand}. Available configs: {tmp['config'].unique()}")
                    except FileNotFoundError:
                        print(f"Warning: {cand} not found while searching for DUAL baseline.")
                return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()

            # Ensure DUAL baseline exists; if not, load and append it
            if "DUAL_TRACES_BOTH_UPDATES" not in present:
                dual_data = _load_dual_baseline()
                if not dual_data.empty:
                    df = pd.concat([df, dual_data], ignore_index=True)
                    present = set(df["config"].unique().tolist())
                else:
                    print("Failed to load DUAL baseline from any 4_* CSV; proceeding without it.")

            plot_configs = [
                "DUAL_TRACES_BOTH_UPDATES",
                "DUAL_TRACES_FAST_ONLY",
                "DUAL_TRACES_SLOW_ONLY",
                "DUAL_TRACES_NO_UPDATES"
            ]
            # Only keep those present
            plot_configs = [c for c in plot_configs if c in present]
            extra_comparisons = [
                ("DUAL_TRACES_BOTH_UPDATES", "DUAL_TRACES_FAST_ONLY"),
                ("DUAL_TRACES_BOTH_UPDATES", "DUAL_TRACES_SLOW_ONLY"),
                ("DUAL_TRACES_BOTH_UPDATES", "DUAL_TRACES_NO_UPDATES"),
                ("DUAL_TRACES_FAST_ONLY", "DUAL_TRACES_SLOW_ONLY"),
            ]
            baseline_config = "DUAL_TRACES_BOTH_UPDATES" if "DUAL_TRACES_BOTH_UPDATES" in present else plot_configs[0] if plot_configs else "DUAL_TRACES_BOTH_UPDATES"
        elif nm in namedict:
            # Generic path for hebbian, recurrent, meta, rms, and trace
            desired = namedict[nm]
            present_list = df["config"].unique().tolist()
            plot_configs = [c for c in desired if c in present_list]
            extra_comparisons = None
            baseline_config = plot_configs[0] if plot_configs else desired[0]
        else:
            print(f"No matching configs found for {csv_file} in named list {nm if nm else 'unknown'}; skipping plot.")
            return

        _plot_group(
            df, plot_configs,
            baseline=baseline_config,
            fname=out_dir / f"zenodo_{nm}_ablation.png",
            extra_pairs=extra_comparisons,
        )



def _plot_group(df: pd.DataFrame, configs: list[str], *, 
                baseline: str,
                fname: Path,
                extra_pairs: list[tuple[str, str]] | None = None) -> None:
    """Reusable grouped-bar plot with CI and paired sig. stars (per session)."""

    # Expect a 'session' column in df
    if 'session' not in df.columns:
        raise ValueError("CSV must include a 'session' column for paired tests.")

    sub = df[df["config"].isin(configs)].copy()
    sub["config"] = pd.Categorical(sub["config"], configs)
    sub = sub.sort_values(["config", "session"])  # stable order

    # Compute mean and CI for X and Y correlations (aggregate across sessions)
    stats = (sub.groupby("config", observed=True)
                   .agg(X_mean=('X', 'mean'), X_ci=('X', _ci),
                        Y_mean=('Y', 'mean'), Y_ci=('Y', _ci))
                   .reset_index())

    print(stats)

    # Setup figure
    fig, ax = plt.subplots(figsize=(3.4, 2.8)) # Adjusted figure size for subfigure panels
    width = 0.35 # Width of each bar

    # Ensure positions match the number of labels actually present
    labels = _format_labels(stats["config"])  
    xpos = np.arange(len(labels))

    bars_x = ax.bar(
        xpos - width/2,
        stats["X_mean"],
        yerr=stats["X_ci"],
        width=width,
        capsize=4,
        color=OKABE[0],
        edgecolor="black",
        label="X Correlation"
    )

    bars_y = ax.bar(
        xpos + width/2,
        stats["Y_mean"],
        yerr=stats["Y_ci"],
        width=width,
        capsize=4,
        color=OKABE[1],
        edgecolor="black",
        label="Y Correlation"
    )

    # Axis cosmetics
    ax.set_xticks(xpos)
    ax.set_xticklabels(labels)
    ax.set_ylabel("Pearson R (mean ± 95 % CI)")
    ax.set_ylim(0, 1.05) 

    # ---------- Paired significance tests by session ----------
    # Build per-session wide tables for baseline vs each other config
    # Use AVG as primary test; also annotate X and Y tests if desired
    idx_map = {cfg: i for i, cfg in enumerate(stats["config"]) }

    if baseline in sub["config"].values:
        # For each target config present
        for cfg in stats["config"]:
            if cfg == baseline:
                continue

            # Pivot to paired vectors per session
            pair_df = sub[sub["config"].isin([baseline, cfg])]
            # Ensure one row per session per config (if duplicates, mean them)
            pair_df = (pair_df
                       .groupby(["session", "config"], observed=True, as_index=False)
                       .agg(X=('X','mean'), Y=('Y','mean'), AVG=('AVG','mean')))
            wide = pair_df.pivot(index='session', columns='config', values=['X','Y','AVG'])

            # X
            try:
                ax_a = wide['X'][baseline].dropna(); ax_b = wide['X'][cfg].dropna()
                cs = ax_a.index.intersection(ax_b.index)
                ax_a, ax_b = ax_a.loc[cs], ax_b.loc[cs]
                p_x = ttest_rel(ax_a, ax_b).pvalue if len(cs) >= 2 else np.nan
                d_x, _ = compute_paired_effect(ax_a, ax_b) if len(cs) >= 2 else (np.nan, np.nan)
            except KeyError:
                p_x = np.nan
                d_x = np.nan

            # Y
            try:
                ay_a = wide['Y'][baseline].dropna(); ay_b = wide['Y'][cfg].dropna()
                cs = ay_a.index.intersection(ay_b.index)
                ay_a, ay_b = ay_a.loc[cs], ay_b.loc[cs]
                p_y = ttest_rel(ay_a, ay_b).pvalue if len(cs) >= 2 else np.nan
                d_y, _ = compute_paired_effect(ay_a, ay_b) if len(cs) >= 2 else (np.nan, np.nan)
            except KeyError:
                p_y = np.nan
                d_y = np.nan

            if cfg in idx_map:
                i = idx_map[cfg]
                # place star above each respective bar using that bar's CI
                y_top_x = stats.loc[stats["config"] == cfg, "X_mean"].iloc[0] + stats.loc[stats["config"] == cfg, "X_ci"].iloc[0]
                y_top_y = stats.loc[stats["config"] == cfg, "Y_mean"].iloc[0] + stats.loc[stats["config"] == cfg, "Y_ci"].iloc[0]
                if not np.isnan(p_x):
                    add_star(ax, xpos[i] - width/2, y_top_x, p_x)
                    # add_effect_label(ax, xpos[i] - width/2, y_top_x, d_x)
                if not np.isnan(p_y):
                    add_star(ax, xpos[i] + width/2, y_top_y, p_y)
                    # add_effect_label(ax, xpos[i] + width/2, y_top_y, d_y)

    # Optional extra comparisons (remain unpaired by design unless session pairing is defined)
    if extra_pairs:
        print(f"\n--- Extra comparisons for {fname.stem} ---")
        for pair in extra_pairs:
            cfg1, cfg2 = pair
            
            if cfg1 not in sub["config"].values or cfg2 not in sub["config"].values:
                print(f"  Skipping comparison: {cfg1} vs {cfg2} (one or both not in data)")
                continue

            # Paired by session
            pair_df = sub[sub["config"].isin([cfg1, cfg2])]
            pair_df = (pair_df
                       .groupby(["session", "config"], observed=True, as_index=False)
                       .agg(X=('X','mean'), Y=('Y','mean'), AVG=('AVG','mean')))
            wide = pair_df.pivot(index='session', columns='config', values=['X','Y','AVG'])
            # X
            try:
                a = wide['X'][cfg1].dropna(); b = wide['X'][cfg2].dropna()
                cs = a.index.intersection(b.index)
                a, b = a.loc[cs], b.loc[cs]
                p_x = ttest_rel(a, b).pvalue if len(cs) >= 2 else np.nan
                d_x, md_x = compute_paired_effect(a, b) if len(cs) >= 2 else (np.nan, np.nan)
            except KeyError:
                p_x, d_x = np.nan, np.nan
            # Y
            try:
                a = wide['Y'][cfg1].dropna(); b = wide['Y'][cfg2].dropna()
                cs = a.index.intersection(b.index)
                a, b = a.loc[cs], b.loc[cs]
                p_y = ttest_rel(a, b).pvalue if len(cs) >= 2 else np.nan
                d_y, md_y = compute_paired_effect(a, b) if len(cs) >= 2 else (np.nan, np.nan)
            except KeyError:
                p_y, d_y = np.nan, np.nan
            print(f"  {_format_labels(pd.Series([cfg1]))[0]} vs. {_format_labels(pd.Series([cfg2]))[0]}: X p={p_x if not np.isnan(p_x) else 'NA'}, d={d_x if not np.isnan(d_x) else 'NA'}; Y p={p_y if not np.isnan(p_y) else 'NA'}, d={d_y if not np.isnan(d_y) else 'NA'}")

    print(f"SAVING TO {fname.with_suffix('.pdf')}")
    fig.tight_layout()
    fig.savefig(fname.with_suffix('.pdf'), transparent=False, bbox_inches="tight", pad_inches=0.02) # Save as PDF with tight layout and padding
    plt.close(fig)


if __name__ == "__main__":
    for i in [4,6,1,2,3,5]:
        plot_ablation_results(csv_file=f"mcmaze_ablation{i}.txt.csv", mcmaze=True)
    for i in ["1","2","3","4_1","4_2","5"]:
        plot_ablation_results(csv_file=f"zenodo_ablation{i}.txt.csv", zenodo=True)
