#!/usr/bin/env python3
"""
pretty_loss_vs_accuracy.py
──────────────────────────
Square (6×6 in) Loss‑vs‑Accuracy plot for PKU‑SafeRLHF experiments.

Elements:
  • Primary ★  |  Guardian ★
  • Random‑sampling ◆  (p = 0.2,0.4,0.5,0.6,0.8)
  • Conformal‑arbitration ○   (one per α) with cross‑hair error bars
  • Interpolation dashed line

Axes fixed:  x ∈ [0, 0.6]   (avg severity‑loss)
             y ∈ [0.1, 0.7] (accuracy)
"""
import argparse, glob, json
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

MIX_PS = {0.2, 0.4, 0.5, 0.6, 0.8}          # which "random‑sampling" points to show
DIAMOND_COLOUR = "#b3b3b3"                  # light grey for random baseline

# ─────────────────── utils ─────────────────── #
def load_trial(fp):
    js = json.load(open(fp))
    alpha_block, mix_block = {}, {}
    for k, v in js.items():
        if k == "mix":
            mix_block = {float(p): v[p] for p in v}
        else:
            alpha_block[float(k)] = v
    return alpha_block, mix_block

def aggregate_alpha(trials):
    alphas = sorted(trials[0].keys())
    agg = {}
    for a in alphas:
        agg[a] = {}
        for strat in ("conf", "score_max", "safety"):
            accs   = [t[a][strat]["avg_acc"]  for t in trials]
            losses = [t[a][strat]["avg_loss"] for t in trials]
            agg[a][strat] = {
                "acc_m" : np.mean(accs),
                "acc_s" : np.std(accs,  ddof=1) if len(accs)>1 else 0.0,
                "loss_m": np.mean(losses),
                "loss_s": np.std(losses, ddof=1) if len(losses)>1 else 0.0,
            }
    return agg, alphas

def aggregate_mix(trials):
    agg = {}
    for p in MIX_PS:
        accs   = [t.get(p, {"avg_acc":np.nan})["avg_acc"]  for t in trials]
        losses = [t.get(p, {"avg_loss":np.nan})["avg_loss"] for t in trials]
        accs   = [x for x in accs   if not np.isnan(x)]
        losses = [x for x in losses if not np.isnan(x)]
        agg[p] = {
            "acc_m" : np.mean(accs),
            "acc_s" : np.std(accs,  ddof=1) if len(accs)>1 else 0.0,
            "loss_m": np.mean(losses),
            "loss_s": np.std(losses, ddof=1) if len(losses)>1 else 0.0,
        }
    return agg

# ─────────────────── plotting ───────────────── #
def make_plot(alpha_agg, alphas, mix_agg, out_png):
    n = len(alphas)

    # split into a "low" half and "high" half
    n_low  = n // 2
    n_high = n - n_low

    # sample t in [0,0.4) and [0.6,1]
    low_ts  = np.linspace(0.0, 0.35, n_low,  endpoint=False)
    high_ts = np.linspace(0.65, 1.0, n_high, endpoint=True)

    # concatenate and pull colors from bwr
    ts      = np.concatenate([low_ts, high_ts])
    cmap  = plt.cm.bwr(ts)

    # conformal metrics
    conf_loss = [alpha_agg[a]["conf"]["loss_m"] for a in alphas]
    conf_xerr = [alpha_agg[a]["conf"]["loss_s"] for a in alphas]
    conf_acc  = [alpha_agg[a]["conf"]["acc_m"]  for a in alphas]
    conf_yerr = [alpha_agg[a]["conf"]["acc_s"]  for a in alphas]

    # baselines (use any α block – they're constant)
    sm = alpha_agg[alphas[0]]["score_max"]   # ← Primary
    sa = alpha_agg[alphas[0]]["safety"]      # ← Guardian

    plt.figure(figsize=(6,6))

    # α tick marks
    for i,a in enumerate(alphas):
        plt.axvline(a, ls="--", lw=1, color=cmap[i], alpha=0.45)

    # ★ Primary  (score‑max)
    plt.errorbar(sm["loss_m"], sm["acc_m"],
                 xerr=sm["loss_s"], yerr=sm["acc_s"],
                 fmt="*", ms=13, capsize=4,
                 color="red", mec="k", mew=1.2,
                 label="Primary", zorder=5)

    # ★ Guardian (safety‑first)
    plt.errorbar(0, sa["acc_m"],
                 yerr=sa["acc_s"],
                 fmt="*", ms=13, capsize=4,
                 color="blue", mec="k", mew=1.2,
                 label="Guardian", zorder=5)

    # ── Interpolation dashed line ──
    plt.plot([0, sm["loss_m"]], [sa["acc_m"], sm["acc_m"]],
             ls="--", lw=1.3, color="grey",
             label="Interpolation", zorder=1)

    # ◆ Random‑sampling (lighter grey diamonds)
    for p,val in mix_agg.items():
        plt.errorbar(val["loss_m"], val["acc_m"],
                     xerr=val["loss_s"], yerr=val["acc_s"],
                     fmt="D", ms=7, capsize=4,
                     color=DIAMOND_COLOUR, mec="k", mew=0.8,
                     alpha=0.8, zorder=3)

    # ○ Conformal‑arbitration points (darker edge for B/W)
    for i,(x,y,xe,ye) in enumerate(zip(conf_loss, conf_acc,
                                       conf_xerr, conf_yerr)):
        plt.errorbar(x, y,
                     xerr=xe, yerr=ye,
                     fmt="o", ms=8, capsize=4, elinewidth=1.5,
                     mfc=cmap[i], mec="k", mew=0.8,               # black edge
                     color=cmap[i], zorder=4)

    # labels, limits, grid
    plt.xlabel("Safety Violation Loss (Harmfulness)", fontsize=12)
    plt.ylabel("Empirical Human Alignment (Helpfulness)", fontsize=12)
    plt.xlim(-0.05, 0.75)
    plt.ylim(0.1, 0.6)
    plt.grid(alpha=0.25, linestyle=':')

    # custom legend – order: Primary, Guardian, Conformal, Random, Interpolation
    legend_handles = [
        Line2D([],[], marker="*", color="red", mec="k", mew=1, ms=12, lw=0,
               label="Primary"),
        Line2D([],[], marker="*", color="blue", mec="k", mew=1, ms=12, lw=0,
               label="Guardian"),
        Line2D([],[], marker="o", color="black", mfc="white", ms=8, lw=0,
               label="Conformal Arbitration"),
        Line2D([],[], marker="D", color=DIAMOND_COLOUR, mec="k", mew=0.8,
               ms=8, lw=0, label="Random sampling"),
        Line2D([],[], color="grey", ls="--", lw=1.3, label="Interpolation")
    ]
    plt.legend(handles=legend_handles, loc="lower right", fontsize=10)
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    print("saved →", out_png)

# ─────────────────────── CLI ─────────────────────── #
if __name__ == "__main__":
    pa = argparse.ArgumentParser()
    pa.add_argument("--results_dir", required=True,
                    help="directory containing trial_*_summary.json files")
    pa.add_argument("--out_png", default=None,
                    help="output PNG file (default: loss_vs_accuracy_plot.png in results_dir)")
    args = pa.parse_args()

    results_dir = Path(args.results_dir)
    paths = sorted(glob.glob(str(results_dir / "trial_*_summary.json")))
    if not paths:
        raise SystemExit("No trial summaries found in " + str(results_dir))

    # Set default output path within results_dir if not specified
    if args.out_png is None:
        out_png = results_dir / "loss_vs_accuracy_plot.png"
    else:
        out_png = Path(args.out_png)

    trials_alpha, trials_mix = [], []
    for p in paths:
        a,m = load_trial(p)
        trials_alpha.append(a)
        trials_mix.append(m)

    alpha_agg, alphas = aggregate_alpha(trials_alpha)
    mix_agg            = aggregate_mix(trials_mix)
    make_plot(alpha_agg, alphas, mix_agg, str(out_png))
