"""
Predictive framework for elab-feedback gains.

Tests the claim: TC gain under elab-feedback is approximately a stable
fraction of Vanilla off-diagonal mass, with stratum-specific recovery
rates that are method-invariant within a (model, dataset) cell and
benchmark-type-modulated across cells.

Outputs:
  (A) Per-stratum recovery rates on V4-Pro x ProofNet# (dual-judged):
      Are TO->TS, SO->TS, BF->TS rates stable across the three
      elab-feedback methods?
  (B) Recovery rate of Vanilla elab-fails across 6 (model, dataset) cells
      under Lean-Retry: does the rate cluster within benchmark, vary
      across benchmark complexity?
  (C) Linear fit: predicted ΔTC = α * Vanilla_elab_fail.
      Residual analysis.
"""
from __future__ import annotations
import json, math, sys
import os
from pathlib import Path
from collections import Counter, defaultdict

sys.stdout.reconfigure(encoding="utf-8")

ROOT_RUNS  = Path(os.environ.get("RUNS_DIR", "../data/runs/proofnet_186/v4_pro"))
ROOT_JUDGE = Path(os.environ.get("JUDGE_DIR", "../data/judge"))
MASTER     = Path(os.environ.get("MASTER_TABLE", "../data/judge/master_table.json"))

METHODS = [("Vanilla","B1","b1"), ("Lean-Retry","B3","b3"),
           ("SAF","B4","b4"), ("Sample-Filter","B5","b5")]

def jl(path, bom=False):
    enc = "utf-8-sig" if bom else "utf-8"
    return [json.loads(l) for l in open(path, encoding=enc) if l.strip()]

def cell_of(elab_ok, equiv):
    if elab_ok and equiv: return "TS"
    if elab_ok and not equiv: return "SO"
    if (not elab_ok) and equiv: return "TO"
    return "BF"

per_problem = {}
for name, runtag, opustag in METHODS:
    opus = {v["problem_id"]: v for v in jl(ROOT_JUDGE / f"v4pro_{opustag}_opus.jsonl", bom=True)}
    for pid, v in opus.items():
        d = per_problem.setdefault(pid, {})
        d[name] = {"elab_ok": bool(v["elab_ok"]),
                   "equiv": bool(v["equiv"]),
                   "cell": cell_of(bool(v["elab_ok"]), bool(v["equiv"]))}
PIDS = sorted(per_problem)
assert len(PIDS) == 186

print("="*72)
print("(A) Per-stratum recovery rate (Vanilla cell -> TS under method)")
print("    V4-Pro x ProofNet#, dual-judged (Opus)")
print("="*72)

vanilla_cell = {pid: per_problem[pid]["Vanilla"]["cell"] for pid in PIDS}
denom = Counter(vanilla_cell.values())

print(f"\nVanilla cell counts: TS={denom['TS']}, SO={denom['SO']}, TO={denom['TO']}, BF={denom['BF']}\n")

print(f"{'Method':<16}  {'TO->TS':>14}  {'SO->TS':>14}  {'BF->TS':>14}  {'TS retain':>14}")
recov_rates = {}
for name,_,_ in METHODS:
    if name == "Vanilla": continue
    n_tots = {"TO":0, "SO":0, "BF":0, "TS":0}
    for pid in PIDS:
        v = vanilla_cell[pid]
        r = per_problem[pid][name]["cell"]
        if v == "TS" and r == "TS": n_tots["TS"] += 1
        if v in ("TO","SO","BF") and r == "TS": n_tots[v] += 1
    rates = {k: n_tots[k] / denom[k] * 100 if denom[k] else 0.0 for k in n_tots}
    recov_rates[name] = (n_tots, rates)
    print(f"{name:<16}  {n_tots['TO']:>3}/{denom['TO']:<3} = {rates['TO']:>4.1f}%"
          f"  {n_tots['SO']:>3}/{denom['SO']:<3} = {rates['SO']:>4.1f}%"
          f"  {n_tots['BF']:>3}/{denom['BF']:<3} = {rates['BF']:>4.1f}%"
          f"  {n_tots['TS']:>3}/{denom['TS']:<3} = {rates['TS']:>4.1f}%")

print()
print("Cross-method spread (mean +/- max-min half-range):")
for src in ("TO","SO","BF","TS"):
    vals = [recov_rates[m][1][src] for m,_,_ in METHODS if m != "Vanilla"]
    mean = sum(vals)/len(vals); spread = (max(vals)-min(vals))/2
    print(f"  {src}->TS rate: mean = {mean:5.1f}%, half-spread = {spread:4.1f}%  ({min(vals):.1f}-{max(vals):.1f})")

print()
print("="*72)
print("(B) Recovery rate across 6 (model, dataset) cells under Lean-Retry")
print("    Uses elab-only signal (TC%). Recovery = ΔTC / Vanilla_elab_fail.")
print("="*72)

master = json.load(open(MASTER))
mt = {(r["model"], r["dataset"], r["method"]): r for r in master}

print()
print(f"{'(model, dataset)':<28}  {'V_fail%':>8}  {'B3 ΔTC':>8}  {'recov%':>8}"
      f"  {'B4 ΔTC':>8}  {'recov%':>8}")
cells = []
for model in ("v4pro","qwen","mimo"):
    for ds in ("proofnet","minif2f"):
        if (model, ds, "B1") not in mt: continue
        v_tc = mt[(model,ds,"B1")]["tc_pct"]
        v_fail = 100.0 - v_tc
        b3_tc = mt[(model,ds,"B3")]["tc_pct"]
        b4_tc = mt[(model,ds,"B4")]["tc_pct"]
        d3 = b3_tc - v_tc
        d4 = b4_tc - v_tc
        r3 = d3 / v_fail * 100 if v_fail > 0 else float("nan")
        r4 = d4 / v_fail * 100 if v_fail > 0 else float("nan")
        cells.append((model, ds, v_fail, d3, r3, d4, r4))
        print(f"{model+' x '+ds:<28}  {v_fail:>7.1f}%  {d3:>+7.1f}  {r3:>7.1f}%"
              f"  {d4:>+7.1f}  {r4:>7.1f}%")

xs = [c[2] for c in cells]; ys = [c[3] for c in cells]
n = len(xs); mx = sum(xs)/n; my = sum(ys)/n
sxy = sum((x-mx)*(y-my) for x,y in zip(xs,ys))
sxx = sum((x-mx)**2 for x in xs); syy = sum((y-my)**2 for y in ys)
r = sxy / math.sqrt(sxx*syy)
slope = sxy/sxx; intercept = my - slope*mx
print()
print(f"OLS fit (Lean-Retry):  ΔTC = {slope:.3f} * V_fail + {intercept:+.3f}")
print(f"Pearson r = {r:.3f},  R^2 = {r*r:.3f}")
print()
print("Per-cell residual (Lean-Retry, in pp):")
for (model, ds, v_fail, d3, r3, d4, r4) in cells:
    pred = slope * v_fail + intercept
    res = d3 - pred
    print(f"  {model+' x '+ds:<28}  predicted={pred:+5.2f}  actual={d3:+5.2f}  residual={res:+5.2f}")

xs4 = [c[2] for c in cells]; ys4 = [c[5] for c in cells]
mx4 = sum(xs4)/n; my4 = sum(ys4)/n
sxy4 = sum((x-mx4)*(y-my4) for x,y in zip(xs4,ys4))
sxx4 = sum((x-mx4)**2 for x in xs4); syy4 = sum((y-my4)**2 for y in ys4)
r_saf = sxy4 / math.sqrt(sxx4*syy4) if sxx4*syy4 > 0 else float("nan")
slope_saf = sxy4/sxx4
intercept_saf = my4 - slope_saf*mx4
print()
print(f"OLS fit (SAF):         ΔTC = {slope_saf:.3f} * V_fail + {intercept_saf:+.3f}")
print(f"Pearson r = {r_saf:.3f},  R^2 = {r_saf*r_saf:.3f}")

print()
print("="*72)
print("(C) Benchmark-type modulation of recovery rate (Lean-Retry)")
print("="*72)
pn = [c[4] for c in cells if c[1] == "proofnet"]
mf = [c[4] for c in cells if c[1] == "minif2f"]
print(f"  ProofNet# (type-complex):  recovery = {sum(pn)/len(pn):.1f}%  range [{min(pn):.1f}, {max(pn):.1f}]")
print(f"  MiniF2F   (type-simple):   recovery = {sum(mf)/len(mf):.1f}%  range [{min(mf):.1f}, {max(mf):.1f}]")
print()
print(f"  Interpretation: type-simple benchmarks expose a higher share of easy")
print(f"  surface-form errors that elab-feedback can rewrite, while type-complex")
print(f"  benchmarks accumulate hard TO that survives K=3 refinement.")

print()
print("="*72)
print("(D) Predicted vs observed ΔTS on V4-Pro x ProofNet# from stratum-rates")
print("="*72)
ref = "Lean-Retry"
n_tots, rates = recov_rates[ref]
ts_loss = denom["TS"] - n_tots["TS"]
print(f"\nLaw fit from {ref}:  TO->TS={rates['TO']:.1f}%, SO->TS={rates['SO']:.1f}%, "
      f"BF->TS={rates['BF']:.1f}%, TS retention={rates['TS']:.1f}%")
print(f"Vanilla mass: TO={denom['TO']}, SO={denom['SO']}, BF={denom['BF']}, TS={denom['TS']}")
print()
print(f"{'Method':<16}  {'Predicted ΔTS':>14}  {'Observed ΔTS':>14}  {'Residual':>10}")
for name,_,_ in METHODS:
    if name == "Vanilla": continue
    pred = (rates['TO']/100*denom['TO'] + rates['SO']/100*denom['SO']
            + rates['BF']/100*denom['BF']) - (denom['TS'] - rates['TS']/100*denom['TS'])
    obs_TS = sum(1 for pid in PIDS if per_problem[pid][name]["cell"] == "TS")
    obs_delta = obs_TS - denom['TS']
    print(f"{name:<16}  {pred:>+13.1f}   {obs_delta:>+13d}   {obs_delta-pred:>+9.2f}")

print()
print("="*72)
print("DONE.")
