#!/usr/bin/env python3
"""
Summarize and sanity-check valence experiment outputs.

Reads the outputs produced by:
- probes_and_causality:   budget + causality sweeps
- localize:               ablation, alignment, head OV project-out/scale
- decisions_and_robustness: decision metrics (+ robustness)

Emits:
- {save_dir}/valence_report.md
- {save_dir}/summary_by_layer.csv
- {save_dir}/causality_slopes.csv              (if present)
- {save_dir}/head_ranking_L{layer}.csv         (per layer, if present)
- {save_dir}/localization_correlations_L{layer}.csv (per layer, if present)
- {save_dir}/decision_metrics_L{layer}.csv     (per layer, if present)

Usage:
  python experiments/valence/summarize_valence_results.py \
      --root_out experiments/valence_suite_all_layers \
      --save_dir experiments/valence_suite_all_layers/summary
"""

from __future__ import annotations

import argparse
import glob
import json
import math
import os
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

# -------------------------
# Small utilities
# -------------------------


def _p(path: str) -> str:
    return os.path.normpath(path)


def _exists(path: str) -> bool:
    return path is not None and os.path.isfile(path)


def _first_existing(paths: List[str]) -> Optional[str]:
    for p in paths:
        if _exists(p):
            return p
    return None


def _read_csv(path: str) -> Optional[pd.DataFrame]:
    try:
        if _exists(path):
            return pd.read_csv(path)
    except Exception as e:
        print(f"[warn] failed to read CSV {path}: {e}")
    return None


def _read_json(path: str):
    try:
        if _exists(path):
            with open(path, "r") as f:
                return json.load(f)
    except Exception as e:
        print(f"[warn] failed to read JSON {path}: {e}")
    return None


def _slope(x: np.ndarray, y: np.ndarray) -> Tuple[float, float]:
    """
    Least squares slope and R^2.
    Returns (slope, r2). If degenerate, returns (0.0, 0.0).
    """
    x = np.asarray(x, dtype=np.float64)
    y = np.asarray(y, dtype=np.float64)
    if x.size < 2 or np.allclose(x, x[0]):
        return 0.0, 0.0
    A = np.stack([x, np.ones_like(x)], axis=1)
    m, b = np.linalg.lstsq(A, y, rcond=None)[0]
    y_hat = m * x + b
    ss_res = float(np.sum((y - y_hat) ** 2))
    ss_tot = float(np.sum((y - np.mean(y)) ** 2)) + 1e-12
    r2 = 1.0 - ss_res / ss_tot
    return float(m), float(max(0.0, min(1.0, r2)))


def _pearson(x: np.ndarray, y: np.ndarray) -> float:
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    if x.size < 2 or y.size != x.size:
        return float("nan")
    if np.allclose(x, x[0]) or np.allclose(y, y[0]):
        return float("nan")
    return float(np.corrcoef(x, y)[0, 1])


def _safe_mean(x):
    x = np.asarray(x, float)
    return float(np.mean(x)) if x.size else 0.0


# -------------------------
# Column normalization (new names vs legacy)
# -------------------------


def normalize_proj_cols(df: pd.DataFrame) -> pd.DataFrame:
    """Return a copy with columns normalized to: dlogit_single/double/triple."""
    df = df.copy()
    ren = {}
    # legacy → normalized
    if "dlogit_-" in df.columns and "dlogit_single" not in df.columns:
        ren["dlogit_-"] = "dlogit_single"
    if "dlogit_=" in df.columns and "dlogit_double" not in df.columns:
        ren["dlogit_="] = "dlogit_double"
    if "dlogit_#" in df.columns and "dlogit_triple" not in df.columns:
        ren["dlogit_#"] = "dlogit_triple"
    return df.rename(columns=ren)


def normalize_bootstrap_cols(df: pd.DataFrame) -> pd.DataFrame:
    """Ensure 'metric' column uses single/double/triple values."""
    df = df.copy()
    if "metric" in df.columns:
        df["metric"] = df["metric"].replace(
            {"-": "single", "=": "double", "#": "triple"}
        )
    return df


# -------------------------
# Discovery of files
# -------------------------


def discover_layers(root_out: str) -> List[int]:
    layers = []
    for d in sorted(glob.glob(os.path.join(root_out, "L*"))):
        base = os.path.basename(d)
        if base.startswith("L"):
            try:
                layers.append(int(base[1:]))
            except:
                pass
    return sorted(set(layers))


def budget_files(root_out: str) -> Dict[str, Optional[str]]:
    return {
        "probe_by_layer": _first_existing(
            [
                _p(os.path.join(root_out, "valence_probe_by_layer.csv")),
                _p(os.path.join(root_out, "budget", "valence_probe_by_layer.csv")),
            ]
        ),
        "events_summary": _first_existing(
            [
                _p(os.path.join(root_out, "valence_events_summary.csv")),
                _p(os.path.join(root_out, "budget", "valence_events_summary.csv")),
            ]
        ),
        "examples_debug": _first_existing(
            [
                _p(os.path.join(root_out, "examples_debug.json")),
                _p(os.path.join(root_out, "budget", "examples_debug.json")),
            ]
        ),
    }


def per_layer_files(root_out: str, L: int) -> Dict[str, Optional[str]]:
    LDIR = _p(os.path.join(root_out, f"L{L}"))
    return {
        "causality": _p(
            os.path.join(LDIR, "causality", f"valence_causality_L{L}.json")
        ),
        "abl": _p(os.path.join(LDIR, "localize", f"localize_head_ablation_L{L}.csv")),
        "align": _p(
            os.path.join(LDIR, "localize", f"localize_head_alignment_L{L}.csv")
        ),
        "proj": _p(os.path.join(LDIR, "localize", f"localize_project_out_L{L}.csv")),
        "dec": _p(os.path.join(LDIR, "decisions", f"decision_metrics_L{L}.csv")),
        "boot": _p(os.path.join(LDIR, "robustness", f"bootstrap_effects_L{L}.csv")),
        "null": _p(os.path.join(LDIR, "robustness", f"null_injection_L{L}.csv")),
        "projout_collapse": _p(
            os.path.join(LDIR, "robustness", f"projectout_collapse_L{L}.csv")
        ),
    }


# -------------------------
# Analysis helpers
# -------------------------


def analyze_budget(root_out: str) -> Dict:
    files = budget_files(root_out)
    probe_df = _read_csv(files["probe_by_layer"])
    events_df = _read_csv(files["events_summary"])
    examples = _read_json(files["examples_debug"])

    peak = None
    if probe_df is not None and len(probe_df):
        probe_df = probe_df.sort_values(["f1_macro", "acc"], ascending=[False, False])
        peak = dict(probe_df.iloc[0])
        peak["layer"] = int(peak["layer"])
    return {
        "probe_df": probe_df,
        "events_df": events_df,
        "examples": examples,
        "peak": peak,
        "files": files,
    }


def analyze_causality(path_json: str) -> Optional[pd.DataFrame]:
    obj = _read_json(path_json)
    if not obj:
        return None
    L = int(obj.get("layer", -1))
    rows = obj.get("results", [])
    if not rows:
        return None
    df = pd.DataFrame(rows)
    # Accept legacy or normalized columns
    if set(["dlogit_-", "dlogit_=", "dlogit_#"]).issubset(df.columns):
        token_cols = [
            ("single", "dlogit_-"),
            ("double", "dlogit_="),
            ("triple", "dlogit_#"),
        ]
    elif set(["dlogit_single", "dlogit_double", "dlogit_triple"]).issubset(df.columns):
        token_cols = [
            ("single", "dlogit_single"),
            ("double", "dlogit_double"),
            ("triple", "dlogit_triple"),
        ]
    else:
        return None

    out = []
    for tok, col in token_cols:
        m, r2 = _slope(df["alpha"].values, df[col].values)
        n = int(df["n_positions"].iloc[0]) if "n_positions" in df else None
        out.append(dict(layer=L, token=tok, slope=m, r2=r2, n=n))
    return pd.DataFrame(out)


def analyze_localization(
    abl_path: str, align_path: str
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[float]]:
    abl = _read_csv(abl_path)
    align = _read_csv(align_path)
    if abl is None or align is None or len(abl) == 0 or len(align) == 0:
        return abl, align, None
    abl = abl.copy()
    # composite: supports higher bond order if (eq + hash) drop > minus drop
    abl["score"] = (
        abl.get("drop_eq", 0).fillna(0)
        + abl.get("drop_hash", 0).fillna(0)
        - abl.get("drop_minus", 0).fillna(0)
    )
    # Use suffixes to avoid 'n' column collision
    merged = pd.merge(
        abl, align, on=["layer", "head"], how="inner", suffixes=("_abl", "_align")
    )
    r = (
        _pearson(merged["score"].values, merged["cos_mean"].values)
        if "cos_mean" in merged
        else float("nan")
    )
    return abl, align, r


def select_top_heads(
    abl: pd.DataFrame, align: Optional[pd.DataFrame], k: int = 3
) -> pd.DataFrame:
    if abl is None or len(abl) == 0:
        return pd.DataFrame()
    abl = abl.copy()
    abl["score"] = (
        abl.get("drop_eq", 0).fillna(0)
        + abl.get("drop_hash", 0).fillna(0)
        - abl.get("drop_minus", 0).fillna(0)
    )
    top = abl.sort_values("score", ascending=False).head(k)
    if align is not None and len(align):
        top = pd.merge(
            top, align, on=["layer", "head"], how="left", suffixes=("_abl", "_align")
        )
    return top


def analyze_project_out(proj_path: str) -> Optional[pd.DataFrame]:
    """
    Parse per-head OV project-out/scale CSVs with flexible column names.
    Accepts any of these patterns (case sensitive):
      - head/Head, alpha
      - delta columns:
          dlogit_single/double/triple
          dlogit_- / dlogit_= / dlogit_#
          dlogit_minus / dlogit_eq / dlogit_hash
          delta_single/double/triple
          delta_- / delta_= / delta_#
    Returns a long-form DataFrame with columns: [head, token(single/double/triple), slope, r2].
    """
    df = _read_csv(proj_path)
    if df is None or len(df) == 0:
        return None

    # Normalize head column name
    head_col = None
    for cand in ("head", "Head", "HEAD"):
        if cand in df.columns:
            head_col = cand
            break
    if head_col is None:
        return None

    # Build a column alias map -> normalized tokens
    cols = set(df.columns)
    col_map = {}

    # canonical
    for src, tok in (
        ("dlogit_single", "single"),
        ("dlogit_double", "double"),
        ("dlogit_triple", "triple"),
    ):
        if src in cols:
            col_map[tok] = src
    # legacy symbol
    for src, tok in (
        ("dlogit_-", "single"),
        ("dlogit_=", "double"),
        ("dlogit_#", "triple"),
    ):
        if src in cols and tok not in col_map:
            col_map[tok] = src
    # your variant: minus/eq/hash
    for src, tok in (
        ("dlogit_minus", "single"),
        ("dlogit_eq", "double"),
        ("dlogit_hash", "triple"),
    ):
        if src in cols and tok not in col_map:
            col_map[tok] = src
    # delta_* fallbacks
    for src, tok in (
        ("delta_single", "single"),
        ("delta_double", "double"),
        ("delta_triple", "triple"),
        ("delta_-", "single"),
        ("delta_=", "double"),
        ("delta_#", "triple"),
    ):
        if src in cols and tok not in col_map:
            col_map[tok] = src

    if "alpha" not in df.columns or not col_map:
        return None

    out = []
    for h, sub in df.groupby(head_col):
        x = sub["alpha"].to_numpy()
        for tok, col in col_map.items():
            y = sub[col].to_numpy()
            if y.size < 2:
                continue
            m, r2 = _slope(x, y)
            out.append(dict(head=int(h), token=tok, slope=m, r2=r2))
    return pd.DataFrame(out) if out else None


def analyze_decisions(dec_path: str) -> Optional[pd.DataFrame]:
    df = _read_csv(dec_path)
    if df is None or len(df) == 0:
        return None
    df = df.copy()
    if {"dmargin3_event", "dmargin3_ctrl"}.issubset(df.columns):
        df["dmargin3_contrast"] = df["dmargin3_event"] - df["dmargin3_ctrl"]
    if {"switch_rate_event", "switch_rate_ctrl"}.issubset(df.columns):
        df["switch_contrast"] = df["switch_rate_event"] - df["switch_rate_ctrl"]
    return df


def analyze_bootstrap(boot_path: str) -> Optional[pd.DataFrame]:
    df = _read_csv(boot_path)
    if df is None or len(df) == 0:
        return None
    df = normalize_bootstrap_cols(df)
    df["significant"] = (df["lo95"] * df["hi95"]) > 0  # CI excludes zero
    return df


def analyze_null(null_path: str) -> Optional[pd.DataFrame]:
    df = _read_csv(null_path)
    if df is None or len(df) == 0:
        return None
    return df


def analyze_projout_collapse(path: str) -> Optional[pd.DataFrame]:
    df = _read_csv(path)
    return df


# -------------------------
# Discovery & budget
# -------------------------


def build_report(root_out: str, save_dir: str) -> None:
    os.makedirs(save_dir, exist_ok=True)
    layers = discover_layers(root_out)

    # Budget (probe by layer)
    probe_file = _first_existing(
        [
            _p(os.path.join(root_out, "valence_probe_by_layer.csv")),
            _p(os.path.join(root_out, "budget", "valence_probe_by_layer.csv")),
        ]
    )
    probe_df = _read_csv(probe_file)

    report = []
    report.append("# Valence suite summary\n")
    report.append(f"- Root results dir: `{root_out}`\n")

    best_layer = None
    if probe_df is None or len(probe_df) == 0:
        report.append("**Budget:** No probe_by_layer file found.\n")
    else:
        pb = probe_df.sort_values(
            ["f1_macro", "acc"], ascending=[False, False]
        ).reset_index(drop=True)
        best_row = pb.iloc[0]
        best_layer = int(best_row["layer"])
        report.append("**Budget probes (layer summary):**\n\n")
        report.append(probe_df.sort_values("layer").to_markdown(index=False))
        report.append("\n")
        report.append(
            f"**Peak layer (macro-F1):** L{best_layer}  "
            f"(acc={best_row['acc']:.3f}, F1={best_row['f1_macro']:.3f})\n"
        )

    if probe_df is not None:
        probe_df.to_csv(os.path.join(save_dir, "summary_by_layer.csv"), index=False)

    # Per-layer analyses
    causality_rows = []
    per_layer_rows = []

    for L in layers:
        files = per_layer_files(root_out, L)

        # Causality slopes (∆logit vs α)
        caus = analyze_causality(files["causality"])
        if caus is not None and len(caus):
            causality_rows.append(caus)

        # Localization
        abl, align, corr = analyze_localization(files["abl"], files["align"])
        top_heads = (
            select_top_heads(abl, align, k=3) if abl is not None else pd.DataFrame()
        )

        # OV project-out/scale (per-head)
        proj = analyze_project_out(files["proj"])

        # Decisions
        dec = analyze_decisions(files["dec"])

        # Optional robustness
        boot = analyze_bootstrap(files["boot"])
        null = analyze_null(files["null"])
        coll = analyze_projout_collapse(files["projout_collapse"])

        # Key layer row
        row = dict(layer=L)
        if probe_df is not None and "layer" in probe_df.columns:
            mrow = probe_df[probe_df["layer"] == L]
            if len(mrow):
                row.update(
                    acc=float(mrow["acc"].iloc[0]),
                    f1_macro=float(mrow["f1_macro"].iloc[0]),
                )
        if abl is not None and len(abl):
            row["n_heads_abl"] = int(abl["head"].nunique())
            row["abl_top_score"] = float(abl["score"].max())
        if align is not None and len(align):
            row["align_cos_mean_med"] = float(np.median(align["cos_mean"]))
        if corr is not None and not math.isnan(corr):
            row["abl_vs_align_corr"] = float(corr)
        if dec is not None and len(dec):
            pos = dec[dec["alpha"] > 0]
            neg = dec[dec["alpha"] < 0]
            for name, sub in [("pos", pos), ("neg", neg)]:
                if len(sub):
                    if "dmargin3_contrast" in sub.columns:
                        row[f"dmargin3_contrast_{name}"] = float(
                            sub["dmargin3_contrast"].mean()
                        )
                    if "unsat_shift_event" in sub.columns:
                        row[f"unsat_shift_event_{name}"] = float(
                            sub["unsat_shift_event"].mean()
                        )
        per_layer_rows.append(row)

        # Persist layer-only CSVs
        if top_heads is not None and len(top_heads):
            top_heads.to_csv(
                os.path.join(save_dir, f"head_ranking_L{L}.csv"), index=False
            )
        if abl is not None and align is not None and len(abl) and len(align):
            merged = pd.merge(
                abl[["layer", "head", "score"]],
                align[["layer", "head", "cos_mean", "cos_median", "frac_positive"]],
                on=["layer", "head"],
                how="inner",
                suffixes=("_abl", "_align"),
            )
            merged.to_csv(
                os.path.join(save_dir, f"localization_correlations_L{L}.csv"),
                index=False,
            )
        if dec is not None and len(dec):
            dec.to_csv(
                os.path.join(save_dir, f"decision_metrics_L{L}.csv"), index=False
            )

        # Report section per layer
        section = []
        section.append(f"\n## Layer L{L}\n")
        if caus is not None and len(caus):
            s_map = {r["token"]: r["slope"] for _, r in caus.iterrows()}
            s_single = float(s_map.get("single", 0.0))
            s_double = float(s_map.get("double", 0.0))
            s_triple = float(s_map.get("triple", 0.0))
            section.append(
                f"- **Causality (∆logit slope vs α):** single={s_single:+.4f}, double={s_double:+.4f}, triple={s_triple:+.4f}\n"
            )
        else:
            section.append("- Causality: _no data_.\n")

        if abl is not None and len(abl):
            section.append(
                f"- **Head ablation:** {len(abl)} rows, best score={abl['score'].max():.4f}\n"
            )
        else:
            section.append("- Head ablation: _no data_.\n")

        if align is not None and len(align):
            section.append(
                f"- **Head–direction alignment:** median cos={np.median(align['cos_mean']):.4f}, "
                f"mean frac_positive={np.mean(align['frac_positive']):.3f}\n"
            )
        else:
            section.append("- Head–direction alignment: _no data_.\n")

        if corr is not None and not math.isnan(corr):
            section.append(
                f"- **Ablation–alignment correlation (Pearson):** r={corr:.3f}\n"
            )

        if proj is not None and len(proj):
            proj_sum = proj.groupby("token")["slope"].mean()
            s_single = float(proj_sum.get("single", 0.0))
            s_double = float(proj_sum.get("double", 0.0))
            s_triple = float(proj_sum.get("triple", 0.0))
            section.append(
                f"- **Head OV scaling (avg slope):** single={s_single:+.4f}, double={s_double:+.4f}, triple={s_triple:+.4f}\n"
            )
        else:
            section.append("- Head OV scaling: _no data_.\n")

        if dec is not None and len(dec):
            section.append("- **Decision metrics (event vs control):**\n")
            if "dmargin3_contrast" in dec.columns:
                pos_mean = (
                    float(dec[dec["alpha"] > 0]["dmargin3_contrast"].mean())
                    if (dec["alpha"] > 0).any()
                    else 0.0
                )
                neg_mean = (
                    float(dec[dec["alpha"] < 0]["dmargin3_contrast"].mean())
                    if (dec["alpha"] < 0).any()
                    else 0.0
                )
                section.append(
                    f"  - dmargin3 contrast: α>0 → {pos_mean:+.4f}; α<0 → {neg_mean:+.4f}\n"
                )
            if "unsat_shift_event" in dec.columns:
                section.append(
                    f"  - unsat_shift_event (mean over α): {float(dec['unsat_shift_event'].mean()):+.4f}\n"
                )
        else:
            section.append("- Decision metrics: _no data_.\n")

        if boot is not None and len(boot):
            sig = int(boot[boot["significant"]].shape[0])
            section.append(
                f"- **Bootstrap effects:** significant cells (95% CI exclude 0): {sig}\n"
            )
        if null is not None and len(null):
            section.append(
                f"- **Null injections:** {len(null)} rows (orthogonal-random control present)\n"
            )
        if coll is not None and len(coll):
            section.append(
                f"- **Project-out collapse (decision-t):** {len(coll)} rows\n"
            )

        report.extend(section)

    # Causality slopes table
    if len(causality_rows):
        caus_all = pd.concat(causality_rows, axis=0).sort_values(["layer", "token"])
        caus_all.to_csv(os.path.join(save_dir, "causality_slopes.csv"), index=False)
        report.append("\n## Causality slopes across layers\n")
        report.append(
            caus_all.pivot(index="layer", columns="token", values="slope").to_markdown(
                index=True
            )
        )
        report.append("\n")

    # Per-layer key numbers table
    if len(per_layer_rows):
        agg = pd.DataFrame(per_layer_rows).sort_values("layer")
        agg.to_csv(os.path.join(save_dir, "summary_by_layer.csv"), index=False)
        report.append("\n## Per-layer key metrics\n")
        show_cols = [
            c
            for c in [
                "layer",
                "acc",
                "f1_macro",
                "n_heads_abl",
                "abl_top_score",
                "align_cos_mean_med",
                "abl_vs_align_corr",
                "dmargin3_contrast_pos",
                "dmargin3_contrast_neg",
                "unsat_shift_event_pos",
                "unsat_shift_event_neg",
            ]
            if c in agg.columns
        ]
        if show_cols:
            report.append(agg[show_cols].to_markdown(index=False))
            report.append("\n")

    if probe_df is not None and best_layer is not None:
        report.append(
            f"\n---\n\n**Recommended focus layer:** L{best_layer} (peak macro-F1 in budget probes)\n"
        )

    md_path = os.path.join(save_dir, "valence_report.md")
    with open(md_path, "w") as f:
        f.write("\n".join(report))
    print(f"[ok] wrote → {md_path}")


# -------------------------
# CLI
# -------------------------


def main():
    ap = argparse.ArgumentParser(description="Summarize valence experiment outputs.")
    ap.add_argument(
        "--root_out",
        default="experiments/valence_suite_all_layers",
        help="Root directory created by the runner.",
    )
    ap.add_argument(
        "--save_dir", default="summary", help="Where to write the report & tables."
    )
    args = ap.parse_args()

    build_report(root_out=args.root_out, save_dir=args.save_dir)


if __name__ == "__main__":
    main()
