#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Summarize *causal pre-screen* outputs across many runs (layers, SAE sizes).
Reads every `final_metrics.npz` (with keys like "<frag>/pre_wsd_mean", "<frag>/pre_pb_corr"),
aggregates NaN-safely, and emits CSV/Parquet summaries and a Markdown overview.

USAGE: edit the CONFIG block and run:
    uv run lmkit/sparse/summarize_causal_pre_scans.py
"""

from __future__ import annotations

import os
import re
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

# Progress bars
try:
    from tqdm import tqdm
except Exception:  # fallback no-op

    def tqdm(x, **kwargs):
        return x

# =========================
# ======= CONFIG ==========
# =========================

# Top-level directories to search recursively for *final_metrics.npz* from the
# causal pre-screen runs. Add as many roots as you like; the script scans with rglob.
PRE_SCAN_ROOTS = [
    "fragment_causal_pre_8x_layer5",  # example: a parent folder containing many runs
    # "pre_scan_output_relu_4x_e9a211", # or add specific SAE/layer folders
    # "pre_scan_output_relu_8x_95aa15",
]

# Where to write consolidated summaries
SUMMARY_OUTDIR = "pre_fragment_scan_summaries_8x"

# Metrics to load from final_metrics.npz (the script is WSD-first)
METRICS = (
    "pre_wsd_mean",
    "pre_pb_corr",
)  # include "pre_auroc_approx" if present in runs

# Ranking / summary knobs
TOPK_FEATURES = 10  # top-k per (run, fragment, metric) for reporting
SHORTLIST_N = 25  # joint shortlist size (composite of pre_wsd + pre_pb)
PARETO_TAU = 0.80  # percentile for quadrant thresholds (pooled top-k)

# Optional: write a long-form Parquet with top-M per metric (can be large)
WRITE_BIG_PARQUET = True
BIG_PARQUET_TOP_M = 50

# Optional: attempt to ingest ΔNLL ablation CSVs under each run's "ablations/" subdir
INGEST_ABLATIONS = True

# Optional plots
MAKE_PLOTS = False


# =========================
# ========= IO ============
# =========================


@dataclass
class RunMeta:
    path: Path  # directory containing final_metrics.npz
    layer: Optional[int]  # parsed from directory name (layerX), may be None
    sae_tag: str  # guessed (e.g., relu_4x_...), else "unknown"
    npz_path: Path  # full path to final_metrics.npz


def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def find_all_npz(roots: List[str]) -> List[RunMeta]:
    metas: List[RunMeta] = []
    for root in roots:
        root = Path(root)
        if not root.exists():
            continue
        for npz in root.rglob("final_metrics.npz"):
            d = npz.parent
            # parse layer id from any "...layer<d>..." component
            m_layer = re.search(r"layer(\d+)", str(d))
            layer = int(m_layer.group(1)) if m_layer else None
            # parse SAE tag if present
            m_sae = re.search(r"(relu_\d+x_[A-Za-z0-9]+)", str(d))
            sae = m_sae.group(1) if m_sae else "unknown"
            metas.append(RunMeta(path=d, layer=layer, sae_tag=sae, npz_path=npz))
    return metas


def load_final_npz(npz_path: Path) -> Dict[str, Dict[str, np.ndarray]]:
    """
    Accepts both:
      - "<frag>/pre_wsd_mean", "<frag>/pre_pb_corr"  (new causal-pre names)
      - "<frag>/wsd_mean",     "<frag>/pb_corr"      (legacy names)
    Returns dict[fragment][normalized_metric] -> (K,)
    where normalized_metric is always "pre_wsd_mean" / "pre_pb_corr" (and optional "pre_auroc_approx").
    """
    out: Dict[str, Dict[str, np.ndarray]] = {}
    # mapping from any seen name -> normalized name
    name_map = {
        "pre_wsd_mean": "pre_wsd_mean",
        "pre_pb_corr": "pre_pb_corr",
        "pre_auroc_approx": "pre_auroc_approx",
        # legacy:
        "wsd_mean": "pre_wsd_mean",
        "pb_corr": "pre_pb_corr",
        "auroc_approx": "pre_auroc_approx",
    }

    with np.load(npz_path, allow_pickle=False) as z:
        for key in z.files:
            if "/" not in key:
                continue
            frag, metric = key.rsplit("/", 1)
            norm = name_map.get(metric)
            if norm is None:
                continue
            vec = np.asarray(z[key])
            vec = np.where(np.isfinite(vec), vec, np.nan)  # NaN = missing; drop infs
            out.setdefault(frag, {})[norm] = vec
    return out


def parse_state_meta(dir_path: Path) -> Optional[dict]:
    p = dir_path / "state_meta.json"
    if not p.exists():
        return None
    try:
        with open(p) as f:
            return json.load(f)
    except Exception:
        return None


# =========================
# ======== UTILS ==========
# =========================


def robust_topk(vec: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
    """Return (values, indices) of top-k ignoring NaNs; if <k valid values, return what exists."""
    if vec.size == 0:
        return np.array([]), np.array([], dtype=int)
    v = vec.copy()
    valid = np.isfinite(v)
    if not valid.any():
        return np.array([]), np.array([], dtype=int)
    v = v[valid]
    idx = np.where(valid)[0]
    if v.size <= k:
        order = np.argsort(-v)
        return v[order], idx[order]
    # argpartition for speed; then sort the subset fully
    subset = np.argpartition(-v, k - 1)[:k]
    subset = subset[np.argsort(-v[subset])]
    return v[subset], idx[subset]


def quantiles(vec: np.ndarray, qs=(0.5, 0.9, 0.95, 0.99)) -> Dict[str, float]:
    valid = vec[np.isfinite(vec)]
    if valid.size == 0:
        return {f"q{int(q * 100)}": np.nan for q in qs}
    return {f"q{int(q * 100)}": float(np.nanquantile(valid, q)) for q in qs}


def composite_rank(
    pre_pb: np.ndarray, pre_wsd: np.ndarray, shortlist_n: int
) -> List[Tuple[int, float, float, float]]:
    """
    Composite = mean of per-vector z-scores (NaN-safe; 0 if all-NaN).
    Returns [(feature_id, pre_pb, pre_wsd, composite)] (length ≤ shortlist_n).
    """

    def z(v: np.ndarray) -> np.ndarray:
        if v.size == 0 or not np.isfinite(v).any():
            return np.zeros_like(v)
        mu = float(np.nanmean(v))
        sd = float(np.nanstd(v))
        sd = max(sd, 1e-12)
        out = (v - mu) / sd
        out[~np.isfinite(out)] = 0.0
        return out

    pbv = np.where(np.isfinite(pre_pb), pre_pb, np.nan)
    wv = np.where(np.isfinite(pre_wsd), pre_wsd, np.nan)
    zp, zw = z(pbv), z(wv)
    comp = 0.5 * (zp + zw)

    if comp.size == 0:
        return []
    order = np.argsort(-np.nan_to_num(comp, nan=-1e9))[:shortlist_n]
    out = []
    for i in order:
        sp = float(pbv[i]) if i < len(pbv) else np.nan
        sw = float(wv[i]) if i < len(wv) else np.nan
        out.append((int(i), sp, sw, float(comp[i])))
    return out


def quadrant_label(pb: float, wsd: float, pb_thr: float, wsd_thr: float) -> str:
    hi_pb = (pb >= pb_thr) if np.isfinite(pb) else False
    hi_wsd = (wsd >= wsd_thr) if np.isfinite(wsd) else False
    if hi_pb and hi_wsd:
        return "high_both"
    if hi_pb and not hi_wsd:
        return "high_pb_low_wsd"
    if not hi_pb and hi_wsd:
        return "high_wsd_low_pb"
    return "low_both"


def pooled_threshold(values: List[float], fallback: float, q=0.80) -> float:
    v = np.array([x for x in values if np.isfinite(x)])
    if v.size == 0:
        return fallback
    return float(np.nanquantile(v, q))


# =========================
# ======= MAIN ============
# =========================


def main():
    ensure_dir(Path(SUMMARY_OUTDIR))

    # 1) discover runs
    runs = find_all_npz(PRE_SCAN_ROOTS)
    if not runs:
        print("No final_metrics.npz found under:", PRE_SCAN_ROOTS)
        return

    # 2) load each run
    catalog_rows = []
    per_run_data: List[Tuple[RunMeta, Dict[str, Dict[str, np.ndarray]]]] = []

    for r in tqdm(runs, desc="Loading runs"):
        d = load_final_npz(r.npz_path)

        # estimate K and num fragments
        any_vec = None
        for frag in d:
            for m in METRICS:
                if m in d[frag]:
                    any_vec = d[frag][m]
                    break
            if any_vec is not None:
                break
        K = len(any_vec) if any_vec is not None else 0
        n_frags = len(d)

        meta = parse_state_meta(r.path)
        mols = meta.get("mols_processed", None) if meta else None

        catalog_rows.append(
            dict(
                run_dir=str(r.path),
                layer=r.layer,
                sae_tag=r.sae_tag,
                K=K,
                fragments=n_frags,
                mols_processed=mols,
            )
        )
        per_run_data.append((r, d))

    runs_catalog = pd.DataFrame(catalog_rows).sort_values(
        ["layer", "sae_tag", "run_dir"], na_position="last"
    )
    runs_catalog.to_csv(Path(SUMMARY_OUTDIR) / "pre_runs_catalog.csv", index=False)

    # 3) build long-form and shortlist (WSD-first)
    long_rows, shortlist_rows = [], []
    pooled_top_pb, pooled_top_wsd = [], []

    for r, d in tqdm(per_run_data, desc="Summarizing per run"):
        for frag, mdict in d.items():
            pre_pb = mdict.get("pre_pb_corr", np.full(0, np.nan))
            pre_wsd = mdict.get("pre_wsd_mean", np.full(0, np.nan))

            # Long-form: top-M per metric
            for metric_name, vec in (
                ("pre_pb_corr", pre_pb),
                ("pre_wsd_mean", pre_wsd),
            ):
                vals, idxs = robust_topk(vec, BIG_PARQUET_TOP_M)
                for v, i in zip(vals, idxs):
                    long_rows.append(
                        dict(
                            run_dir=str(r.path),
                            layer=r.layer,
                            sae_tag=r.sae_tag,
                            fragment=frag,
                            metric=metric_name,
                            feature_id=int(i),
                            score=float(v),
                        )
                    )

            # pool top‑K for thresholds
            vpb, _ = robust_topk(pre_pb, TOPK_FEATURES)
            vwd, _ = robust_topk(pre_wsd, TOPK_FEATURES)
            pooled_top_pb.extend([float(x) for x in vpb])
            pooled_top_wsd.extend([float(x) for x in vwd])

            # joint shortlist (composite rank)
            joint = composite_rank(pre_pb, pre_wsd, SHORTLIST_N)
            for fid, spb, sw, comp in joint:
                shortlist_rows.append(
                    dict(
                        run_dir=str(r.path),
                        layer=r.layer,
                        sae_tag=r.sae_tag,
                        fragment=frag,
                        feature_id=fid,
                        pre_pb_corr=spb,
                        pre_wsd_mean=sw,
                        composite=comp,
                    )
                )

    if WRITE_BIG_PARQUET and long_rows:
        df_long = pd.DataFrame(long_rows)
        df_long.to_parquet(Path(SUMMARY_OUTDIR) / "pre_all_scores.parquet", index=False)

    df_short = pd.DataFrame(shortlist_rows)
    if df_short.empty:
        # create empty table with expected columns so downstream code won't crash
        df_short = pd.DataFrame(
            columns=[
                "run_dir",
                "layer",
                "sae_tag",
                "fragment",
                "feature_id",
                "pre_pb_corr",
                "pre_wsd_mean",
                "composite",
            ]
        )
        df_short.to_csv(Path(SUMMARY_OUTDIR) / "pre_top_wsd_shortlist.csv", index=False)
    else:
        df_short = df_short.sort_values(
            ["fragment", "layer", "sae_tag", "composite"],
            ascending=[True, True, True, False],
        )
        df_short.to_csv(Path(SUMMARY_OUTDIR) / "pre_top_wsd_shortlist.csv", index=False)

    # 4) fragment difficulty (global) — WSD-first
    frag_stats = []
    for r, d in tqdm(per_run_data, desc="Per-fragment stats"):
        for frag, mdict in d.items():
            pre_pb = mdict.get("pre_pb_corr", np.full(0, np.nan))
            pre_wsd = mdict.get("pre_wsd_mean", np.full(0, np.nan))

            def bundle(vec):
                valid = vec[np.isfinite(vec)]
                if valid.size == 0:
                    return dict(
                        max=np.nan,
                        top10_med=np.nan,
                        top10_mean=np.nan,
                        **quantiles(valid),
                    )
                vals, _ = robust_topk(valid, TOPK_FEATURES)
                return dict(
                    max=float(np.nanmax(valid)),
                    top10_med=float(np.nanmedian(vals)) if vals.size else np.nan,
                    top10_mean=float(np.nanmean(vals)) if vals.size else np.nan,
                    **quantiles(valid),
                )

            bpb = bundle(pre_pb)
            bwsd = bundle(pre_wsd)

            frag_stats.append(
                dict(
                    run_dir=str(r.path),
                    layer=r.layer,
                    sae_tag=r.sae_tag,
                    fragment=frag,
                    pre_pb_max=bpb["max"],
                    pre_pb_top10_med=bpb["top10_med"],
                    pre_pb_top10_mean=bpb["top10_mean"],
                    pre_pb_q50=bpb["q50"],
                    pre_pb_q90=bpb["q90"],
                    pre_pb_q95=bpb["q95"],
                    pre_pb_q99=bpb["q99"],
                    pre_wsd_max=bwsd["max"],
                    pre_wsd_top10_med=bwsd["top10_med"],
                    pre_wsd_top10_mean=bwsd["top10_mean"],
                    pre_wsd_q50=bwsd["q50"],
                    pre_wsd_q90=bwsd["q90"],
                    pre_wsd_q95=bwsd["q95"],
                    pre_wsd_q99=bwsd["q99"],
                )
            )

    df_frag_stats = pd.DataFrame(frag_stats)
    df_frag_stats.to_csv(
        Path(SUMMARY_OUTDIR) / "pre_per_run_fragment_stats.csv", index=False
    )

    # 5) Aggregate by fragment × layer and fragment × SAE
    def agg_df(df: pd.DataFrame, group_cols: List[str]) -> pd.DataFrame:
        agg_cols = [
            c
            for c in df.columns
            if c not in (["run_dir", "fragment", "layer", "sae_tag"])
        ]
        out = df.groupby(group_cols, dropna=False)[agg_cols].agg(
            ["median", "mean", "std", "count"]
        )
        out.columns = ["__".join(c) for c in out.columns.to_flat_index()]
        out = out.reset_index()
        return out

    by_layer = agg_df(df_frag_stats, ["fragment", "layer"])
    by_layer.to_csv(
        Path(SUMMARY_OUTDIR) / "pre_per_fragment_layer_wsd.csv", index=False
    )

    by_sae = agg_df(df_frag_stats, ["fragment", "sae_tag"])
    by_sae.to_csv(Path(SUMMARY_OUTDIR) / "pre_per_fragment_sae_wsd.csv", index=False)

    # 6) Global difficulty ranking by pre_wsd (median of per-run top-10)
    frag_order = (
        df_frag_stats.groupby("fragment", dropna=False)["pre_wsd_top10_med"]
        .median()
        .sort_values(ascending=False)
        .reset_index()
        .rename(columns={"pre_wsd_top10_med": "pre_wsd_top10_med_global_median"})
    )
    frag_order.to_csv(
        Path(SUMMARY_OUTDIR) / "pre_fragment_difficulty_wsd.csv", index=False
    )

    # 7) Quadrant analysis: thresholds from pooled top‑K distributions (80th pct by default)
    pb_thr = pooled_threshold(
        pooled_top_pb, fallback=0.10, q=PARETO_TAU
    )  # falls back if no data
    wsd_thr = pooled_threshold(pooled_top_wsd, fallback=0.02, q=PARETO_TAU)

    quad_rows = []
    for r, d in per_run_data:
        for frag, mdict in d.items():
            pre_pb = mdict.get("pre_pb_corr", np.full(0, np.nan))
            pre_wsd = mdict.get("pre_wsd_mean", np.full(0, np.nan))
            best_pb = np.nanmax(pre_pb) if pre_pb.size else np.nan
            best_wsd = np.nanmax(pre_wsd) if pre_wsd.size else np.nan
            label = quadrant_label(best_pb, best_wsd, pb_thr, wsd_thr)
            quad_rows.append(
                dict(fragment=frag, layer=r.layer, sae_tag=r.sae_tag, label=label)
            )

    df_quad = pd.DataFrame(quad_rows)
    quadrant_counts = (
        df_quad.groupby(["fragment", "label"]).size().reset_index(name="count")
    )
    quadrant_counts["frac"] = quadrant_counts.groupby("fragment")["count"].transform(
        lambda x: x / x.sum()
    )
    quadrant_counts.to_csv(
        Path(SUMMARY_OUTDIR) / "pre_quadrant_counts.csv", index=False
    )

    # 8) Optional: ingest ablation ΔNLL snippets (if present)
    if INGEST_ABLATIONS:
        abl_rows = []
        for r, _ in tqdm(per_run_data, desc="Ingesting abl./ΔNLL (if any)"):
            abl_dir = r.path / "ablations"
            if not abl_dir.exists():
                continue
            for csv_path in abl_dir.rglob("*.csv"):
                try:
                    df = pd.read_csv(csv_path)
                    # Expect columns like: fragment, feature_id, delta_nll_mean, delta_nll_med, n_tokens, etc.
                    # If your schema differs, tweak here.
                    df["run_dir"] = str(r.path)
                    df["layer"] = r.layer
                    df["sae_tag"] = r.sae_tag
                    abl_rows.append(df)
                except Exception:
                    pass
        if abl_rows:
            df_abl = pd.concat(abl_rows, ignore_index=True)
            df_abl.to_csv(
                Path(SUMMARY_OUTDIR) / "pre_ablation_snippets.csv", index=False
            )

    # 9) Executive summary (Markdown)
    lines = []
    lines.append("# Causal Pre‑Screen Executive Summary\n")
    lines.append(f"- Runs found: **{len(runs)}**")
    lines.append(f"- Primary metric: **pre_wsd_mean** (Top‑K={TOPK_FEATURES})")
    lines.append(
        f"- Pareto tau for quadrant thresholds: {PARETO_TAU} (pre_pb_thr≈{pb_thr:.3g}, pre_wsd_thr≈{wsd_thr:.3g})\n"
    )

    # Top fragments by pre-WSD difficulty
    lines.append("## Hardest fragments (highest global median of top‑10 pre‑WSD)\n")
    for _, r0 in frag_order.head(15).iterrows():
        lines.append(
            f"- **{r0['fragment']}**: median top‑10 pre‑WSD = {r0['pre_wsd_top10_med_global_median']:.3g}"
        )

    # Per-layer highlights
    if not by_layer.empty:
        lines.append("\n## Layer profiles (median of per‑run top‑10 pre‑WSD)\n")
        piv = by_layer.pivot(
            index="fragment", columns="layer", values="pre_wsd_top10_med__median"
        ).fillna(np.nan)
        for frag in frag_order["fragment"].tolist()[:5]:
            if frag in piv.index:
                vals = piv.loc[frag]
                if np.isfinite(vals).any():
                    best_layer = int(vals.idxmax())
                    best_val = float(vals.max())
                    lines.append(
                        f"- **{frag}** peaks at layer **{best_layer}** with median(top‑10 pre‑WSD)≈{best_val:.3g}"
                    )

    (Path(SUMMARY_OUTDIR) / "pre_grand_summary.md").write_text("\n".join(lines))

    # 10) Optional plots
    if MAKE_PLOTS and not by_layer.empty:
        try:
            import matplotlib.pyplot as plt

            ensure_dir(Path(SUMMARY_OUTDIR) / "plots")

            # Nature-ish minimalist style
            plt.rcParams.update(
                {
                    "figure.dpi": 120,
                    "savefig.dpi": 300,
                    "font.size": 10,
                    "axes.spines.top": False,
                    "axes.spines.right": False,
                    "axes.grid": False,
                    "grid.alpha": 0.15,
                    "legend.frameon": False,
                }
            )

            piv = by_layer.pivot(
                index="fragment", columns="layer", values="pre_wsd_top10_med__median"
            )
            piv = piv.sort_values(
                by=piv.columns.tolist() if len(piv.columns) else [], ascending=False
            )
            fig, ax = plt.subplots(figsize=(9, max(4, 0.24 * len(piv))))
            im = ax.imshow(np.nan_to_num(piv.values, nan=0.0), aspect="auto")
            ax.set_yticks(range(len(piv.index)))
            ax.set_yticklabels(piv.index, fontsize=7)
            ax.set_xticks(range(len(piv.columns)))
            ax.set_xticklabels(piv.columns)
            ax.set_title("Median of top‑10 pre‑WSD by fragment × layer")
            cbar = fig.colorbar(im, ax=ax, shrink=0.7)
            cbar.set_label("pre‑WSD (median of per‑run top‑10)")
            fig.tight_layout()
            fig.savefig(
                Path(SUMMARY_OUTDIR)
                / "plots/pre_frag_by_layer_wsd_top10_med_median.pdf"
            )
            plt.close(fig)
        except Exception as e:
            print("[warn] plotting failed:", e)

    print(f"✓ Wrote pre-screen summaries to: {SUMMARY_OUTDIR}")


if __name__ == "__main__":
    main()
