#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
WSD-first summarizer for fragment_scan outputs across many runs (layers, SAE sizes).
Reads each `final_metrics.npz`, aggregates NaN-safely, and emits WSD-centric CSV/Parquet summaries.

USAGE: edit the CONFIG block and run:
    uv run lmkit/sparse/summarize_fragment_scans_wsd.py
"""

from __future__ import annotations

import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm

# =========================
# ======= CONFIG ==========
# =========================

# Parent that contains all run directories with final_metrics.npz
frag_dir = "fragment_scan_output_all"

# SAE variants to include:
saes = [
    "relu_4x_e9a211",
    "relu_8x_95aa15",
    # "relu_16x_b0334a",
]

# Build search roots like: fragment_scan_output_all/run_<SAE>_layer<d>_layer<d>
SCAN_ROOTS: List[str] = []
for sae in saes:
    for i in range(1, 6):  # layers 1..5
        SCAN_ROOTS.append(f"{frag_dir}/run_{sae}_layer{i}_layer{i}")

# Output directory
SUMMARY_OUTDIR = "fragment_scan_summaries_wsd"

# --------- WSD-centric knobs ----------
PRIMARY_METRIC = "wsd_mean"
INCLUDE_PB_CORR = False  # set True if you also want pb_corr tables for context

# Ranking & summaries
TOPK_FEATURES = 10  # top-k features per (run, fragment) for per-run summaries
SHORTLIST_N = 25  # top-N features by WSD exported to shortlist
WRITE_BIG_PARQUET = True  # long-form table of top-M per metric (can be big)
BIG_PARQUET_TOP_M = 50  # cap per (run,frag,metric)

# Optional plots
MAKE_PLOTS = True

# =========================
# ====== HELPERS ==========
# =========================


@dataclass
class RunMeta:
    path: Path  # dir containing final_metrics.npz
    layer: Optional[int]  # parsed from path
    sae_tag: str  # parsed from path
    npz_path: Path  # full path to npz


def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def find_all_npz(roots: List[str]) -> List[RunMeta]:
    metas: List[RunMeta] = []
    for root in roots:
        root_path = Path(root)
        if not root_path.exists():
            continue
        # walk with progress per root
        for npz in root_path.rglob("final_metrics.npz"):
            d = npz.parent
            m_layer = re.search(r"layer(\d+)", str(d))
            layer = int(m_layer.group(1)) if m_layer else None
            m_sae = re.search(r"(relu_\d+x_[A-Za-z0-9]+)", str(d))
            sae = m_sae.group(1) if m_sae else "unknown"
            metas.append(RunMeta(path=d, layer=layer, sae_tag=sae, npz_path=npz))
    return metas


def load_final_npz(
    npz_path: Path, primary: str, include_pb: bool
) -> Dict[str, Dict[str, np.ndarray]]:
    """
    Returns: dict[fragment][metric] -> (K,)
    Keeps PRIMARY metric, wsd_max, and optionally pb_corr.
    """
    wanted = {primary, "wsd_max"}
    if include_pb:
        wanted.add("pb_corr")
    out: Dict[str, Dict[str, np.ndarray]] = {}
    with np.load(npz_path, allow_pickle=False) as z:
        for key in z.files:
            if "/" not in key:
                continue
            frag, metric = key.rsplit("/", 1)
            if metric not in wanted:
                continue
            vec = np.asarray(z[key])
            # keep NaNs as NaNs for aggregation; scrub infs
            vec = np.where(np.isfinite(vec), vec, np.nan)
            out.setdefault(frag, {})[metric] = vec
    return out


def parse_state_meta(dir_path: Path) -> Optional[dict]:
    p = dir_path / "state_meta.json"
    if not p.exists():
        return None
    try:
        with open(p) as f:
            return json.load(f)
    except Exception:
        return None


def robust_topk(vec: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
    """Return (values, indices) top-k ignoring NaNs."""
    if vec.size == 0:
        return np.array([]), np.array([], dtype=int)
    valid = np.isfinite(vec)
    if not valid.any():
        return np.array([]), np.array([], dtype=int)
    v = vec[valid]
    idx = np.where(valid)[0]
    if v.size <= k:
        order = np.argsort(-v)
        return v[order], idx[order]
    order = np.argpartition(-v, k - 1)[:k]
    order = order[np.argsort(-v[order])]
    return v[order], idx[order]


def quantiles(vec: np.ndarray, qs=(0.5, 0.9, 0.95, 0.99)) -> Dict[str, float]:
    valid = vec[np.isfinite(vec)]
    if valid.size == 0:
        return {f"q{int(q * 100)}": np.nan for q in qs}
    return {f"q{int(q * 100)}": float(np.nanquantile(valid, q)) for q in qs}


# =========================
# ======= MAIN ============
# =========================


def main():
    ensure_dir(Path(SUMMARY_OUTDIR))

    # 1) Discover runs
    runs = find_all_npz(SCAN_ROOTS)
    if not runs:
        print("No final_metrics.npz found under:", SCAN_ROOTS)
        return

    # 2) Load runs with progress
    catalog_rows = []
    per_run_data: List[Tuple[RunMeta, Dict[str, Dict[str, np.ndarray]]]] = []

    for r in tqdm(runs, desc="Loading runs", unit="run"):
        d = load_final_npz(
            r.npz_path, primary=PRIMARY_METRIC, include_pb=INCLUDE_PB_CORR
        )
        # infer K from primary metric
        any_vec = None
        for frag, md in d.items():
            if PRIMARY_METRIC in md:
                any_vec = md[PRIMARY_METRIC]
                break
        K = len(any_vec) if any_vec is not None else 0
        meta = parse_state_meta(r.path)
        mols = meta.get("mols_processed", None) if meta else None
        catalog_rows.append(
            dict(
                run_dir=str(r.path),
                layer=r.layer,
                sae_tag=r.sae_tag,
                K=K,
                fragments=len(d),
                mols_processed=mols,
            )
        )
        per_run_data.append((r, d))

    runs_catalog = pd.DataFrame(catalog_rows).sort_values(
        ["layer", "sae_tag", "run_dir"]
    )
    runs_catalog.to_csv(Path(SUMMARY_OUTDIR) / "runs_catalog.csv", index=False)

    # 3) Long-form table (optional) & WSD shortlists
    long_rows = []
    shortlist_rows = []
    max_wsd_shortlist_rows = []  # For the new wsd_max shortlist

    for r, d in tqdm(per_run_data, desc="Indexing scores", unit="run"):
        for frag, mdict in d.items():
            wsd = mdict.get(PRIMARY_METRIC, np.full(0, np.nan))
            wsd_max = mdict.get("wsd_max", np.full(0, np.nan))

            # Long-form (top-M WSD)
            vals, idxs = robust_topk(wsd, BIG_PARQUET_TOP_M)
            for v, i in zip(vals, idxs):
                long_rows.append(
                    dict(
                        run_dir=str(r.path),
                        layer=r.layer,
                        sae_tag=r.sae_tag,
                        fragment=frag,
                        metric=PRIMARY_METRIC,
                        feature_id=int(i),
                        score=float(v),
                    )
                )

            # Shortlist (top-N WSD mean)
            sv, si = robust_topk(wsd, SHORTLIST_N)
            for v, i in zip(sv, si):
                shortlist_rows.append(
                    dict(
                        run_dir=str(r.path),
                        layer=r.layer,
                        sae_tag=r.sae_tag,
                        fragment=frag,
                        feature_id=int(i),
                        wsd_mean=float(v),
                    )
                )

            # --- NEW: Shortlist (top-N WSD max) ---
            sv_max, si_max = robust_topk(wsd_max, SHORTLIST_N)
            for v, i in zip(sv_max, si_max):
                max_wsd_shortlist_rows.append(
                    dict(
                        run_dir=str(r.path),
                        layer=r.layer,
                        sae_tag=r.sae_tag,
                        fragment=frag,
                        feature_id=int(i),
                        wsd_max=float(v),
                    )
                )

            # Optional: pb_corr top-M alongside (context only)
            if INCLUDE_PB_CORR and "pb_corr" in mdict:
                pv, pi = robust_topk(mdict["pb_corr"], BIG_PARQUET_TOP_M)
                for v, i in zip(pv, pi):
                    long_rows.append(
                        dict(
                            run_dir=str(r.path),
                            layer=r.layer,
                            sae_tag=r.sae_tag,
                            fragment=frag,
                            metric="pb_corr",
                            feature_id=int(i),
                            score=float(v),
                        )
                    )

    if WRITE_BIG_PARQUET and long_rows:
        pd.DataFrame(long_rows).to_parquet(
            Path(SUMMARY_OUTDIR) / "all_scores.parquet", index=False
        )

    df_short = pd.DataFrame(shortlist_rows).sort_values(
        ["fragment", "layer", "sae_tag", "wsd_mean"],
        ascending=[True, True, True, False],
    )
    df_short.to_csv(Path(SUMMARY_OUTDIR) / "top_wsd_shortlist.csv", index=False)

    # --- NEW: Save the wsd_max shortlist CSV ---
    if max_wsd_shortlist_rows:
        df_short_max = pd.DataFrame(max_wsd_shortlist_rows).sort_values(
            ["fragment", "layer", "sae_tag", "wsd_max"],
            ascending=[True, True, True, False],
        )
        df_short_max.to_csv(
            Path(SUMMARY_OUTDIR) / "top_wsd_max_shortlist.csv", index=False
        )
    else:
        print("[warn] No wsd_max data found; skipping top_wsd_max_shortlist.csv")

    # 4) Fragment difficulty (WSD) with progress
    frag_stats = []
    for r, d in tqdm(per_run_data, desc="Computing fragment stats", unit="run"):
        for frag, mdict in d.items():
            wsd = mdict.get(PRIMARY_METRIC, np.full(0, np.nan))
            wsd_valid = wsd[np.isfinite(wsd)] if wsd.size else np.array([])

            def bundle(vec):
                if vec.size == 0:
                    return dict(
                        max=np.nan,
                        top10_med=np.nan,
                        top10_mean=np.nan,
                        **quantiles(vec),
                    )
                vals, _ = robust_topk(vec, TOPK_FEATURES)
                return dict(
                    max=float(np.nanmax(vec)),
                    top10_med=float(np.nanmedian(vals)) if vals.size else np.nan,
                    top10_mean=float(np.nanmean(vals)) if vals.size else np.nan,
                    **quantiles(vec),
                )

            b = bundle(wsd_valid)
            frag_stats.append(
                dict(
                    run_dir=str(r.path),
                    layer=r.layer,
                    sae_tag=r.sae_tag,
                    fragment=frag,
                    wsd_max=b["max"],
                    wsd_top10_med=b["top10_med"],
                    wsd_top10_mean=b["top10_mean"],
                    wsd_q50=b["q50"],
                    wsd_q90=b["q90"],
                    wsd_q95=b["q95"],
                    wsd_q99=b["q99"],
                )
            )

    df_frag_stats = pd.DataFrame(frag_stats)

    # 5) Aggregations by fragment × layer / fragment × SAE
    def agg_df(df: pd.DataFrame, group_cols: List[str]) -> pd.DataFrame:
        agg_cols = [
            c
            for c in df.columns
            if c not in ("run_dir", "fragment", "layer", "sae_tag")
        ]
        out = df.groupby(group_cols, dropna=False)[agg_cols].agg(
            ["median", "mean", "std", "count"]
        )
        out.columns = ["__".join(c) for c in out.columns.to_flat_index()]
        out = out.reset_index()
        return out

    by_layer = agg_df(df_frag_stats, ["fragment", "layer"])
    by_layer.to_csv(Path(SUMMARY_OUTDIR) / "per_fragment_layer_wsd.csv", index=False)

    by_sae = agg_df(df_frag_stats, ["fragment", "sae_tag"])
    by_sae.to_csv(Path(SUMMARY_OUTDIR) / "per_fragment_sae_wsd.csv", index=False)

    # 6) Global difficulty ranking by WSD (median top-10 across all runs)
    frag_order = (
        df_frag_stats.groupby("fragment", dropna=False)["wsd_top10_med"]
        .median()
        .sort_values(ascending=False)
        .reset_index()
        .rename(columns={"wsd_top10_med": "wsd_top10_med_global_median"})
    )
    frag_order.to_csv(Path(SUMMARY_OUTDIR) / "fragment_difficulty_wsd.csv", index=False)

    # 7) Best exemplars: for each (fragment × layer), pick the run with highest median(top-10 WSD)
    exemplar_rows = []
    for (frag, layer), grp in df_frag_stats.groupby(["fragment", "layer"]):
        if grp.empty:
            continue
        # pick row with max wsd_top10_med
        j = grp["wsd_top10_med"].idxmax()
        row = grp.loc[j]
        # find its top-N features from df_short
        sel = (
            df_short[
                (df_short["run_dir"] == row["run_dir"]) & (df_short["fragment"] == frag)
            ]
            .sort_values("wsd_mean", ascending=False)
            .head(min(10, SHORTLIST_N))
        )
        top_feats = ";".join(str(int(x)) for x in sel["feature_id"].tolist())
        exemplar_rows.append(
            dict(
                fragment=frag,
                layer=int(layer) if pd.notna(layer) else None,
                sae_tag=str(grp.loc[j, "sae_tag"]),
                run_dir=str(row["run_dir"]),
                wsd_top10_med=float(row["wsd_top10_med"]),
                best_feature_ids=top_feats,
            )
        )
    pd.DataFrame(exemplar_rows).to_csv(
        Path(SUMMARY_OUTDIR) / "best_wsd_exemplars.csv", index=False
    )

    # 8) WSD-first executive summary
    lines = []
    lines.append("# WSD-First Executive Summary\n")
    lines.append(f"- Runs found: **{len(runs)}**")
    lines.append(f"- Primary metric: **{PRIMARY_METRIC}** (Top‑K={TOPK_FEATURES})\n")

    lines.append("## Hardest fragments (highest global median of top‑10 WSD)\n")
    for _, r0 in frag_order.head(15).iterrows():
        lines.append(
            f"- **{r0['fragment']}**: median top‑10 WSD = {r0['wsd_top10_med_global_median']:.4g}"
        )

    # Layer peaks for top fragments
    if not by_layer.empty:
        lines.append("\n## Layer peaks (median of per‑run top‑10 WSD)\n")
        piv = by_layer.pivot(
            index="fragment", columns="layer", values="wsd_top10_med__median"
        ).fillna(np.nan)
        for frag in frag_order["fragment"].tolist()[:5]:
            if frag in piv.index and np.isfinite(piv.loc[frag]).any():
                vals = piv.loc[frag]
                best_layer = int(vals.idxmax())
                best_val = float(vals.max())
                lines.append(
                    f"- **{frag}** peaks at layer **{best_layer}** with median(top‑10 WSD)≈{best_val:.4g}"
                )

    (Path(SUMMARY_OUTDIR) / "wsd_grand_summary.md").write_text("\n".join(lines))

    # 9) Optional plots
    if MAKE_PLOTS and not by_layer.empty:
        try:
            import matplotlib.pyplot as plt

            ensure_dir(Path(SUMMARY_OUTDIR) / "plots")

            piv = by_layer.pivot(
                index="fragment", columns="layer", values="wsd_top10_med__median"
            )
            piv = piv.sort_values(
                by=piv.columns.tolist(), ascending=False if piv.columns.size else True
            )
            fig, ax = plt.subplots(figsize=(9, max(4, 0.24 * len(piv))))
            im = ax.imshow(np.nan_to_num(piv.values, nan=0.0), aspect="auto")
            ax.set_yticks(range(len(piv.index)))
            ax.set_yticklabels(piv.index, fontsize=7)
            ax.set_xticks(range(len(piv.columns)))
            ax.set_xticklabels(piv.columns)
            ax.set_title("Median of top‑10 WSD by fragment × layer")
            fig.colorbar(im, ax=ax, shrink=0.7)
            fig.tight_layout()
            fig.savefig(
                Path(SUMMARY_OUTDIR) / "plots/frag_by_layer_wsd_top10_med_median.pdf"
            )
            plt.close(fig)
        except Exception as e:
            print("[warn] plotting failed:", e)

    print(f"✓ Wrote WSD-centric summaries to: {SUMMARY_OUTDIR}")


if __name__ == "__main__":
    main()
