"""Plots for causal chain depth analyses."""

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

COMPUTE_PHASE = {"COMPUTE", "CHECK", "SETUP"}
INTERRUPT = {"HYPOTHESIZE", "BACKTRACK", "ERROR_DETECT"}

OUT = Path("results/exploration_analysis/causal_chain_depth_obs")
OUT.mkdir(exist_ok=True, parents=True)
ROOT = Path("results/exploration_analysis")


def chain_depth(seq):
    if isinstance(seq, str):
        try:
            seq = json.loads(seq)
        except (json.JSONDecodeError, TypeError):
            return 0
    if not isinstance(seq, (list, np.ndarray)) or len(seq) == 0:
        return 0
    best = cur = 0
    for p in seq:
        if p in COMPUTE_PHASE:
            cur += 1
            if cur > best:
                best = cur
        elif p in INTERRUPT:
            cur = 0
    return best


def load(parquet):
    df = pd.read_parquet(parquet)
    df["chain_depth"] = df["primitive_sequence"].apply(chain_depth)
    return df


def main():
    runs = [
        ("base", "v90_base_math", "no_RL"),
        ("SFT v2", "v90_sft_math", "no_RL"),
        ("curriculum s55", "v90_curriculum_s55_math", "small+large"),
        ("entropy s5", "v90_entropy_s5_math", "small+large"),
        ("prod s15 (best math)", "v90_prod_s15_math", "small+large"),
        ("prod s25", "v90_prod_s25_math", "small+large"),
        ("mini s20", "v90_mini_math", "small-only"),
        ("GSPO v2 s20", "v90_gspo_math", "small+large"),
    ]

    # ---- Plot 1: chain-depth distribution by training mix ----
    fig, ax = plt.subplots(figsize=(10, 5))
    data = []
    labels = []
    colors = []
    cmap = {"no_RL": "#888888", "small-only": "#cc6677", "small+large": "#4477aa"}
    for label, dirname, mix in runs:
        df = load(ROOT / dirname / "trace_level_metrics.parquet")
        data.append(df["chain_depth"].values)
        labels.append(label)
        colors.append(cmap[mix])

    bp = ax.boxplot(data, labels=labels, patch_artist=True, showfliers=False)
    for patch, c in zip(bp["boxes"], colors):
        patch.set_facecolor(c)
        patch.set_alpha(0.65)
    ax.set_ylabel("Chain depth (max consecutive COMPUTE/CHECK/SETUP)")
    ax.set_title("OlymMATH Hard: Chain depth distribution per checkpoint")
    plt.xticks(rotation=20, ha="right")
    # Custom legend
    from matplotlib.patches import Patch
    legend = [Patch(color=cmap[k], alpha=0.65, label=k) for k in cmap]
    ax.legend(handles=legend, loc="upper right", fontsize=9)
    plt.tight_layout()
    plt.savefig(OUT / "fig1_chain_depth_by_checkpoint.png", dpi=130)
    plt.close()
    print(f"Saved: {OUT / 'fig1_chain_depth_by_checkpoint.png'}")

    # ---- Plot 2: chain-depth vs OlymMATH Hard pass@32 ----
    # Use the table from v90_combined report.md for pass@32 alignment
    pass32 = {
        "base": 16.0, "SFT v2": 23.0, "curriculum s55": 22.0, "entropy s5": 28.0,
        "prod s15 (best math)": 36.0, "prod s25": 23.0, "mini s20": 28.0, "GSPO v2 s20": 29.0,
    }
    fig, ax = plt.subplots(figsize=(8, 5.5))
    means = []
    p32 = []
    for label, dirname, mix in runs:
        df = load(ROOT / dirname / "trace_level_metrics.parquet")
        m = df["chain_depth"].mean()
        ax.scatter(m, pass32[label], color=cmap[mix], s=110, alpha=0.85, edgecolor="black", linewidth=0.7)
        ax.annotate(label, (m, pass32[label]), xytext=(7, -3),
                    textcoords="offset points", fontsize=9)
        means.append(m)
        p32.append(pass32[label])
    ax.set_xlabel("Mean chain depth on OlymMATH Hard rollouts")
    ax.set_ylabel("OlymMATH Hard pass@32 (%)")
    ax.set_title("Chain depth ≠ math transfer monotonically\n(prod s15 has best math but moderate depth)")
    ax.grid(alpha=0.3)
    # Pearson r
    from scipy.stats import pearsonr
    r, p = pearsonr(means, p32)
    ax.text(0.04, 0.95, f"Pearson r = {r:+.2f}, p = {p:.3f}",
            transform=ax.transAxes, fontsize=10, va="top",
            bbox=dict(facecolor="white", alpha=0.85, edgecolor="gray"))
    plt.tight_layout()
    plt.savefig(OUT / "fig2_chain_depth_vs_pass32.png", dpi=130)
    plt.close()
    print(f"Saved: {OUT / 'fig2_chain_depth_vs_pass32.png'}")

    # ---- Plot 3: solved vs unsolved chain depth per checkpoint ----
    rows = []
    for label, dirname, mix in runs:
        df = load(ROOT / dirname / "trace_level_metrics.parquet")
        s = df[df["correct"] == True]["chain_depth"]
        u = df[df["correct"] == False]["chain_depth"]
        rows.append({
            "label": label,
            "solved_mean": s.mean() if len(s) >= 5 else np.nan,
            "unsolved_mean": u.mean() if len(u) >= 5 else np.nan,
            "n_solved": len(s),
            "n_unsolved": len(u),
        })
    sdf = pd.DataFrame(rows)
    fig, ax = plt.subplots(figsize=(11, 5))
    x = np.arange(len(sdf))
    w = 0.35
    ax.bar(x - w/2, sdf["solved_mean"], w, label="Solved", color="#117733", alpha=0.85)
    ax.bar(x + w/2, sdf["unsolved_mean"], w, label="Unsolved", color="#cc6677", alpha=0.85)
    ax.set_xticks(x)
    ax.set_xticklabels(sdf["label"], rotation=20, ha="right")
    ax.set_ylabel("Mean chain depth")
    ax.set_title("OlymMATH Hard: solved vs unsolved chain depth\n(prod s15 is the only RL run where solved > unsolved cleanly)")
    ax.legend()
    ax.grid(alpha=0.3, axis="y")
    plt.tight_layout()
    plt.savefig(OUT / "fig3_solved_vs_unsolved.png", dpi=130)
    plt.close()
    print(f"Saved: {OUT / 'fig3_solved_vs_unsolved.png'}")

    # ---- Plot 4: within-checkpoint puzzle-type chain depth ----
    rows = []
    for ck, dirname in [("DSR SFT v2", "v90_sft_puzzles"),
                        ("GSPO v2 s20", "v90_gspo_puzzles")]:
        df = load(ROOT / dirname / "trace_level_metrics.parquet")
        for task, short in [("bridges_8x8de_pass32", "Bridges 8x8de"),
                            ("undead_5x5de_pass32", "Undead 5x5de")]:
            sub = df[df["task_name"] == task]
            rows.append({"checkpoint": ck, "task": short, "mean": sub["chain_depth"].mean()})
    pdf = pd.DataFrame(rows)
    pivot = pdf.pivot(index="checkpoint", columns="task", values="mean")

    fig, ax = plt.subplots(figsize=(7, 4.5))
    pivot.plot(kind="bar", ax=ax, color=["#4477aa", "#cc6677"], alpha=0.85, edgecolor="black", linewidth=0.7)
    ax.set_ylabel("Mean chain depth")
    ax.set_title("Within-checkpoint puzzle-type chain depth")
    ax.set_xlabel("")
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.savefig(OUT / "fig4_within_checkpoint_puzzle.png", dpi=130)
    plt.close()
    print(f"Saved: {OUT / 'fig4_within_checkpoint_puzzle.png'}")


if __name__ == "__main__":
    main()
