#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
EduFrameTrap / ICML Position Paper — Reproducible analysis script (supplementary)

This script reproduces the *paper-facing* tables/figures from a run log JSONL,
using the "FINAL label" pipeline:
  - final_label := human_label if present else model_final_label
  - SYC := final_label ∈ {CS-SYC, AUTH-SYC, FACE-SYC, DIR-SYC}
  - DISAGREE := judge_a.label != judge_b.label  (two-judge disagreement)

It generates:
TABLES (CSV + LaTeX):
  Table1  overall_by_tutor                 (N, SYC%, DISAGREE%)
  Table2  reliability_summary              (DISAGREE%, JudgeA_SYC%, JudgeB_SYC%)
  Table3  syc_by_pressure_mode_wilson      (per tutor: CS/AUTH/SOC with 95% Wilson ±)
  Table4  confidence_by_mode               (per tutor: mode × C=1/2/3)
  Table5  subtype_counts_by_domain         (Auth/CS/Face counts + Disagree%)
  Table6* false_negatives_vs_humans        (optional; judge-consensus PASS but human SYC)

FIGURES (PNG + PDF):
  Figure1 syc_rates_by_pressure_and_domain (2-panel bars with Wilson CIs + n=k_syc)
  Figure2 fragility_heatmaps_domains       (domain × pressure, per tutor; annotate rate + n=k_syc)
  Figure3 judge_disagreement_by_domain     (bar chart, grouped by tutor)
  Figure4 confidence_x_pressure_heatmaps   (C × pressure, per tutor; shared scale like domain heatmap)
          + stitched side-by-side layout with shared colorbar and no shared y-label text.

Notes:
  - No seaborn used (matplotlib only).
  - Colormap defaults to YlOrRd (yellow→orange→red) as requested for the paper figures.
  - For Figure4, scaling is shared (default 0–35%) to match Figure2-style comparability.

Expected JSONL fields (robust to extras):
  split, tutor_model, tutor_turn2, pressure_mode, domain, confidence,
  judge_a{label,...}, judge_b{label,...}, human_label, model_final_label

If your file uses slightly different names, adjust FIELD_* constants below.

Usage:
  python analyze_sycophancy.py \
    --input data/runs/eval_test_default_20260126_014856_c3092ce3_with_human.jsonl \
    --out icml_figures \
    --paper_prefix "icml_syco" \
    --shared_vmax 35
"""


import os
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Optional: stitching Figure4 composite
try:
    from PIL import Image  # type: ignore
    PIL_AVAILABLE = True
except Exception:
    PIL_AVAILABLE = False


# ================================================================
# 0) PROJECT PATHS (Thonny-friendly defaults)
# ================================================================
THIS_FILE = Path(__file__).resolve()
CODE_DIR = THIS_FILE.parent                       # EduFrameTrap/code
PROJECT_ROOT = CODE_DIR.parent                    # EduFrameTrap
RUNS_DIR = PROJECT_ROOT / "data" / "runs"         # EduFrameTrap/data/runs
OUT_DIR = PROJECT_ROOT / "icml_figures_with_counts"

# Prefix for filenames (keep stable for paper)
PAPER_PREFIX = "eduframetrap"

# If you want to hardcode a specific run file, set it here (string), else None:
FORCE_RUN_FILE = None  # e.g. "eval_test_default_20260126_014856_c3092ce3_with_human.jsonl"

# Analysis split (only used if your JSONL has a "split" field)
SPLIT_TO_USE = "test"

# Heatmap settings
CMAP_NAME = "YlOrRd"     # yellow → orange → red
SHARED_VMAX = 35.0       # shared scale for domain + confidence heatmaps


# ================================================================
# 1) SCHEMA CONFIG
# ================================================================
FIELD_SPLIT = "split"
FIELD_TUTOR_MODEL = "tutor_model"
FIELD_TUTOR_T2 = "tutor_turn2"
FIELD_DOMAIN = "domain"
FIELD_PRESSURE = "pressure_mode"
FIELD_CONF = "confidence"
FIELD_JUDGE_A = "judge_a"
FIELD_JUDGE_B = "judge_b"
FIELD_HUMAN = "human_label"
FIELD_MODEL_FINAL = "model_final_label"

PRESSURE_ORDER = ["authority", "context_switch", "social"]
PRESSURE_CANON = {
    "authority": "authority",
    "context_switch": "context_switch",
    "context-switch": "context_switch",
    "social": "social",
    "social-affective": "social",
}

DOMAIN_ORDER = ["computer_science", "economics", "physics", "math", "chemistry", "biology"]
CONF_ORDER = [1, 2, 3]

DOMAIN_LABELS = {
    "computer_science": "Computer Science",
    "economics": "Economics",
    "physics": "Physics",
    "math": "Math",
    "chemistry": "Chemistry",
    "biology": "Biology",
}

PRESSURE_LABELS_PAPER = {
    "authority": "authority",
    "context_switch": "context-switch",
    "social": "social-affective",
}

MODEL_MAP = {
    "gpt-5.2-2025-12-11": "GPT-5.2",
    "claude-sonnet-4-5": "Claude 4.5",
}

SYC_LABELS = {"CS-SYC", "AUTH-SYC", "FACE-SYC", "DIR-SYC"}


# ================================================================
# 2) HELPERS
# ================================================================
def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


def safe_jsonl_load(path: Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            try:
                rows.append(json.loads(s))
            except json.JSONDecodeError:
                continue
    return rows


def get_judge_label(val: Any) -> Optional[str]:
    if isinstance(val, dict):
        lab = val.get("label", None)
        if lab is None:
            return None
        lab = str(lab).strip()
        return lab if lab else None
    return None


def wilson_ci(k: int, n: int, z: float = 1.96) -> Tuple[float, float]:
    if n <= 0:
        return (np.nan, np.nan)
    phat = k / n
    denom = 1 + z**2 / n
    center = (phat + z**2 / (2 * n)) / denom
    half = z * np.sqrt((phat * (1 - phat) + z**2 / (4 * n)) / n) / denom
    return center - half, center + half


def pct(x: float, nd: int = 1) -> float:
    return float(np.round(x * 100.0, nd))


def tutor_sort_key(name: str) -> int:
    s = str(name).lower()
    if "gpt" in s:
        return 0
    if "claude" in s:
        return 1
    return 99


def to_latex_table(df: pd.DataFrame, path: Path, caption: str, label: str) -> None:
    lines = []
    lines.append(r"\begin{table}[t]")
    lines.append(r"\centering")
    lines.append(r"\small")
    lines.append(r"\setlength{\tabcolsep}{4pt}")
    cols = "l" + "r" * (len(df.columns) - 1)
    lines.append(r"\begin{tabular}{" + cols + r"}")
    lines.append(r"\toprule")
    header = " & ".join([r"\textbf{" + str(c) + "}" for c in df.columns]) + r" \\"
    lines.append(header)
    lines.append(r"\midrule")
    for _, row in df.iterrows():
        vals = [str(row[c]) for c in df.columns]
        lines.append(" & ".join(vals) + r" \\")
    lines.append(r"\bottomrule")
    lines.append(r"\end{tabular}")
    lines.append(r"\vspace{2pt}")
    lines.append(r"\caption{" + caption + r"}")
    lines.append(r"\label{" + label + r"}")
    lines.append(r"\end{table}")
    path.write_text("\n".join(lines), encoding="utf-8")


def pick_run_file() -> Path:
    """
    Prefer newest '*with_human*.jsonl' in data/runs.
    Fallback: newest '.jsonl'.
    """
    if FORCE_RUN_FILE:
        p = RUNS_DIR / FORCE_RUN_FILE
        if not p.exists():
            raise FileNotFoundError(f"FORCE_RUN_FILE not found: {p}")
        return p

    if not RUNS_DIR.exists():
        raise FileNotFoundError(f"Runs dir not found: {RUNS_DIR}")

    with_human = sorted(RUNS_DIR.glob("*with_human*.jsonl"), key=lambda p: p.stat().st_mtime, reverse=True)
    if with_human:
        return with_human[0]

    any_jsonl = sorted(RUNS_DIR.glob("*.jsonl"), key=lambda p: p.stat().st_mtime, reverse=True)
    if any_jsonl:
        return any_jsonl[0]

    raise FileNotFoundError(f"No .jsonl files found in {RUNS_DIR}")


# ================================================================
# 3) LOAD + PREPROCESS
# ================================================================
def load_and_preprocess(run_files: List[Path]) -> pd.DataFrame:
    all_rows: List[Dict[str, Any]] = []
    for f in run_files:
        if not f.exists():
            print(f"[warn] Missing file: {f}")
            continue
        all_rows.extend(safe_jsonl_load(f))

    if not all_rows:
        return pd.DataFrame()

    df = pd.DataFrame(all_rows)

    # Optional split filter
    if FIELD_SPLIT in df.columns:
        df = df[df[FIELD_SPLIT] == SPLIT_TO_USE].copy()

    required_cols = [FIELD_TUTOR_T2, FIELD_JUDGE_A, FIELD_JUDGE_B, FIELD_TUTOR_MODEL,
                     FIELD_DOMAIN, FIELD_PRESSURE, FIELD_CONF]
    for c in required_cols:
        if c not in df.columns:
            raise ValueError(f"Missing required column: {c}")

    df = df[df[FIELD_TUTOR_T2].notna() & df[FIELD_JUDGE_A].notna() & df[FIELD_JUDGE_B].notna()].copy()

    df["label_a"] = df[FIELD_JUDGE_A].apply(get_judge_label)
    df["label_b"] = df[FIELD_JUDGE_B].apply(get_judge_label)
    df["is_disagree"] = df["label_a"] != df["label_b"]

    for col in [FIELD_HUMAN, FIELD_MODEL_FINAL]:
        if col not in df.columns:
            df[col] = np.nan
        df[col] = df[col].replace("", np.nan)

    df["final_used"] = df[FIELD_HUMAN].where(df[FIELD_HUMAN].notna(), df[FIELD_MODEL_FINAL])
    df = df[df["final_used"].notna()].copy()

    df["is_syc_final"] = df["final_used"].isin(SYC_LABELS)
    df["is_syc_a"] = df["label_a"].isin(SYC_LABELS)
    df["is_syc_b"] = df["label_b"].isin(SYC_LABELS)

    df[FIELD_PRESSURE] = df[FIELD_PRESSURE].map(lambda x: PRESSURE_CANON.get(str(x), str(x)))
    df = df[df[FIELD_PRESSURE].isin(PRESSURE_ORDER)].copy()

    df["model_clean"] = df[FIELD_TUTOR_MODEL].map(lambda x: MODEL_MAP.get(str(x), str(x)))

    df[FIELD_PRESSURE] = pd.Categorical(df[FIELD_PRESSURE], PRESSURE_ORDER, ordered=True)
    df[FIELD_DOMAIN] = pd.Categorical(df[FIELD_DOMAIN], DOMAIN_ORDER, ordered=True)
    df[FIELD_CONF] = pd.Categorical(df[FIELD_CONF].astype(int), CONF_ORDER, ordered=True)

    return df


# ================================================================
# 4) TABLES
# ================================================================
def table_overall_by_tutor(df: pd.DataFrame) -> pd.DataFrame:
    g = df.groupby(["model_clean"], observed=True)
    out = g.agg(
        N=("is_syc_final", "size"),
        SYC_pct=("is_syc_final", lambda s: pct(float(s.mean()))),
        DISAGREE_pct=("is_disagree", lambda s: pct(float(s.mean()))),
    ).reset_index().rename(columns={"model_clean": "Tutor"})
    out = out.sort_values("Tutor", key=lambda s: s.map(tutor_sort_key))
    return out


def table_reliability_summary(df: pd.DataFrame) -> pd.DataFrame:
    g = df.groupby(["model_clean"], observed=True)
    out = g.agg(
        N=("is_syc_final", "size"),
        DISAGREE_pct=("is_disagree", lambda s: pct(float(s.mean()))),
        JudgeA_SYC_pct=("is_syc_a", lambda s: pct(float(s.mean()))),
        JudgeB_SYC_pct=("is_syc_b", lambda s: pct(float(s.mean()))),
    ).reset_index().rename(columns={"model_clean": "Tutor"})
    out = out.sort_values("Tutor", key=lambda s: s.map(tutor_sort_key))
    return out


def table_mode_compact_wilson(df: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for tutor, sub_t in df.groupby("model_clean", observed=True):
        for pm in PRESSURE_ORDER:
            cell = sub_t[sub_t[FIELD_PRESSURE] == pm]
            n = len(cell)
            k = int(cell["is_syc_final"].sum())
            r = (k / n) if n else np.nan
            lo, hi = wilson_ci(k, n)
            half = ((hi - lo) / 2.0) * 100.0 if n else np.nan
            rows.append({"Tutor": tutor, "Mode": pm, "rate": r * 100.0, "pm": half})

    tmp = pd.DataFrame(rows)

    def fmt(rate, half):
        return f"{rate:.1f}\\%\\,(\\pm\\,{half:.1f})" if np.isfinite(rate) else "--"

    wide = []
    for tutor in tmp["Tutor"].unique():
        st = tmp[tmp["Tutor"] == tutor].set_index("Mode")
        wide.append({
            "Tutor": tutor,
            "CS": fmt(st.loc["context_switch", "rate"], st.loc["context_switch", "pm"]),
            "AUTH": fmt(st.loc["authority", "rate"], st.loc["authority", "pm"]),
            "SOCIAL": fmt(st.loc["social", "rate"], st.loc["social", "pm"]),
        })
    out = pd.DataFrame(wide).sort_values("Tutor", key=lambda s: s.map(tutor_sort_key))
    return out


def table_confidence_by_mode(df: pd.DataFrame) -> pd.DataFrame:
    g = df.groupby(["model_clean", FIELD_PRESSURE, FIELD_CONF], observed=True)
    out = g.agg(
        SYC_pct=("is_syc_final", lambda s: pct(float(s.mean()))),
        N=("is_syc_final", "size"),
        k=("is_syc_final", "sum"),
    ).reset_index().rename(columns={"model_clean": "Tutor", FIELD_PRESSURE: "Mode", FIELD_CONF: "C"})
    out["Mode"] = out["Mode"].map(PRESSURE_LABELS_PAPER)
    return out


def table_subtypes_by_domain(df: pd.DataFrame) -> pd.DataFrame:
    domains = []
    for dom in DOMAIN_ORDER:
        sub = df[df[FIELD_DOMAIN] == dom]
        counts = sub["final_used"].value_counts().to_dict()
        domains.append({
            "Domain": DOMAIN_LABELS[dom],
            "Auth-Syc": int(counts.get("AUTH-SYC", 0)),
            "CS-Syc": int(counts.get("CS-SYC", 0)),
            "Face-Syc": int(counts.get("FACE-SYC", 0)),
            "Disagree Rate": f"{pct(float(sub['is_disagree'].mean())):.1f}\\%" if len(sub) else "--",
        })
    return pd.DataFrame(domains)


def table_false_negatives_vs_humans(df: pd.DataFrame) -> pd.DataFrame:
    human_present = df[df[FIELD_HUMAN].notna()].copy()
    if human_present.empty:
        return pd.DataFrame(columns=["Tutor", "N_human", "FN_count", "FN_pct"])

    human_present["judges_consensus_pass"] = (human_present["label_a"] == "PASS") & (human_present["label_b"] == "PASS")
    human_present["human_is_syc"] = human_present[FIELD_HUMAN].isin(SYC_LABELS)
    human_present["is_fn"] = human_present["judges_consensus_pass"] & human_present["human_is_syc"]

    g = human_present.groupby("model_clean", observed=True)
    out = g.agg(
        N_human=("is_fn", "size"),
        FN_count=("is_fn", "sum"),
        FN_pct=("is_fn", lambda s: pct(float(s.mean()))),
    ).reset_index().rename(columns={"model_clean": "Tutor"}).sort_values("Tutor", key=lambda s: s.map(tutor_sort_key))
    return out


def save_tables(df: pd.DataFrame) -> None:
    ensure_dir(OUT_DIR)

    t1 = table_overall_by_tutor(df)
    t2 = table_reliability_summary(df)
    t3 = table_mode_compact_wilson(df)
    t4 = table_confidence_by_mode(df)
    t5 = table_subtypes_by_domain(df)
    t6 = table_false_negatives_vs_humans(df)

    # CSV
    t1.to_csv(OUT_DIR / f"{PAPER_PREFIX}_table1_overall_by_tutor.csv", index=False)
    t2.to_csv(OUT_DIR / f"{PAPER_PREFIX}_table2_reliability_summary.csv", index=False)
    t3.to_csv(OUT_DIR / f"{PAPER_PREFIX}_table3_mode_compact_wilson.csv", index=False)
    t4.to_csv(OUT_DIR / f"{PAPER_PREFIX}_table4_confidence_by_mode_long.csv", index=False)
    t5.to_csv(OUT_DIR / f"{PAPER_PREFIX}_table5_subtypes_by_domain.csv", index=False)
    if not t6.empty:
        t6.to_csv(OUT_DIR / f"{PAPER_PREFIX}_table6_false_negatives_vs_humans.csv", index=False)

    # LaTeX snippets
    to_latex_table(
        t1, OUT_DIR / f"{PAPER_PREFIX}_table1_overall_by_tutor.tex",
        caption=r"\textbf{Overall rates by tutor (test, $T_2$ only).}",
        label="tab:results_overall_by_tutor",
    )
    to_latex_table(
        t2, OUT_DIR / f"{PAPER_PREFIX}_table2_reliability_summary.tex",
        caption=r"\textbf{Reliability summary (test, $T_2$ only).}",
        label="tab:reliability_summary",
    )
    to_latex_table(
        t3, OUT_DIR / f"{PAPER_PREFIX}_table3_mode_compact.tex",
        caption=r"\textbf{Sycophancy by pressure mode (test, $T_2$ only).} Values are failure rates with 95\% Wilson intervals.",
        label="tab:mode_compact",
    )

    wide = (
        t4.pivot_table(index=["Tutor", "Mode"], columns="C", values="SYC_pct", observed=True)
        .reset_index()
        .rename(columns={1: "C=1", 2: "C=2", 3: "C=3"})
    )
    wide.to_csv(OUT_DIR / f"{PAPER_PREFIX}_table4_confidence_by_mode_wide.csv", index=False)
    to_latex_table(
        wide, OUT_DIR / f"{PAPER_PREFIX}_table4_confidence_by_mode_wide.tex",
        caption=r"\textbf{Confidence-conditioned sycophancy rates by pressure mode (test, $T_2$ only).}",
        label="tab:conf_by_mode",
    )
    to_latex_table(
        t5, OUT_DIR / f"{PAPER_PREFIX}_table5_subtypes_by_domain.tex",
        caption=r"\textbf{Failure taxonomy by domain (final subtype counts).}",
        label="tab:subtypes",
    )
    if not t6.empty:
        to_latex_table(
            t6, OUT_DIR / f"{PAPER_PREFIX}_table6_false_negatives_vs_humans.tex",
            caption=r"\textbf{Judge false negatives vs. human labels.} Both judges output PASS, but the human label indicates a sycophancy subtype.",
            label="tab:false_negatives",
        )

    print("\n[Table1] Overall by tutor\n", t1.to_string(index=False))
    print("\n[Table2] Reliability summary\n", t2.to_string(index=False))
    print("\n[Table3] Mode compact\n", t3.to_string(index=False))
    print("\n[Table4] Confidence wide\n", wide.to_string(index=False))
    print("\n[Table5] Subtypes by domain\n", t5.to_string(index=False))
    if not t6.empty:
        print("\n[Table6] False negatives vs humans\n", t6.to_string(index=False))


# ================================================================
# 5) FIGURES
# ================================================================
def summarize_rate_ci(df: pd.DataFrame, group_cols: List[str]) -> pd.DataFrame:
    rows = []
    for keys, sub in df.groupby(group_cols, observed=True):
        if not isinstance(keys, tuple):
            keys = (keys,)
        n = len(sub)
        k = int(sub["is_syc_final"].sum())
        rate = k / n if n else np.nan
        lo, hi = wilson_ci(k, n)
        rows.append((*keys, n, k, rate, lo, hi))
    cols = list(group_cols) + ["n_total", "k_syc", "rate", "lo", "hi"]
    return pd.DataFrame(rows, columns=cols)


def figure1_syc_by_pressure_and_domain(df: pd.DataFrame) -> None:
    tutors = sorted(df[FIELD_TUTOR_MODEL].unique(), key=tutor_sort_key)
    tutor_names = {mid: MODEL_MAP.get(str(mid), str(mid)) for mid in tutors}

    press = summarize_rate_ci(df, [FIELD_TUTOR_MODEL, FIELD_PRESSURE])
    dom = summarize_rate_ci(df, [FIELD_TUTOR_MODEL, FIELD_DOMAIN])

    press[FIELD_PRESSURE] = pd.Categorical(press[FIELD_PRESSURE], PRESSURE_ORDER, ordered=True)
    dom[FIELD_DOMAIN] = pd.Categorical(dom[FIELD_DOMAIN], DOMAIN_ORDER, ordered=True)

    def annotate_low(ax, bars, labels):
        for b, lab in zip(bars, labels):
            h = b.get_height()
            y = max(0.7, h * 0.25)
            ax.text(b.get_x() + b.get_width() / 2, y, lab, ha="center", va="bottom", fontsize=9)

    fig, axes = plt.subplots(2, 1, figsize=(10.5, 7.5), constrained_layout=True)

    ax = axes[0]
    x = np.arange(len(PRESSURE_ORDER))
    width = 0.36
    for i, mid in enumerate(tutors):
        sub = press[press[FIELD_TUTOR_MODEL] == mid].set_index(FIELD_PRESSURE).loc[PRESSURE_ORDER]
        y = sub["rate"].values * 100
        yerr = np.vstack([(sub["rate"] - sub["lo"]).values * 100, (sub["hi"] - sub["rate"]).values * 100])
        bars = ax.bar(x + (i - 0.5) * width, y, width, label=tutor_names[mid], yerr=yerr, capsize=4)
        annotate_low(ax, bars, [f"n={int(k)}" for k in sub["k_syc"].values])
    ax.set_title("Sycophancy rate by pressure mode (95% Wilson CI)")
    ax.set_ylabel("SYC rate (%)")
    ax.set_xticks(x, [PRESSURE_LABELS_PAPER[p] for p in PRESSURE_ORDER])
    ax.set_ylim(0, max(30, ax.get_ylim()[1]))
    ax.legend(frameon=False)

    ax = axes[1]
    x = np.arange(len(DOMAIN_ORDER))
    for i, mid in enumerate(tutors):
        sub = dom[dom[FIELD_TUTOR_MODEL] == mid].set_index(FIELD_DOMAIN).loc[DOMAIN_ORDER]
        y = sub["rate"].values * 100
        yerr = np.vstack([(sub["rate"] - sub["lo"]).values * 100, (sub["hi"] - sub["rate"]).values * 100])
        bars = ax.bar(x + (i - 0.5) * width, y, width, label=tutor_names[mid], yerr=yerr, capsize=4)
        annotate_low(ax, bars, [f"n={int(k)}" for k in sub["k_syc"].values])
    ax.set_title("Sycophancy rate by domain (95% Wilson CI)")
    ax.set_ylabel("SYC rate (%)")
    ax.set_xticks(x, [DOMAIN_LABELS[d] for d in DOMAIN_ORDER])
    ax.set_ylim(0, max(30, ax.get_ylim()[1]))
    ax.legend(frameon=False)

    out_png = OUT_DIR / f"{PAPER_PREFIX}_figure1_sycrate_pressure_and_domain.png"
    out_pdf = OUT_DIR / f"{PAPER_PREFIX}_figure1_sycrate_pressure_and_domain.pdf"
    fig.savefig(out_png, dpi=300, bbox_inches="tight")
    fig.savefig(out_pdf, bbox_inches="tight")
    plt.close(fig)
    print(f"[Figure1] {out_png}")


def figure2_fragility_heatmaps_domains(df: pd.DataFrame) -> None:
    tutors = sorted(df[FIELD_TUTOR_MODEL].unique(), key=tutor_sort_key)
    cmap = plt.get_cmap(CMAP_NAME)

    mats, anns = [], []
    for mid in tutors:
        sub = df[df[FIELD_TUTOR_MODEL] == mid]
        mat = np.zeros((len(DOMAIN_ORDER), len(PRESSURE_ORDER)), dtype=float)
        ann = np.empty_like(mat, dtype=object)
        for i, dom in enumerate(DOMAIN_ORDER):
            for j, pm in enumerate(PRESSURE_ORDER):
                cell = sub[(sub[FIELD_DOMAIN] == dom) & (sub[FIELD_PRESSURE] == pm)]
                N = len(cell)
                k = int(cell["is_syc_final"].sum())
                rate = (k / N * 100) if N else 0.0
                mat[i, j] = rate
                ann[i, j] = f"{rate:.1f}%\n(n={k})"
        mats.append(mat)
        anns.append(ann)

    fig, axes = plt.subplots(1, len(tutors), figsize=(14, 6), constrained_layout=True)
    if len(tutors) == 1:
        axes = [axes]

    im = None
    for ax, mat, ann, mid in zip(axes, mats, anns, tutors):
        im = ax.imshow(mat, cmap=cmap, vmin=0, vmax=SHARED_VMAX, aspect="auto")
        ax.set_title(f"{MODEL_MAP.get(str(mid), str(mid))}\nFragility Heatmap", fontsize=14, pad=10)
        ax.set_xticks(range(len(PRESSURE_ORDER)), [PRESSURE_LABELS_PAPER[p] for p in PRESSURE_ORDER])
        ax.set_yticks(range(len(DOMAIN_ORDER)), [d.replace("_", " ") for d in DOMAIN_ORDER])
        ax.set_xlabel("pressure mode")
        ax.set_ylabel("")  # remove word 'domain' for shared layout later if you want
        for i in range(mat.shape[0]):
            for j in range(mat.shape[1]):
                ax.text(j, i, ann[i, j], ha="center", va="center", fontsize=10, color="black")

    cbar = fig.colorbar(im, ax=axes, shrink=0.9)
    cbar.set_label("SYC rate (%)")

    out_png = OUT_DIR / f"{PAPER_PREFIX}_figure2_fragility_heatmaps_domains.png"
    out_pdf = OUT_DIR / f"{PAPER_PREFIX}_figure2_fragility_heatmaps_domains.pdf"
    fig.savefig(out_png, dpi=300, bbox_inches="tight")
    fig.savefig(out_pdf, bbox_inches="tight")
    plt.close(fig)
    print(f"[Figure2] {out_png}")


def figure3_judge_disagreement_by_domain(df: pd.DataFrame) -> None:
    tutors = sorted(df[FIELD_TUTOR_MODEL].unique(), key=tutor_sort_key)
    tutor_names = {mid: MODEL_MAP.get(str(mid), str(mid)) for mid in tutors}

    rows = []
    for mid in tutors:
        sub = df[df[FIELD_TUTOR_MODEL] == mid]
        for dom in DOMAIN_ORDER:
            cell = sub[sub[FIELD_DOMAIN] == dom]
            n = len(cell)
            k = int(cell["is_disagree"].sum())
            rate = (k / n) if n else np.nan
            rows.append((mid, dom, rate))
    res = pd.DataFrame(rows, columns=["mid", "domain", "rate"])

    fig, ax = plt.subplots(figsize=(9.5, 4.5), constrained_layout=True)
    x = np.arange(len(DOMAIN_ORDER))
    width = 0.36
    for i, mid in enumerate(tutors):
        sub = res[res["mid"] == mid].set_index("domain").loc[DOMAIN_ORDER]
        ax.bar(x + (i - 0.5) * width, sub["rate"].values, width, label=tutor_names[mid])

    ax.set_ylabel("Judge disagreement rate")
    ax.set_xticks(x, [DOMAIN_LABELS[d] for d in DOMAIN_ORDER])
    ax.set_ylim(0, max(0.25, float(np.nanmax(res["rate"])) * 1.15))
    ax.legend(frameon=False, fontsize=9)

    out_png = OUT_DIR / f"{PAPER_PREFIX}_figure3_judge_disagreement_by_domain.png"
    out_pdf = OUT_DIR / f"{PAPER_PREFIX}_figure3_judge_disagreement_by_domain.pdf"
    fig.savefig(out_png, dpi=300, bbox_inches="tight")
    fig.savefig(out_pdf, bbox_inches="tight")
    plt.close(fig)
    print(f"[Figure3] {out_png}")


def figure4_confidence_x_pressure_stitched(df: pd.DataFrame) -> None:
    tutors = ["GPT-5.2", "Claude 4.5"]
    cmap = plt.get_cmap(CMAP_NAME)

    panel_paths = []
    for tutor in tutors:
        sub = df[df["model_clean"] == tutor].copy()
        if sub.empty:
            print(f"[warn] No data for tutor: {tutor}")
            continue

        pivot = sub.pivot_table(index=FIELD_CONF, columns=FIELD_PRESSURE, values="is_syc_final",
                                aggfunc=["mean", "sum", "count"], observed=True)
        mean = pivot["mean"].reindex(index=CONF_ORDER, columns=PRESSURE_ORDER)
        syc = pivot["sum"].reindex(index=CONF_ORDER, columns=PRESSURE_ORDER).fillna(0).astype(int)
        n = pivot["count"].reindex(index=CONF_ORDER, columns=PRESSURE_ORDER).fillna(0).astype(int)
        mat = (mean.values * 100.0).astype(float)

        fig, ax = plt.subplots(figsize=(5.4, 4.0))
        ax.imshow(mat, aspect="auto", cmap=cmap, vmin=0, vmax=SHARED_VMAX)
        ax.set_title(f"{tutor}: Confidence × Pressure (T2)", fontsize=14, pad=10)
        ax.set_xticks(range(len(PRESSURE_ORDER)))
        ax.set_xticklabels([PRESSURE_LABELS_PAPER[m] for m in PRESSURE_ORDER], rotation=15, ha="right", fontsize=10)
        ax.set_xlabel("pressure mode", fontsize=12)
        ax.set_yticks(range(len(CONF_ORDER)))
        ax.set_yticklabels([f"C={c}" for c in CONF_ORDER], fontsize=11)
        ax.set_ylabel("")  # remove 'confidence'

        for i in range(mat.shape[0]):
            for j in range(mat.shape[1]):
                ax.text(
                    j, i,
                    f"{mat[i, j]:.1f}%\n(n={int(syc.values[i, j])}/{int(n.values[i, j])})",
                    ha="center", va="center", fontsize=10, color="black"
                )

        fig.tight_layout()
        panel_png = OUT_DIR / f"{PAPER_PREFIX}_figure4_panel_{tutor.replace(' ', '_').replace('.', '')}.png"
        fig.savefig(panel_png, dpi=300, bbox_inches="tight")
        plt.close(fig)
        panel_paths.append(panel_png)

    if not PIL_AVAILABLE:
        print("[warn] pillow not installed → skipping stitched Figure4 composite.")
        return
    if len(panel_paths) < 2:
        print("[warn] not enough panels to stitch Figure4.")
        return

    # Shared colorbar image
    cbar_png = OUT_DIR / f"{PAPER_PREFIX}_figure4_shared_colorbar.png"
    fig, ax = plt.subplots(figsize=(0.8, 4.0))
    fig.subplots_adjust(left=0.0, right=0.6, top=0.98, bottom=0.02)
    norm = plt.Normalize(vmin=0, vmax=SHARED_VMAX)
    cb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax)
    cb.set_label("SYC rate (%)", fontsize=12)
    cb.ax.tick_params(labelsize=10)
    fig.savefig(cbar_png, dpi=300, bbox_inches="tight")
    plt.close(fig)

    # Stitch panels + colorbar
    imgs = [Image.open(p).convert("RGB") for p in panel_paths]
    cbar = Image.open(cbar_png).convert("RGB")
    max_h = max([im.size[1] for im in imgs] + [cbar.size[1]])

    def resize_to_h(im, h):
        if im.size[1] == h:
            return im
        new_w = int(im.size[0] * h / im.size[1])
        return im.resize((new_w, h), Image.LANCZOS)

    imgs = [resize_to_h(im, max_h) for im in imgs]
    cbar = resize_to_h(cbar, max_h)

    gap, left_margin = 20, 10
    total_w = left_margin + imgs[0].size[0] + gap + imgs[1].size[0] + gap + cbar.size[0]
    canvas = Image.new("RGB", (total_w, max_h), (255, 255, 255))

    x = left_margin
    canvas.paste(imgs[0], (x, 0))
    x += imgs[0].size[0] + gap
    canvas.paste(imgs[1], (x, 0))
    x += imgs[1].size[0] + gap
    canvas.paste(cbar, (x, 0))

    out_png = OUT_DIR / f"{PAPER_PREFIX}_figure4_confidence_x_pressure_stitched.png"
    out_pdf = OUT_DIR / f"{PAPER_PREFIX}_figure4_confidence_x_pressure_stitched.pdf"
    canvas.save(out_png)
    canvas.save(out_pdf, "PDF", resolution=300.0)
    print(f"[Figure4] {out_png}")


# ================================================================
# 6) RUN (Thonny: just press Run)
# ================================================================
def main():
    print("Project root:", PROJECT_ROOT)
    print("Runs dir:", RUNS_DIR)

    run_file = pick_run_file()
    print("\nUsing run file:", run_file)

    df = load_and_preprocess([run_file])
    if df.empty:
        print("[error] No usable rows. Check schema or split.")
        return

    print("Usable rows:", len(df))
    print("Tutors:", sorted(df["model_clean"].unique(), key=tutor_sort_key))

    ensure_dir(OUT_DIR)
    save_tables(df)
    figure1_syc_by_pressure_and_domain(df)
    figure2_fragility_heatmaps_domains(df)
    figure3_judge_disagreement_by_domain(df)
    figure4_confidence_x_pressure_stitched(df)

    final_counts = df["final_used"].value_counts()
    print("\nFinal-label counts:\n", final_counts.to_string())
    if final_counts.get("DIR-SYC", 0) == 0 and final_counts.get("EVADE", 0) == 0:
        print("Note: DIR-SYC and EVADE do not occur as FINAL labels in this run.")

    print("\nDONE. Outputs in:", OUT_DIR)


if __name__ == "__main__":
    main()
