#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Consolidate 'pointer' experiment outputs into a compact, reviewer-friendly summary.

Now supports a root directory that contains subfolders:
  <root>/{suite,must_adds,stability}/...

Usage examples:
  uv run experiments/summarize_pointer_results.py --root_dir experiments/pointer_suite_all --out_dir experiments/pointer_suite_all/report
  # or explicit dirs:
  uv run experiments/summarize_pointer_results.py --suite_dir experiments/pointer_suite_all/suite --must_dir experiments/pointer_suite_all/must_adds --stab_dir experiments/pointer_suite_all/stability --out_dir experiments/pointer_suite_all/report
"""

import argparse
import json
import os
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd


# ----------------------------
# filesystem helpers
# ----------------------------
def ensure_outdir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def first_existing(paths: List[str]) -> Optional[str]:
    for p in paths:
        if p and os.path.exists(p):
            return p
    return None


def safe_read_csv(path: Optional[str]) -> Optional[pd.DataFrame]:
    if not path or not os.path.exists(path):
        return None
    try:
        return pd.read_csv(path)
    except Exception as e:
        print(f"[skip] CSV read failed: {path} ({e})")
        return None


def safe_read_json(path: Optional[str]) -> Optional[dict]:
    if not path or not os.path.exists(path):
        return None
    try:
        with open(path, "r") as f:
            return json.load(f)
    except Exception as e:
        print(f"[skip] JSON read failed: {path} ({e})")
        return None


# ----------------------------
# math helpers (no SciPy dependency)
# ----------------------------
def spearman_np(x: np.ndarray, y: np.ndarray) -> Optional[float]:
    """Spearman rho via Pearson on ranks (handles NaNs)."""
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    m = ~(np.isnan(x) | np.isnan(y))
    x = x[m]
    y = y[m]
    if x.size < 3:
        return None
    xr = pd.Series(x).rank(method="average").to_numpy()
    yr = pd.Series(y).rank(method="average").to_numpy()
    sx = xr.std()
    sy = yr.std()
    if sx <= 0 or sy <= 0:
        return None
    return float(((xr - xr.mean()) * (yr - yr.mean())).mean() / (sx * sy))


# ----------------------------
# schema normalizers
# ----------------------------
def normalize_scatter(df_raw: pd.DataFrame) -> pd.DataFrame:
    """
    Normalize to long-form with columns:
      task ['ring'|'paren'], layer, head, pointer_mass,
      delta_margin_event, delta_margin_control,
      delta_acc_event (opt), delta_acc_control (opt)
    Accepts schema from pointer_suite or must_adds.
    """
    df = df_raw.copy()
    # numeric coercion
    for c in [
        "layer",
        "head",
        "pointer_mass",
        "pointer_ring",
        "pointer_paren",
        "delta_margin_event",
        "delta_margin_control",
        "delta_acc_event",
        "delta_acc_control",
        "dmargin_ring",
        "dmargin_paren",
        "dmargin_ctrl",
        "dacc_ring",
        "dacc_paren",
        "dacc_ctrl",
    ]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    # Case A: already long-form with 'task'
    if {"task", "delta_margin_event", "delta_margin_control"} <= set(df.columns):
        out = df.copy()
        if "pointer_mass" not in out.columns:
            pm = []
            for _, r in out.iterrows():
                pm.append(
                    r.get("pointer_ring")
                    if r.get("task") == "ring"
                    else r.get("pointer_paren")
                )
            out["pointer_mass"] = pm
        keep = [
            "task",
            "layer",
            "head",
            "pointer_mass",
            "delta_margin_event",
            "delta_margin_control",
            "delta_acc_event",
            "delta_acc_control",
        ]
        out = out[[c for c in keep if c in out.columns]]
        return out

    # Case B: wide-form from pointer_suite
    need = {"layer", "head", "dmargin_ctrl"} & set(df.columns)
    have_ring = "dmargin_ring" in df.columns
    have_parn = "dmargin_paren" in df.columns
    if not need or (not have_ring and not have_parn):
        raise ValueError("scatter CSV: unrecognized schema.")

    rows = []
    for _, r in df.iterrows():
        l, h = int(r["layer"]), int(r["head"])
        # ring row
        if have_ring and not pd.isna(r["dmargin_ring"]):
            rows.append(
                dict(
                    task="ring",
                    layer=l,
                    head=h,
                    pointer_mass=float(
                        r.get("pointer_mass", r.get("pointer_ring", np.nan))
                    ),
                    delta_margin_event=float(r["dmargin_ring"]),
                    delta_margin_control=float(r["dmargin_ctrl"]),
                    delta_acc_event=float(r.get("dacc_ring", np.nan))
                    if "dacc_ring" in df.columns
                    else np.nan,
                    delta_acc_control=float(r.get("dacc_ctrl", np.nan))
                    if "dacc_ctrl" in df.columns
                    else np.nan,
                )
            )
        # paren row
        if have_parn and not pd.isna(r["dmargin_paren"]):
            rows.append(
                dict(
                    task="paren",
                    layer=l,
                    head=h,
                    pointer_mass=float(
                        r.get("pointer_mass", r.get("pointer_paren", np.nan))
                    ),
                    delta_margin_event=float(r["dmargin_paren"]),
                    delta_margin_control=float(r["dmargin_ctrl"]),
                    delta_acc_event=float(r.get("dacc_paren", np.nan))
                    if "dacc_paren" in df.columns
                    else np.nan,
                    delta_acc_control=float(r.get("dacc_ctrl", np.nan))
                    if "dacc_ctrl" in df.columns
                    else np.nan,
                )
            )

    return pd.DataFrame(rows)


def normalize_redundancy(df_raw: pd.DataFrame) -> Optional[pd.DataFrame]:
    """
    Expect columns from must_adds (preferred):
      task, heads, pm1, pm2, dM_h1_evt, dM_h2_evt, dM_both_evt, dM_sum_evt,
      dM_both_ctl, dM_sum_ctl, dA_h1_evt, dA_h2_evt, dA_both_evt, dA_sum_evt
    """
    cols = set(df_raw.columns)
    needed = {
        "task",
        "heads",
        "pm1",
        "pm2",
        "dM_h1_evt",
        "dM_h2_evt",
        "dM_both_evt",
        "dM_sum_evt",
        "dM_both_ctl",
        "dM_sum_ctl",
    }
    if needed <= cols:
        df = df_raw.copy()
        for c in needed | {"dA_h1_evt", "dA_h2_evt", "dA_both_evt", "dA_sum_evt"}:
            if c in df.columns:
                df[c] = pd.to_numeric(df[c], errors="coerce")
        if "syn_evt" not in df.columns:
            df["syn_evt"] = df["dM_both_evt"] - df["dM_sum_evt"]
        if (
            "syn_ctl" not in df.columns
            and "dM_both_ctl" in df.columns
            and "dM_sum_ctl" in df.columns
        ):
            df["syn_ctl"] = df["dM_both_ctl"] - df["dM_sum_ctl"]
        return df
    return None


# ----------------------------
# report helpers
# ----------------------------
def render_head_table(df_long: pd.DataFrame, task: str, topn: int) -> pd.DataFrame:
    sub = df_long[df_long["task"] == task].copy()
    if sub.empty:
        return sub
    sub["ESI_margin"] = sub["delta_margin_event"] - sub["delta_margin_control"]
    sub = sub.sort_values(["ESI_margin", "pointer_mass"], ascending=False)
    cols = [
        "layer",
        "head",
        "pointer_mass",
        "delta_margin_event",
        "delta_margin_control",
        "ESI_margin",
        "delta_acc_event",
    ]
    cols = [c for c in cols if c in sub.columns]
    return sub.head(topn)[cols]


def default_recommended(df_long: pd.DataFrame) -> Dict[str, str]:
    out = {}
    for task in ("ring", "paren"):
        sub = df_long[df_long["task"] == task].copy()
        if sub.empty:
            continue
        sub["ESI_margin"] = sub["delta_margin_event"] - sub["delta_margin_control"]
        sub = sub.sort_values(["ESI_margin", "pointer_mass"], ascending=False)
        l, h = int(sub.iloc[0]["layer"]), int(sub.iloc[0]["head"])
        out[task] = f"L{l}H{h}"
        if len(sub) > 1:
            l2, h2 = int(sub.iloc[1]["layer"]), int(sub.iloc[1]["head"])
            out[f"{task}_backup"] = f"L{l2}H{h2}"
    return out


def make_md_report(
    headline: Dict,
    top_ring: pd.DataFrame,
    top_paren: pd.DataFrame,
    add_ring: Optional[pd.Series],
    add_paren: Optional[pd.Series],
    robust: Dict,
    stability: Dict,
    probes: Dict,
) -> str:
    def fmt_head_table(df: pd.DataFrame) -> str:
        if df is None or df.empty:
            return "_no heads found_\n"
        header = "| layer | head | pointer | ΔM evt | ΔM ctl | ESI | Δacc evt |\n|---:|---:|---:|---:|---:|---:|---:|\n"
        lines = []
        for _, r in df.iterrows():

            def g(k):
                return r[k] if k in r.index and pd.notna(r[k]) else np.nan

            lines.append(
                f"| {int(g('layer'))} | {int(g('head'))} | {float(g('pointer_mass')):.3f} "
                f"| {float(g('delta_margin_event')):.3f} | {float(g('delta_margin_control')):.3f} "
                f"| {float(g('ESI_margin')):.3f} | {float(g('delta_acc_event')) if pd.notna(g('delta_acc_event')) else float('nan'):.3f} |"
            )
        return header + "\n".join(lines) + "\n"

    def fmt_add_row(row: Optional[pd.Series]) -> str:
        if row is None or len(row) == 0:
            return "_not available_"

        def f(k, default=np.nan):
            return (
                float(row.get(k, default)) if pd.notna(row.get(k, default)) else np.nan
            )

        return (
            f"**{row.get('heads', '?')}** — "
            f"ΔM_both(evt)={f('dM_both_evt'):.3f}, ΔM_sum(evt)={f('dM_sum_evt'):.3f}, syn(evt)={f('syn_evt'):.3f}; "
            f"ΔM_both(ctl)={f('dM_both_ctl'):.3f}, ΔM_sum(ctl)={f('dM_sum_ctl'):.3f}, syn(ctl)={f('syn_ctl'):.3f}."
        )

    def fmt_robust(task: str) -> str:
        d = robust.get(task, {})
        if not d:
            return "_not available_"
        parts = []
        if "head" in d:
            parts.append(f"head **{d['head']}**")
        if "rho_pm" in d and d["rho_pm"] is not None:
            parts.append(f"ρ(pointer, span)={d['rho_pm']:.3f}")
        if "rho_dM" in d and d["rho_dM"] is not None:
            parts.append(f"ρ(ΔM, span)={d['rho_dM']:.3f}")
        if "n_bins" in d:
            parts.append(f"bins={d['n_bins']}")
        return ", ".join(parts)

    def fmt_stab(task: str) -> str:
        d = stability.get(task, {})
        if not d:
            return "_not available_"
        pieces = []
        if "spearman" in d:
            pieces.append(f"ρ={float(d['spearman']):.4f}")
        if "jaccard_top5" in d and "jaccard_top10" in d:
            pieces.append(
                f"J@5={float(d['jaccard_top5']):.2f}, J@10={float(d['jaccard_top10']):.2f}"
            )
        return ", ".join(pieces) if pieces else "_not available_"

    def fmt_probes(task: str) -> str:
        d = probes.get(task, {})
        if not d:
            return "_not available_"
        bits = []
        if "pre_ov_acc" in d:
            bits.append(f"pre‑OV={float(d['pre_ov_acc']):.3f}")
        if "post_ov_acc" in d:
            bits.append(f"post‑OV={float(d['post_ov_acc']):.3f}")
        if "chance" in d:
            bits.append(f"chance={float(d['chance']):.3f}")
        return ", ".join(bits) if bits else "_not available_"

    lines = []
    lines.append("# Pointer‑head Summary\n")
    lines.append("## Headline\n")
    lines.append(f"- **Ring:** {headline.get('ring_summary', 'n/a')}")
    lines.append(f"- **Paren:** {headline.get('paren_summary', 'n/a')}\n")

    lines.append("## Top heads by event‑specificity (ESI) and pointer mass\n")
    lines.append("### Rings\n")
    lines.append(fmt_head_table(top_ring))
    lines.append("### Parentheses\n")
    lines.append(fmt_head_table(top_paren))

    lines.append("## Two‑head additivity (events vs controls)\n")
    lines.append(f"- Rings: {fmt_add_row(add_ring)}")
    lines.append(f"- Parentheses: {fmt_add_row(add_paren)}\n")

    lines.append("## Robustness to structure\n")
    lines.append(f"- Rings: {fmt_robust('ring')}")
    lines.append(f"- Parentheses: {fmt_robust('paren')}\n")

    lines.append("## Stability across checkpoints\n")
    lines.append(f"- Rings: {fmt_stab('ring')}")
    lines.append(f"- Parentheses: {fmt_stab('paren')}\n")

    lines.append("## Value‑stream probes\n")
    lines.append(f"- Rings: {fmt_probes('ring')}")
    lines.append(f"- Parentheses: {fmt_probes('paren')}\n")

    return "\n".join(lines)


# ----------------------------
# main
# ----------------------------
def main():
    ap = argparse.ArgumentParser(description="Summarize pointer results.")
    ap.add_argument(
        "--root_dir",
        default="",
        help="Root containing suite/, must_adds/, stability/ (optional).",
    )
    ap.add_argument(
        "--suite_dir", default="", help="Override path to suite dir (optional)."
    )
    ap.add_argument(
        "--must_dir", default="", help="Override path to must_adds dir (optional)."
    )
    ap.add_argument(
        "--stab_dir", default="", help="Override path to stability dir (optional)."
    )
    ap.add_argument("--out_dir", default="experiments/pointer_report")
    ap.add_argument("--topn", type=int, default=8)
    args = ap.parse_args()

    # auto-fill from root_dir if specific dirs are not given
    suite_dir = args.suite_dir or (
        os.path.join(args.root_dir, "suite") if args.root_dir else ""
    )
    must_dir = args.must_dir or (
        os.path.join(args.root_dir, "must_adds") if args.root_dir else ""
    )
    stab_dir = args.stab_dir or (
        os.path.join(args.root_dir, "stability") if args.root_dir else ""
    )

    ensure_outdir(args.out_dir)

    # ---- resolve files with robust fallbacks
    def p(dirpath, name):
        return os.path.join(dirpath, name) if dirpath else ""

    # Prefer suite for these, else stability
    summary_csv = first_existing(
        [
            p(suite_dir, "pointer_suite_summary.csv"),
            p(stab_dir, "pointer_suite_summary.csv"),
        ]
    )
    robust_json = first_existing(
        [
            p(suite_dir, "robustness_curves.json"),
            p(stab_dir, "robustness_curves.json"),
        ]
    )
    value_json = first_existing(
        [
            p(suite_dir, "value_probe.json"),
            p(stab_dir, "value_probe.json"),
        ]
    )

    # Prefer must_adds for scatter/redundancy, else suite, else stability
    scatter_csv = first_existing(
        [
            p(must_dir, "pointer_vs_global_scatter.csv"),
            p(suite_dir, "pointer_vs_global_scatter.csv"),
            p(stab_dir, "pointer_vs_global_scatter.csv"),
        ]
    )
    redundancy_csv = first_existing(
        [
            p(must_dir, "redundancy_table.csv"),
            p(suite_dir, "redundancy_table.csv"),
            p(stab_dir, "redundancy_table.csv"),
        ]
    )

    # Stability metrics: prefer stability/; fall back nowhere
    stability_json = first_existing(
        [
            p(stab_dir, "stability_summary.json"),
        ]
    )

    print("[inputs]")
    print("  summary_csv     :", summary_csv or "<missing>")
    print("  scatter_csv     :", scatter_csv or "<missing>")
    print("  redundancy_csv  :", redundancy_csv or "<missing>")
    print("  robustness_json :", robust_json or "<missing>")
    print("  value_json      :", value_json or "<missing>")
    print("  stability_json  :", stability_json or "<missing>")

    # --- read / normalize
    df_summary = safe_read_csv(summary_csv)  # optional (currently unused)
    df_sc_raw = safe_read_csv(scatter_csv)
    df_long = pd.DataFrame()
    if df_sc_raw is not None:
        try:
            df_sc = normalize_scatter(df_sc_raw)
            df_long = df_sc.copy()
            df_long["ESI_margin"] = (
                df_long["delta_margin_event"] - df_long["delta_margin_control"]
            )
            out_long = os.path.join(args.out_dir, "head_metrics_long.csv")
            df_long.to_csv(out_long, index=False)
            print(f"[ok] wrote {out_long}")
        except Exception as e:
            print(f"[warn] could not normalize scatter CSV: {e}")

    # --- top tables
    def render_head_table(df: pd.DataFrame, task: str, topn: int) -> pd.DataFrame:
        sub = df[df["task"] == task].copy()
        if sub.empty:
            return sub
        sub = sub.sort_values(["ESI_margin", "pointer_mass"], ascending=False)
        cols = [
            "layer",
            "head",
            "pointer_mass",
            "delta_margin_event",
            "delta_margin_control",
            "ESI_margin",
            "delta_acc_event",
        ]
        cols = [c for c in cols if c in sub.columns]
        return sub.head(topn)[cols]

    top_ring = (
        render_head_table(df_long, "ring", args.topn)
        if not df_long.empty
        else pd.DataFrame()
    )
    top_paren = (
        render_head_table(df_long, "paren", args.topn)
        if not df_long.empty
        else pd.DataFrame()
    )
    if not top_ring.empty:
        top_ring.to_csv(os.path.join(args.out_dir, "top_heads_ring.csv"), index=False)
        print(f"[ok] wrote {os.path.join(args.out_dir, 'top_heads_ring.csv')}")
    if not top_paren.empty:
        top_paren.to_csv(os.path.join(args.out_dir, "top_heads_paren.csv"), index=False)
        print(f"[ok] wrote {os.path.join(args.out_dir, 'top_heads_paren.csv')}")

    # --- headline summaries
    headline = {}
    for task, label in [("ring", "ring_summary"), ("paren", "paren_summary")]:
        sub = df_long[df_long["task"] == task] if not df_long.empty else pd.DataFrame()
        if sub.empty:
            headline[label] = "no heads found"
            continue
        best = sub.sort_values(["ESI_margin", "pointer_mass"], ascending=False).iloc[0]
        headline[label] = (
            f"{int(np.sum(sub['ESI_margin'] > 0))}/{len(sub)} heads with ESI>0; "
            f"top {task} head L{int(best['layer'])}H{int(best['head'])} "
            f"(pointer={float(best['pointer_mass']):.3f}, ESI={float(best['ESI_margin']):.3f})"
        )

    # --- redundancy / additivity
    add_ring = add_paren = None
    df_red_raw = safe_read_csv(redundancy_csv)
    df_red = (
        normalize_redundancy(df_red_raw)
        if isinstance(df_red_raw, pd.DataFrame)
        else None
    )
    if isinstance(df_red, pd.DataFrame) and not df_red.empty:

        def best_pair(df_red: pd.DataFrame, task: str) -> Optional[pd.Series]:
            sub = df_red[df_red["task"] == task]
            if sub.empty:
                return None
            return sub.sort_values("dM_both_evt", ascending=False).iloc[0]

        add_ring = best_pair(df_red, "ring")
        add_paren = best_pair(df_red, "paren")
        out_add = os.path.join(args.out_dir, "additivity_summary.csv")
        df_red.to_csv(out_add, index=False)
        print(f"[ok] wrote {out_add}")

    # --- robustness: correlate curves for default heads (if available)
    robustness = {}
    rdata = safe_read_json(robust_json)

    def pick_head_name(
        available: Dict[str, dict], preferred: List[str]
    ) -> Optional[str]:
        if not isinstance(available, dict):
            return None
        for p in preferred:
            if p in available:
                return p
        return sorted(available.keys())[0] if available else None

    def extract_curves(entry: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        def vec(d, *keys):
            for k in keys:
                if k in d and isinstance(d[k], (list, tuple)):
                    try:
                        return np.array(d[k], dtype=float)
                    except Exception:
                        pass
            return np.array([])

        x = vec(
            entry, "x", "bins", "span_bins", "depth_bins", "span", "depth", "distance"
        )
        pm = vec(entry, "pointer_mass", "pm", "pointer")
        dM = vec(entry, "delta_margin", "dM", "deltaM", "d_margin")
        return x, pm, dM

    if isinstance(rdata, dict):
        for task in ("ring", "paren"):
            tdata = rdata.get(task, {})
            if not isinstance(tdata, dict) or not tdata:
                continue
            preferred = ["L2H7", "L3H1"] if task == "ring" else ["L2H3"]
            head_name = pick_head_name(tdata, preferred)
            if not head_name:
                continue
            x, pm, dM = extract_curves(tdata.get(head_name, {}))
            rd = {"head": head_name}
            if x.size and pm.size:
                rd["rho_pm"], rd["n_bins"] = spearman_np(x, pm), int(x.size)
            if x.size and dM.size:
                rd["rho_dM"] = spearman_np(x, dM)
            robustness[task] = rd

    # --- stability metrics
    stability = {}
    sdata = safe_read_json(stability_json)
    if isinstance(sdata, dict):
        for task in ("ring", "paren"):
            d = sdata.get(task, {})
            if isinstance(d, dict) and d:
                keep = {}
                for k in ("spearman", "jaccard_top5", "jaccard_top10"):
                    if k in d:
                        keep[k] = d[k]
                stability[task] = keep

    # --- value probes
    probes = {}
    vdata = safe_read_json(value_json)
    if isinstance(vdata, dict):
        for task in ("ring", "paren"):
            d = vdata.get(task, {})
            if not isinstance(d, dict):
                continue
            keep = {}
            for k in (
                "pre_ov_acc",
                "post_ov_acc",
                "chance",
                "pre_acc",
                "post_acc",
                "chance_acc",
            ):
                if k in d:
                    keep[k] = d[k]
            if "pre_acc" in keep and "pre_ov_acc" not in keep:
                keep["pre_ov_acc"] = keep.pop("pre_acc")
            if "post_acc" in keep and "post_ov_acc" not in keep:
                keep["post_ov_acc"] = keep.pop("post_acc")
            if "chance_acc" in keep and "chance" not in keep:
                keep["chance"] = keep.pop("chance_acc")
            if keep:
                probes[task] = keep

    # --- recommendations (heads to reuse downstream)
    recommended = default_recommended(df_long) if not df_long.empty else {}

    # --- assemble report
    def make_md_report(
        headline: Dict,
        top_ring: pd.DataFrame,
        top_paren: pd.DataFrame,
        add_ring: Optional[pd.Series],
        add_paren: Optional[pd.Series],
        robust: Dict,
        stability: Dict,
        probes: Dict,
    ) -> str:
        def fmt_head_table(df: pd.DataFrame) -> str:
            if df is None or df.empty:
                return "_no heads found_\n"
            header = "| layer | head | pointer | ΔM evt | ΔM ctl | ESI | Δacc evt |\n|---:|---:|---:|---:|---:|---:|---:|\n"
            lines = []
            for _, r in df.iterrows():

                def g(k):
                    return r[k] if k in r.index and pd.notna(r[k]) else np.nan

                lines.append(
                    f"| {int(g('layer'))} | {int(g('head'))} | {float(g('pointer_mass')):.3f} "
                    f"| {float(g('delta_margin_event')):.3f} | {float(g('delta_margin_control')):.3f} "
                    f"| {float(g('ESI_margin')):.3f} | {float(g('delta_acc_event')) if pd.notna(g('delta_acc_event')) else float('nan'):.3f} |"
                )
            return header + "\n".join(lines) + "\n"

        def fmt_add_row(row: Optional[pd.Series]) -> str:
            if row is None or len(row) == 0:
                return "_not available_"

            def f(k, default=np.nan):
                return (
                    float(row.get(k, default))
                    if pd.notna(row.get(k, default))
                    else np.nan
                )

            return (
                f"**{row.get('heads', '?')}** — "
                f"ΔM_both(evt)={f('dM_both_evt'):.3f}, ΔM_sum(evt)={f('dM_sum_evt'):.3f}, syn(evt)={f('syn_evt'):.3f}; "
                f"ΔM_both(ctl)={f('dM_both_ctl'):.3f}, ΔM_sum(ctl)={f('dM_sum_ctl'):.3f}, syn(ctl)={f('syn_ctl'):.3f}."
            )

        def fmt_robust(task: str) -> str:
            d = robust.get(task, {})
            if not d:
                return "_not available_"
            parts = []
            if "head" in d:
                parts.append(f"head **{d['head']}**")
            if "rho_pm" in d and d["rho_pm"] is not None:
                parts.append(f"ρ(pointer, span)={d['rho_pm']:.3f}")
            if "rho_dM" in d and d["rho_dM"] is not None:
                parts.append(f"ρ(ΔM, span)={d['rho_dM']:.3f}")
            if "n_bins" in d:
                parts.append(f"bins={d['n_bins']}")
            return ", ".join(parts)

        def fmt_stab(task: str) -> str:
            d = stability.get(task, {})
            if not d:
                return "_not available_"
            pieces = []
            if "spearman" in d:
                pieces.append(f"ρ={float(d['spearman']):.4f}")
            if "jaccard_top5" in d and "jaccard_top10" in d:
                pieces.append(
                    f"J@5={float(d['jaccard_top5']):.2f}, J@10={float(d['jaccard_top10']):.2f}"
                )
            return ", ".join(pieces) if pieces else "_not available_"

        def fmt_probes(task: str) -> str:
            d = probes.get(task, {})
            if not d:
                return "_not available_"
            bits = []
            if "pre_ov_acc" in d:
                bits.append(f"pre‑OV={float(d['pre_ov_acc']):.3f}")
            if "post_ov_acc" in d:
                bits.append(f"post‑OV={float(d['post_ov_acc']):.3f}")
            if "chance" in d:
                bits.append(f"chance={float(d['chance']):.3f}")
            return ", ".join(bits) if bits else "_not available_"

        lines = []
        lines.append("# Pointer‑head Summary\n")
        lines.append("## Headline\n")
        lines.append(f"- **Ring:** {headline.get('ring_summary', 'n/a')}")
        lines.append(f"- **Paren:** {headline.get('paren_summary', 'n/a')}\n")

        lines.append("## Top heads by event‑specificity (ESI) and pointer mass\n")
        lines.append("### Rings\n")
        lines.append(fmt_head_table(top_ring))
        lines.append("### Parentheses\n")
        lines.append(fmt_head_table(top_paren))

        lines.append("## Two‑head additivity (events vs controls)\n")
        lines.append(f"- Rings: {fmt_add_row(add_ring)}")
        lines.append(f"- Parentheses: {fmt_add_row(add_paren)}\n")

        lines.append("## Robustness to structure\n")
        lines.append(f"- Rings: {fmt_robust('ring')}")
        lines.append(f"- Parentheses: {fmt_robust('paren')}\n")

        lines.append("## Stability across checkpoints\n")
        lines.append(f"- Rings: {fmt_stab('ring')}")
        lines.append(f"- Parentheses: {fmt_stab('paren')}\n")

        lines.append("## Value‑stream probes\n")
        lines.append(f"- Rings: {fmt_probes('ring')}")
        lines.append(f"- Parentheses: {fmt_probes('paren')}\n")

        return "\n".join(lines)

    md = make_md_report(
        headline=headline,
        top_ring=top_ring,
        top_paren=top_paren,
        add_ring=add_ring,
        add_paren=add_paren,
        robust=robustness,
        stability=stability,
        probes=probes,
    )
    with open(os.path.join(args.out_dir, "report.md"), "w") as f:
        f.write(md)
    print(f"[ok] wrote {os.path.join(args.out_dir, 'report.md')}")

    out_json = dict(
        headline=headline,
        top_ring=top_ring.to_dict(orient="records") if not top_ring.empty else [],
        top_paren=top_paren.to_dict(orient="records") if not top_paren.empty else [],
        additivity=dict(
            ring=(add_ring.to_dict() if isinstance(add_ring, pd.Series) else None),
            paren=(add_paren.to_dict() if isinstance(add_paren, pd.Series) else None),
        ),
        robustness=robustness,
        stability=stability,
        probes=probes,
        recommended_heads=default_recommended(df_long) if not df_long.empty else {},
        inputs=dict(
            summary_csv=summary_csv,
            scatter_csv=scatter_csv,
            redundancy_csv=redundancy_csv,
            robustness_json=robust_json,
            value_json=value_json,
            stability_json=stability_json,
        ),
    )
    with open(os.path.join(args.out_dir, "report.json"), "w") as f:
        json.dump(out_json, f, indent=2)
    print(f"[ok] wrote {os.path.join(args.out_dir, 'report.json')}")

    if not df_long.empty:
        rec = default_recommended(df_long)
        with open(os.path.join(args.out_dir, "recommended_heads.json"), "w") as f:
            json.dump(rec, f, indent=2)
        print(f"[ok] wrote {os.path.join(args.out_dir, 'recommended_heads.json')}")

    print("\n[done] Summary ready.")


if __name__ == "__main__":
    main()
