"""
Stratum analysis for V4-Pro x ProofNet#.

Computes:
  (1) per-problem signal-coverage cell for each method (Vanilla/Lean-Retry/SAF/Sample-Filter)
  (2) bootstrap 95% CI on per-cell counts and on Delta-vs-Vanilla
  (3) Vanilla -> Lean-Retry transition matrix (4x4)
  (4) case study candidates: TYPE_ONLY recoveries, durable SEM_ONLY, GTED false negatives

Data sources (all on V4-Pro x ProofNet# x 186 test, paths configurable
via env vars; see top-level README):
  - runs:   <RUNS_DIR>/{B1,B3,B4,B5}.jsonl
  - opus:   <JUDGE_DIR>/v4pro_{b1,b3,b4,b5}_opus.jsonl
  - gted:   <JUDGE_DIR>/v4pro_{b1,b3,b4,b5}_gted.jsonl
  - gold:   <GOLD_PATH>

Cell rule (uses Opus verdict for semantic; uses Opus's recorded elab_ok for type):
  TS = elab_ok ^ equiv;  SO = elab_ok ^ -equiv;  TO = -elab_ok ^ equiv;  BF = -elab_ok ^ -equiv
"""
from __future__ import annotations
import json, random, sys
import os
from pathlib import Path
from collections import Counter, defaultdict

sys.stdout.reconfigure(encoding="utf-8")

ROOT_RUNS  = Path(os.environ.get("RUNS_DIR", "../data/runs/proofnet_186/v4_pro"))
ROOT_JUDGE = Path(os.environ.get("JUDGE_DIR", "../data/judge"))
GOLD_PATH  = Path(os.environ.get("GOLD_PATH", "../data/proofnetsharp_test.json"))

METHODS = [("Vanilla","B1","b1"), ("Lean-Retry","B3","b3"),
           ("SAF","B4","b4"), ("Sample-Filter","B5","b5")]

def jl(path, bom=False):
    enc = "utf-8-sig" if bom else "utf-8"
    return [json.loads(l) for l in open(path, encoding=enc) if l.strip()]

def cell_of(elab_ok: bool, equiv: bool) -> str:
    if elab_ok and equiv: return "TS"
    if elab_ok and not equiv: return "SO"
    if (not elab_ok) and equiv: return "TO"
    return "BF"

gold = {x["id"]: x for x in json.load(open(GOLD_PATH, encoding="utf-8"))}

per_problem: dict[str, dict] = {}
for name, runtag, opustag in METHODS:
    runs = {r["problem_id"]: r for r in jl(ROOT_RUNS / f"{runtag}.jsonl")}
    opus = {v["problem_id"]: v for v in jl(ROOT_JUDGE / f"v4pro_{opustag}_opus.jsonl", bom=True)}
    gted = {g["problem_id"]: g for g in jl(ROOT_JUDGE / f"v4pro_{opustag}_gted.jsonl")}
    for pid, v in opus.items():
        elab_ok = bool(v["elab_ok"])
        equiv   = bool(v["equiv"])
        cell    = cell_of(elab_ok, equiv)
        d = per_problem.setdefault(pid, {})
        d[name] = {
            "elab_ok": elab_ok, "equiv": equiv, "cell": cell,
            "opus_reason": v.get("reason",""), "opus_conf": v.get("confidence",""),
            "opus_prompt": v.get("prompt",""),
            "gted_sim": gted.get(pid,{}).get("ted_similarity"),
            "run": runs.get(pid, {}),
        }

PIDS = sorted(per_problem.keys())
assert len(PIDS) == 186, f"expected 186 problems, got {len(PIDS)}"

print("="*70)
print("(1) Per-method cell counts on V4-Pro x ProofNet# (n=186)")
print("="*70)
print(f"{'Method':<16} {'TS':>4} {'SO':>4} {'TO':>4} {'BF':>4}  TC%   SFstrict%")
for name,_,_ in METHODS:
    cells = [per_problem[p][name]["cell"] for p in PIDS]
    c = Counter(cells)
    n = len(PIDS)
    tc = (c["TS"]+c["SO"])/n*100
    sf = c["TS"]/n*100
    print(f"{name:<16} {c['TS']:>4} {c['SO']:>4} {c['TO']:>4} {c['BF']:>4}  {tc:5.2f} {sf:6.2f}")

print()
print("="*70)
print("(2) Bootstrap 95% CI (10,000 resamples, problem-level)")
print("="*70)

random.seed(20260512)
N = len(PIDS)
B = 10_000

labels = {name: [per_problem[p][name]["cell"] for p in PIDS] for name,_,_ in METHODS}

def boot_metric(method_name: str, ref_name: str | None, cell: str):
    """Returns (point, lo, hi) for either count(cell) under method_name (ref=None)
       or count(cell, method) - count(cell, ref) when ref_name given."""
    arr_m = labels[method_name]
    arr_r = labels[ref_name] if ref_name else None
    samples = []
    for _ in range(B):
        idx = [random.randrange(N) for _ in range(N)]
        cm = sum(1 for i in idx if arr_m[i]==cell)
        if arr_r is None:
            samples.append(cm)
        else:
            cr = sum(1 for i in idx if arr_r[i]==cell)
            samples.append(cm - cr)
    samples.sort()
    lo = samples[int(0.025*B)]
    hi = samples[int(0.975*B)]
    if arr_r is None:
        point = sum(1 for x in arr_m if x==cell)
    else:
        point = sum(1 for x in arr_m if x==cell) - sum(1 for x in arr_r if x==cell)
    return point, lo, hi

print()
print("Absolute counts with 95% CI:")
print(f"{'Method':<16}  {'TS':>14}  {'SO':>14}  {'TO':>14}  {'BF':>14}")
for name,_,_ in METHODS:
    parts = []
    for cell in ["TS","SO","TO","BF"]:
        p,lo,hi = boot_metric(name, None, cell)
        parts.append(f"{p:3d} [{lo:3d},{hi:3d}]")
    print(f"{name:<16}  {parts[0]:>14}  {parts[1]:>14}  {parts[2]:>14}  {parts[3]:>14}")

print()
print("Delta vs Vanilla with 95% CI:")
print(f"{'Method':<16}  {'dTS':>14}  {'dSO':>14}  {'dTO':>14}  {'dBF':>14}")
for name,_,_ in METHODS:
    if name == "Vanilla": continue
    parts = []
    for cell in ["TS","SO","TO","BF"]:
        p,lo,hi = boot_metric(name, "Vanilla", cell)
        sign = "+" if p>=0 else ""
        parts.append(f"{sign}{p:3d} [{lo:+3d},{hi:+3d}]")
    print(f"{name:<16}  {parts[0]:>14}  {parts[1]:>14}  {parts[2]:>14}  {parts[3]:>14}")

print()
print("="*70)
print("(3) Per-problem transition matrix Vanilla -> Lean-Retry")
print("="*70)

CELL_ORDER = ["TS","SO","TO","BF"]
trans = {(a,b):0 for a in CELL_ORDER for b in CELL_ORDER}
trans_pids = defaultdict(list)
for pid in PIDS:
    v = per_problem[pid]["Vanilla"]["cell"]
    r = per_problem[pid]["Lean-Retry"]["cell"]
    trans[(v,r)] += 1
    trans_pids[(v,r)].append(pid)

print()
print("Rows = Vanilla cell; Columns = Lean-Retry cell. Cell = count of problems.")
print(f"{'':<8}" + "".join(f"{c:>6}" for c in CELL_ORDER) + f"{'row sum':>10}")
for v_cell in CELL_ORDER:
    row = [trans[(v_cell,r_cell)] for r_cell in CELL_ORDER]
    print(f"{v_cell:<8}" + "".join(f"{x:>6}" for x in row) + f"{sum(row):>10}")
print(f"{'col sum':<8}" + "".join(f"{sum(trans[(v,r)] for v in CELL_ORDER):>6}" for r in CELL_ORDER))

print()
print("Where did the +N TRUE_SUCCESS under Lean-Retry come from?")
ts_gain_sources = {(v_cell,"TS"):trans[(v_cell,"TS")] for v_cell in CELL_ORDER if v_cell != "TS"}
ts_stay = trans[("TS","TS")]
ts_loss = sum(trans[("TS",r)] for r in CELL_ORDER if r != "TS")
new_ts = sum(ts_gain_sources.values())
print(f"  TS->TS (already correct, stayed correct):  {ts_stay}")
print(f"  TS->{{SO,TO,BF}} (lost in retry):           {ts_loss}")
print(f"  TO->TS (type stratum recovered):           {trans[('TO','TS')]}")
print(f"  SO->TS (semantic stratum 'recovered'):     {trans[('SO','TS')]}")
print(f"  BF->TS (both-fail rescued):                {trans[('BF','TS')]}")
print(f"  Net TS change = {ts_stay + new_ts} - {ts_stay + ts_loss} = {(ts_stay+new_ts) - (ts_stay+ts_loss):+d}")

print()
print("="*70)
print("(4) Case study candidates")
print("="*70)

print()
print("(a) TYPE_ONLY -> TRUE_SUCCESS recoveries (Vanilla TO, Lean-Retry TS):")
print(f"    n = {trans[('TO','TS')]}")
for pid in trans_pids[("TO","TS")][:5]:
    op = per_problem[pid]["Vanilla"]["opus_reason"][:120]
    print(f"    - {pid}")
    print(f"        opus(vanilla): {op}")

print()
print("(b) SEM_ONLY that survives all three elab-feedback methods:")
durable_so = []
for pid in PIDS:
    cells = [per_problem[pid][m]["cell"] for m,_,_ in METHODS]
    if cells[0] == "SO" and all(c == "SO" for c in cells[1:]):
        durable_so.append(pid)
print(f"    n = {len(durable_so)}")
for pid in durable_so[:5]:
    op = per_problem[pid]["Vanilla"]["opus_reason"][:140]
    print(f"    - {pid}")
    print(f"        opus(vanilla): {op}")

print()
print("(c) GTED false negatives (Opus equiv=True, GTED similarity < 0.5):")
gted_fn = []
for pid in PIDS:
    d = per_problem[pid]["Vanilla"]
    if d["equiv"] and d["elab_ok"] and d["gted_sim"] is not None and d["gted_sim"] < 0.5:
        gted_fn.append((pid, d["gted_sim"]))
gted_fn.sort(key=lambda x: x[1])
print(f"    n = {len(gted_fn)} (Vanilla outputs only; counts will be much larger for B3/B4/B5)")
for pid, sim in gted_fn[:5]:
    op = per_problem[pid]["Vanilla"]["opus_reason"][:120]
    print(f"    - {pid}  GTED sim = {sim:.3f}")
    print(f"        opus(vanilla): {op}")

print()
print("    Same on Lean-Retry (B3) -- this is where the +26-37pp gap accumulates:")
gted_fn_b3 = []
for pid in PIDS:
    d = per_problem[pid]["Lean-Retry"]
    if d["equiv"] and d["elab_ok"] and d["gted_sim"] is not None and d["gted_sim"] < 0.5:
        gted_fn_b3.append((pid, d["gted_sim"]))
gted_fn_b3.sort(key=lambda x: x[1])
print(f"    n = {len(gted_fn_b3)}")
for pid, sim in gted_fn_b3[:5]:
    op = per_problem[pid]["Lean-Retry"]["opus_reason"][:120]
    print(f"    - {pid}  GTED sim = {sim:.3f}")
    print(f"        opus(retry): {op}")

print()
print("="*70)
print("DONE.")
