#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Summarize fragment_scan outputs across many runs (layers, SAE sizes).
Reads every `final_metrics.npz`, aggregates NaN-safely, and emits CSV/Parquet summaries.

USAGE: just edit the CONFIG block below and run:
    uv run lmkit/sparse/summarize_fragment_scans.py
"""

from __future__ import annotations

import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm.auto import tqdm  # ← progress bars

# =========================
# ======= CONFIG ==========
# =========================

# Root directories to search recursively for *final_metrics.npz*
# You can list each SAE-size’s output root here, or one parent that contains them all.

frag_dir = "fragment_scan_output_all"
saes = [
    "relu_4x_e9a211",
    "relu_8x_95aa15",
    # "relu_16x_b0334a",
]

SCAN_ROOTS = []
for sae in saes:
    for i in range(1, 6):
        SCAN_ROOTS.append(f"{frag_dir}/run_{sae}_layer{i}_layer{i}")

# Where to write the consolidated summaries:
SUMMARY_OUTDIR = "fragment_scan_summaries"

# Metrics to load; AUROC often disabled during scans, so default to two:
METRICS = ("pb_corr", "wsd_mean")

# For ranking & summaries:
TOPK_FEATURES = 10  # top-k per (run, fragment, metric)
SHORTLIST_N = 25  # top-N joint (pb & wsd) features exported to shortlist
PARETO_TAU = 0.80  # percentile threshold for Pareto-high-high labelling
PLOT_TOP_N = 20  # top-N fragments to plot in heatmap

# Optional: write a big long-form table with top-M per metric (can be large)
WRITE_BIG_PARQUET = True
BIG_PARQUET_TOP_M = 50

# Optional plots
MAKE_PLOTS = True

# =========================
# ====== HELPERS ==========
# =========================


@dataclass
class RunMeta:
    path: Path  # directory containing final_metrics.npz
    layer: Optional[int]  # parsed from directory name (layerX), may be None
    sae_tag: str  # guessed from path (e.g., relu_4x_...), else "unknown"
    npz_path: Path  # full path to final_metrics.npz


def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def find_all_npz(roots: List[str]) -> List[RunMeta]:
    metas: List[RunMeta] = []
    # Show a quick bar over roots; inner rglob can be many files, so we buffer results
    for root in tqdm(roots, desc="Scanning roots", leave=False):
        root = Path(root)
        if not root.exists():
            continue
        for npz in root.rglob("final_metrics.npz"):
            d = npz.parent
            m_layer = re.search(r"layer(\d+)", str(d))
            layer = int(m_layer.group(1)) if m_layer else None
            m_sae = re.search(r"(relu_\d+x_[A-Za-z0-9]+)", str(d))
            sae = m_sae.group(1) if m_sae else "unknown"
            metas.append(RunMeta(path=d, layer=layer, sae_tag=sae, npz_path=npz))
    return metas


def load_final_npz(npz_path: Path) -> Dict[str, Dict[str, np.ndarray]]:
    """
    Returns: dict[fragment][metric] -> (K,)
    Keys in npz are like "fg:primary_amine/pb_corr" or "ring:benzene/wsd_mean".
    """
    out: Dict[str, Dict[str, np.ndarray]] = {}
    with np.load(npz_path, allow_pickle=False) as z:
        for key in z.files:
            if "/" not in key:
                continue
            frag, metric = key.rsplit("/", 1)
            if metric not in METRICS:
                continue
            vec = np.asarray(z[key])
            vec = np.nan_to_num(
                vec, nan=np.nan, posinf=np.nan, neginf=np.nan
            )  # keep NaN to ignore in stats
            out.setdefault(frag, {})[metric] = vec
    return out


def parse_state_meta(dir_path: Path) -> Optional[dict]:
    p = dir_path / "state_meta.json"
    if not p.exists():
        return None
    try:
        with open(p) as f:
            return json.load(f)
    except Exception:
        return None


def robust_topk(vec: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
    v = vec.copy()
    valid = np.isfinite(v)
    if not valid.any():
        return np.array([]), np.array([], dtype=int)
    v = v[valid]
    idx = np.where(valid)[0]
    if v.size <= k:
        order = np.argsort(-v)
        return v[order], idx[order]
    order = np.argpartition(-v, k - 1)[:k]
    order = order[np.argsort(-v[order])]
    return v[order], idx[order]


def quantiles(vec: np.ndarray, qs=(0.5, 0.9, 0.95, 0.99)) -> Dict[str, float]:
    valid = vec[np.isfinite(vec)]
    if valid.size == 0:
        return {f"q{int(q * 100)}": np.nan for q in qs}
    return {f"q{int(q * 100)}": float(np.nanquantile(valid, q)) for q in qs}


# MODIFIED SECTION 1: Composite rank is now driven by WSD
def composite_rank(
    pb: np.ndarray, wsd: np.ndarray, shortlist_n: int
) -> List[Tuple[int, float, float, float]]:
    """Ranks features primarily by WSD score for the shortlist."""
    K = max(len(pb), len(wsd))
    pbv = np.nan_to_num(pb, nan=np.nan)
    wv = np.nan_to_num(wsd, nan=np.nan)

    def z(v):
        mu = np.nanmean(v) if np.isfinite(v).any() else 0.0
        sd = np.nanstd(v) if np.isfinite(v).any() else 1.0
        sd = max(sd, 1e-12)
        zz = (v - mu) / sd
        return np.nan_to_num(zz, nan=0.0, posinf=0.0, neginf=0.0)

    # The composite score is now just the Z-score of the WSD
    zw = z(wv)
    comp = zw

    order = np.argsort(-comp)[:shortlist_n]
    out = []
    for i in order:
        sp = float(pbv[i]) if i < len(pbv) else np.nan
        sw = float(wv[i]) if i < len(wv) else np.nan
        out.append((int(i), sp, sw, float(comp[i])))
    return out


def quadrant_label(pb: float, wsd: float, pb_thr: float, wsd_thr: float) -> str:
    hi_pb = (pb >= pb_thr) if np.isfinite(pb) else False
    hi_wsd = (wsd >= wsd_thr) if np.isfinite(wsd) else False
    if hi_pb and hi_wsd:
        return "high_both"
    if hi_pb and not hi_wsd:
        return "high_pb_low_wsd"
    if not hi_pb and hi_wsd:
        return "high_wsd_low_pb"
    return "low_both"


def guess_thresholds_across_runs(values: List[float], fallback: float, q=0.8) -> float:
    v = np.array([x for x in values if np.isfinite(x)])
    if v.size == 0:
        return fallback
    return float(np.nanquantile(v, q))


# =========================
# ======= MAIN ============
# =========================


def main():
    ensure_dir(Path(SUMMARY_OUTDIR))

    # 1) discover runs
    runs = find_all_npz(SCAN_ROOTS)
    if not runs:
        print("No final_metrics.npz found under:", SCAN_ROOTS)
        return

    # 2) load each run → per-fragment metric vectors
    catalog_rows = []
    per_run_data = []  # (RunMeta, metrics: Dict[frag]->Dict[metric]->(K,))
    for r in tqdm(runs, desc="Loading runs"):
        d = load_final_npz(r.npz_path)
        # estimate K from any metric
        K = 0
        n_frags = len(d)
        any_vec = None
        for frag in d:
            for m in METRICS:
                if m in d[frag]:
                    any_vec = d[frag][m]
                    break
            if any_vec is not None:
                break
        K = len(any_vec) if any_vec is not None else 0
        meta = parse_state_meta(r.path)
        mols = meta.get("mols_processed", None) if meta else None
        catalog_rows.append(
            dict(
                run_dir=str(r.path),
                layer=r.layer,
                sae_tag=r.sae_tag,
                K=K,
                fragments=n_frags,
                mols_processed=mols,
            )
        )
        per_run_data.append((r, d))

    runs_catalog = pd.DataFrame(catalog_rows).sort_values(
        ["layer", "sae_tag", "run_dir"]
    )
    runs_catalog.to_csv(Path(SUMMARY_OUTDIR) / "runs_catalog.csv", index=False)

    # 3) build long-form table of top-M features (per run × fragment × metric)
    long_rows = []
    shortlist_rows = []
    all_pb_tops, all_wsd_tops = [], []

    # Count total (runs, frags) for nicer progress
    total_frag_items = sum(len(d) for _, d in per_run_data)
    with tqdm(total=total_frag_items, desc="Indexing top features") as pbar:
        for r, d in per_run_data:
            for frag, mdict in d.items():
                pb = mdict.get("pb_corr", np.full(0, np.nan))
                wsd = mdict.get("wsd_mean", np.full(0, np.nan))

                for metric_name, vec in (("pb_corr", pb), ("wsd_mean", wsd)):
                    vals, idxs = robust_topk(vec, BIG_PARQUET_TOP_M)
                    for v, i in zip(vals, idxs):
                        long_rows.append(
                            dict(
                                run_dir=str(r.path),
                                layer=r.layer,
                                sae_tag=r.sae_tag,
                                fragment=frag,
                                metric=metric_name,
                                feature_id=int(i),
                                score=float(v),
                            )
                        )

                top_pb_vals, _ = robust_topk(pb, TOPK_FEATURES)
                top_wsd_vals, _ = robust_topk(wsd, TOPK_FEATURES)
                if top_pb_vals.size:
                    all_pb_tops.extend([float(x) for x in top_pb_vals])
                if top_wsd_vals.size:
                    all_wsd_tops.extend([float(x) for x in top_wsd_vals])

                joint = composite_rank(pb, wsd, SHORTLIST_N)
                for fid, spb, sw, comp in joint:
                    shortlist_rows.append(
                        dict(
                            run_dir=str(r.path),
                            layer=r.layer,
                            sae_tag=r.sae_tag,
                            fragment=frag,
                            feature_id=fid,
                            pb_corr=spb,
                            wsd_mean=sw,
                            composite=comp,
                        )
                    )
                pbar.update(1)

    if WRITE_BIG_PARQUET and long_rows:
        df_long = pd.DataFrame(long_rows)
        df_long.to_parquet(Path(SUMMARY_OUTDIR) / "all_scores.parquet", index=False)

    df_short = pd.DataFrame(shortlist_rows).sort_values(
        ["fragment", "layer", "sae_tag", "composite"],
        ascending=[True, True, True, False],
    )
    df_short.to_csv(Path(SUMMARY_OUTDIR) / "top_features_shortlist.csv", index=False)

    # 4) fragment difficulty (global)
    frag_stats = []
    total_frag_items = sum(len(d) for _, d in per_run_data)
    with tqdm(total=total_frag_items, desc="Summarizing per-fragment stats") as pbar:
        for r, d in per_run_data:
            for frag, mdict in d.items():
                pb = mdict.get("pb_corr", np.full(0, np.nan))
                wsd = mdict.get("wsd_mean", np.full(0, np.nan))

                pb_valid = pb[np.isfinite(pb)] if pb.size else np.array([])
                wsd_valid = wsd[np.isfinite(wsd)] if wsd.size else np.array([])

                def bundle(vec):
                    if vec.size == 0:
                        return dict(
                            max=np.nan,
                            top10_med=np.nan,
                            top10_mean=np.nan,
                            **quantiles(vec),
                        )
                    vals, _ = robust_topk(vec, TOPK_FEATURES)
                    return dict(
                        max=float(np.nanmax(vec)),
                        top10_med=float(np.nanmedian(vals)) if vals.size else np.nan,
                        top10_mean=float(np.nanmean(vals)) if vals.size else np.nan,
                        **quantiles(vec),
                    )

                bpb = bundle(pb_valid)
                bwsd = bundle(wsd_valid)

                frag_stats.append(
                    dict(
                        run_dir=str(r.path),
                        layer=r.layer,
                        sae_tag=r.sae_tag,
                        fragment=frag,
                        pb_max=bpb["max"],
                        pb_top10_med=bpb["top10_med"],
                        pb_top10_mean=bpb["top10_mean"],
                        pb_q50=bpb["q50"],
                        pb_q90=bpb["q90"],
                        pb_q95=bpb["q95"],
                        pb_q99=bpb["q99"],
                        wsd_max=bwsd["max"],
                        wsd_top10_med=bwsd["top10_med"],
                        wsd_top10_mean=bwsd["top10_mean"],
                        wsd_q50=bwsd["q50"],
                        wsd_q90=bwsd["q90"],
                        wsd_q95=bwsd["q95"],
                        wsd_q99=bwsd["q99"],
                    )
                )
                pbar.update(1)

    df_frag_stats = pd.DataFrame(frag_stats)

    # 5) Aggregate by fragment × layer and fragment × SAE
    def agg_df(df: pd.DataFrame, group_cols: List[str]) -> pd.DataFrame:
        agg_cols = [
            c
            for c in df.columns
            if c not in (["run_dir", "fragment", "layer", "sae_tag"])
        ]
        out = df.groupby(group_cols, dropna=False)[agg_cols].agg(
            ["median", "mean", "std", "count"]
        )
        out.columns = ["__".join(c) for c in out.columns.to_flat_index()]
        out = out.reset_index()
        return out

    by_layer = agg_df(df_frag_stats, ["fragment", "layer"])
    by_layer.to_csv(
        Path(SUMMARY_OUTDIR) / "per_fragment_layer_summary.csv", index=False
    )

    by_sae = agg_df(df_frag_stats, ["fragment", "sae_tag"])
    by_sae.to_csv(Path(SUMMARY_OUTDIR) / "per_fragment_sae_summary.csv", index=False)

    # MODIFIED SECTION 2: Global difficulty ranking is now based on WSD
    frag_order = (
        df_frag_stats.groupby("fragment", dropna=False)["wsd_top10_med"]
        .median()
        .sort_values(ascending=False)
        .reset_index()
    )
    frag_order.rename(
        columns={"wsd_top10_med": "wsd_top10_med_global_median"}, inplace=True
    )
    frag_order.to_csv(
        Path(SUMMARY_OUTDIR) / "fragment_difficulty_by_WSD.csv", index=False
    )

    # 7) Quadrant analysis thresholds
    pb_thr = guess_thresholds_across_runs(all_pb_tops, fallback=0.2, q=PARETO_TAU)
    wsd_thr = guess_thresholds_across_runs(all_wsd_tops, fallback=0.0, q=PARETO_TAU)

    quad_rows = []
    total_frag_items = sum(len(d) for _, d in per_run_data)
    with tqdm(total=total_frag_items, desc="Quadrant assignment") as pbar:
        for r, d in per_run_data:
            for frag, mdict in d.items():
                pb = mdict.get("pb_corr", np.full(0, np.nan))
                wsd = mdict.get("wsd_mean", np.full(0, np.nan))
                best_pb = np.nanmax(pb) if pb.size else np.nan
                best_wsd = np.nanmax(wsd) if wsd.size else np.nan
                label = quadrant_label(best_pb, best_wsd, pb_thr, wsd_thr)
                quad_rows.append(
                    dict(fragment=frag, layer=r.layer, sae_tag=r.sae_tag, label=label)
                )
                pbar.update(1)

    df_quad = pd.DataFrame(quad_rows)
    quadrant_counts = (
        df_quad.groupby(["fragment", "label"]).size().reset_index(name="count")
    )
    quadrant_counts["frac"] = quadrant_counts.groupby("fragment")["count"].transform(
        lambda x: x / x.sum()
    )
    quadrant_counts.to_csv(Path(SUMMARY_OUTDIR) / "quadrant_counts.csv", index=False)

    # MODIFIED SECTION 3: Executive summary is now based on WSD
    lines = []
    lines.append("# Fragment-Scan Executive Summary (Ranked by WSD)\n")
    lines.append(f"- Runs found: **{len(runs)}**")
    lines.append(f"- Metrics aggregated: {', '.join(METRICS)}")
    lines.append(f"- Top-K used for summaries: K={TOPK_FEATURES}")
    lines.append(
        f"- Pareto tau for quadrant thresholds: {PARETO_TAU} (pb_thr≈{pb_thr:.3g}, wsd_thr≈{wsd_thr:.3g})\n"
    )

    lines.append("## Top fragments by within-sequence discriminativity (WSD)\n")
    top_frag = frag_order.head(15)
    for _, r0 in top_frag.iterrows():
        lines.append(
            f"- **{r0['fragment']}**: median top-10 WSD = {r0['wsd_top10_med_global_median']:.3g}"
        )

    if not by_layer.empty:
        lines.append("\n## Layer profiles (median of per-run wsd_top10_med)\n")
        pivot = by_layer.pivot(
            index="fragment", columns="layer", values="wsd_top10_med__median"
        ).fillna(np.nan)
        for frag in top_frag["fragment"].tolist()[:5]:
            if frag in pivot.index:
                vals = pivot.loc[frag]
                best_layer = int(vals.idxmax()) if np.isfinite(vals).any() else None
                best_val = float(vals.max()) if np.isfinite(vals).any() else np.nan
                lines.append(
                    f"- **{frag}** peaks at layer **{best_layer}** with median(top-10 WSD)≈{best_val:.3g}"
                )

    (Path(SUMMARY_OUTDIR) / "grand_summary_by_WSD.md").write_text("\n".join(lines))

    # MODIFIED SECTION 4: Optional plots are now based on WSD
    if MAKE_PLOTS and not by_layer.empty:
        try:
            import matplotlib.pyplot as plt

            ensure_dir(Path(SUMMARY_OUTDIR) / "plots")

            # --- MODIFICATION START ---

            # Get the names of the top N fragments from the global ranking
            top_n_fragments = frag_order.head(PLOT_TOP_N)["fragment"].tolist()

            # Create the full pivot table first
            piv_full = by_layer.pivot(
                index="fragment", columns="layer", values="wsd_top10_med__median"
            )

            # Filter the pivot table to only include the top N fragments, in the correct order
            piv_top_n = piv_full.reindex(top_n_fragments).dropna(how="all")

            # Use the filtered pivot table for plotting
            fig, ax = plt.subplots(figsize=(8, max(4, 0.24 * len(piv_top_n))))
            im = ax.imshow(np.nan_to_num(piv_top_n.values, nan=0.0), aspect="auto")

            ax.set_yticks(range(len(piv_top_n.index)))
            ax.set_yticklabels(piv_top_n.index, fontsize=7)
            ax.set_xticks(range(len(piv_top_n.columns)))
            ax.set_xticklabels(piv_top_n.columns)
            ax.set_title(
                f"Median of top-10 wsd_mean for Top {PLOT_TOP_N} Fragments by Layer"
            )  # Title updated

            # --- MODIFICATION END ---

            fig.colorbar(im, ax=ax, shrink=0.7)
            fig.tight_layout()
            fig.savefig(
                Path(SUMMARY_OUTDIR)
                / f"plots/frag_by_layer_wsd_top10_med_median_top_{PLOT_TOP_N}.pdf"
            )
            plt.close(fig)
        except Exception as e:
            print("[warn] plotting failed:", e)

    print(f"✓ Wrote summaries to: {SUMMARY_OUTDIR}")


if __name__ == "__main__":
    main()
