"""Causal chain-depth analyses on existing v90 parquets.

What we can do with what's on S3:
  - Analysis 3' (mini-vs-prod natural A/B): mini training mix (small only) vs
    prod (small+large) on OlymMATH Hard.
  - Analysis 1' (within-checkpoint puzzle-type chain depth): bridges 8x8de vs
    undead 5x5de, both within v90_sft_puzzles and v90_gspo_puzzles.
  - Analysis 2' (solved vs unsolved chain depth on OlymMATH Hard).

Chain depth = longest run of compute-phase primitives {COMPUTE, CHECK, SETUP}
in primitive_sequence, before an exploration-interrupt primitive
{HYPOTHESIZE, BACKTRACK, ERROR_DETECT}. Other primitives (PLAN, ENUMERATE,
DECOMPOSE, VERIFY, SUMMARIZE, OTHER) neither extend nor interrupt.
"""

import json
from pathlib import Path

import numpy as np
import pandas as pd

COMPUTE_PHASE = {"COMPUTE", "CHECK", "SETUP"}
INTERRUPT = {"HYPOTHESIZE", "BACKTRACK", "ERROR_DETECT"}

OUT = Path("results/exploration_analysis/causal_chain_depth_obs")
OUT.mkdir(exist_ok=True, parents=True)


def chain_depth(seq):
    if isinstance(seq, str):
        try:
            seq = json.loads(seq)
        except (json.JSONDecodeError, TypeError):
            return 0
    if not isinstance(seq, (list, np.ndarray)) or len(seq) == 0:
        return 0
    best = cur = 0
    for p in seq:
        if p in COMPUTE_PHASE:
            cur += 1
            if cur > best:
                best = cur
        elif p in INTERRUPT:
            cur = 0
    return best


def add_chain_depth(df):
    df = df.copy()
    df["chain_depth"] = df["primitive_sequence"].apply(chain_depth)
    return df


def summary(df, by, label):
    g = df.groupby(by)["chain_depth"]
    out = pd.DataFrame({
        "n": g.size(),
        "mean": g.mean().round(2),
        "median": g.median(),
        "p90": g.quantile(0.9).round(1),
        "p95": g.quantile(0.95).round(1),
        "max": g.max(),
    }).reset_index()
    out.insert(0, "view", label)
    return out


def mannwhitney(x, y):
    """Quick non-parametric test, returns (U, p) using rank-based approximation."""
    from scipy.stats import mannwhitneyu

    return mannwhitneyu(x, y, alternative="two-sided")


def load(parquet):
    return add_chain_depth(pd.read_parquet(parquet))


def main():
    ROOT = Path("results/exploration_analysis")

    # =========================================================================
    # Analysis 3': mini-vs-prod natural A/B on OlymMATH Hard
    # =========================================================================
    print("\n=== Analysis 3': training mix → math chain depth ===")

    runs = {
        "base": "v90_base_math",
        "sft_v2": "v90_sft_math",
        "gspo_v2_s20": "v90_gspo_math",
        "mini_s20": "v90_mini_math",
        "prod_s15": "v90_prod_s15_math",
        "prod_s25": "v90_prod_s25_math",
        "curriculum_s55": "v90_curriculum_s55_math",
        "entropy_s5": "v90_entropy_s5_math",
    }

    rows = []
    for label, dirname in runs.items():
        df = load(ROOT / dirname / "trace_level_metrics.parquet")
        rows.append({
            "label": label,
            "training_mix": _classify_mix(label),
            "n_traces": len(df),
            "mean": df["chain_depth"].mean(),
            "median": df["chain_depth"].median(),
            "p90": df["chain_depth"].quantile(0.9),
            "p95": df["chain_depth"].quantile(0.95),
            "max": df["chain_depth"].max(),
            "frac_correct": df["correct"].mean(),
        })

    a3 = pd.DataFrame(rows)
    a3.to_csv(OUT / "analysis3_training_mix.csv", index=False)
    print(a3.to_string(index=False))

    # Statistical test: mini (small only) vs prod_s15 (small+large)
    mini = load(ROOT / "v90_mini_math" / "trace_level_metrics.parquet")
    prod = load(ROOT / "v90_prod_s15_math" / "trace_level_metrics.parquet")
    gspo = load(ROOT / "v90_gspo_math" / "trace_level_metrics.parquet")
    sft = load(ROOT / "v90_sft_math" / "trace_level_metrics.parquet")

    print("\nMini (small-only) vs Prod s15 (small+large) Mann-Whitney U:")
    u, p = mannwhitney(mini["chain_depth"], prod["chain_depth"])
    print(f"  U={u:.0f}  p={p:.4g}")

    print("Mini vs GSPO v2 s20 (vanilla RL on small+large):")
    u, p = mannwhitney(mini["chain_depth"], gspo["chain_depth"])
    print(f"  U={u:.0f}  p={p:.4g}")

    print("SFT vs Mini (RL effect on small-only training):")
    u, p = mannwhitney(sft["chain_depth"], mini["chain_depth"])
    print(f"  U={u:.0f}  p={p:.4g}")

    # =========================================================================
    # Analysis 1': within-checkpoint puzzle-type chain depth
    # =========================================================================
    print("\n=== Analysis 1': within-checkpoint puzzle-type chain depth ===")

    rows = []
    for ck_label, dirname in [("dsr_sft_v2", "v90_sft_puzzles"),
                              ("gspo_v2_s20", "v90_gspo_puzzles")]:
        df = load(ROOT / dirname / "trace_level_metrics.parquet")
        for task in df["task_name"].unique():
            sub = df[df["task_name"] == task]
            rows.append({
                "checkpoint": ck_label,
                "task": task,
                "n_traces": len(sub),
                "mean": sub["chain_depth"].mean().round(2),
                "median": sub["chain_depth"].median(),
                "p90": sub["chain_depth"].quantile(0.9).round(1),
                "p95": sub["chain_depth"].quantile(0.95).round(1),
                "frac_correct": sub["correct"].mean().round(3),
            })

    a1 = pd.DataFrame(rows)
    a1.to_csv(OUT / "analysis1_within_checkpoint_puzzle.csv", index=False)
    print(a1.to_string(index=False))

    # Within each checkpoint: bridges_8x8de vs undead_5x5de
    for ck_label, dirname in [("dsr_sft_v2", "v90_sft_puzzles"),
                              ("gspo_v2_s20", "v90_gspo_puzzles")]:
        df = load(ROOT / dirname / "trace_level_metrics.parquet")
        b = df[df["task_name"] == "bridges_8x8de_pass32"]["chain_depth"]
        u_ = df[df["task_name"] == "undead_5x5de_pass32"]["chain_depth"]
        u, p = mannwhitney(b, u_)
        print(f"  {ck_label}: bridges_8x8de vs undead_5x5de  U={u:.0f}  p={p:.4g}")

    # =========================================================================
    # Analysis 2': solved vs unsolved chain depth (OlymMATH Hard)
    # =========================================================================
    print("\n=== Analysis 2': solved vs unsolved chain depth on OlymMATH Hard ===")

    rows = []
    for label, dirname in runs.items():
        df = load(ROOT / dirname / "trace_level_metrics.parquet")
        solved = df[df["correct"] == True]["chain_depth"]
        unsolved = df[df["correct"] == False]["chain_depth"]
        if len(solved) < 5:
            row = {"checkpoint": label, "n_solved": len(solved),
                   "n_unsolved": len(unsolved),
                   "solved_mean": np.nan, "unsolved_mean": np.nan,
                   "delta": np.nan, "U": np.nan, "p": np.nan}
        else:
            u, p = mannwhitney(solved, unsolved)
            row = {
                "checkpoint": label,
                "n_solved": len(solved),
                "n_unsolved": len(unsolved),
                "solved_mean": round(solved.mean(), 2),
                "unsolved_mean": round(unsolved.mean(), 2),
                "delta": round(solved.mean() - unsolved.mean(), 2),
                "U": int(u),
                "p": float(p),
            }
        rows.append(row)

    a2 = pd.DataFrame(rows)
    a2.to_csv(OUT / "analysis2_solved_vs_unsolved.csv", index=False)
    print(a2.to_string(index=False))


def _classify_mix(label):
    """Heuristic: mini = small-only training, others = small+large or no RL."""
    if label == "base":
        return "no_RL"
    if label == "sft_v2":
        return "no_RL_sft_only"
    if label == "mini_s20":
        return "small_only"
    if label in ("gspo_v2_s20", "prod_s15", "prod_s25", "curriculum_s55", "entropy_s5"):
        return "small_plus_large"
    return "?"


if __name__ == "__main__":
    main()
