# -*- coding: utf-8 -*-
"""
Plot label-shift results from saved JSON (no retraining).
Usage (from repo root):
  python -m src.plot_labelshift_from_json --in runs/label_shift_runs_raw.json --outdir runs
"""

import os, json, argparse
import numpy as np
import matplotlib.pyplot as plt

# plt.rcParams.update({
#     "axes.labelsize": 18,   # x and y labels
#     "axes.titlesize": 18,   # title size (if you add titles)
#     "xtick.labelsize": 14,  # x-axis numbers
#     "ytick.labelsize": 14,  # y-axis numbers
#     "legend.fontsize": 14,  # legend text
# })
plt.rcParams.update({
    "axes.labelsize": 18,
    "axes.titlesize": 18,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "legend.fontsize": 14,

    # aesthetics
    "axes.spines.top": False,
    "axes.spines.right": False,
    "legend.frameon": False,
    "axes.grid": True,
    "grid.alpha": 0.35,
    "grid.linestyle": ":",
    "grid.linewidth": 0.8,
})

def _ensure_dir(p):
    d = os.path.dirname(p)
    if d: os.makedirs(d, exist_ok=True)

# def plot_mean_std(KL, mean, std, label, ax, ylabel):
#     ax.plot(KL, mean, label=label)
#     ax.fill_between(KL, mean - std, mean + std, alpha=0.2)
#     ax.set_xlabel(r"KL($\mathbb{P}_{test} || \mathbb{P}_{train})$")
#     ax.set_ylabel(ylabel)
#     ax.grid(alpha=0.2)

def plot_mean_std(KL, mean, std, label, ax, ylabel, highlight=False):
    # Thicker line for the method you want to emphasize (e.g., IRS)
    lw = 2.8 if highlight else 2.0
    ms = 6

    line, = ax.plot(
        KL, mean,
        linestyle="--",
        marker="s",
        linewidth=lw,
        markersize=ms,
        markeredgewidth=1.2,
        markerfacecolor="white",  # clean, readable in print
        label=label,
        zorder=3
    )

    # Color-matched uncertainty band
    c = line.get_color()
    ax.fill_between(KL, mean - std, mean + std, color=c, alpha=0.15, linewidth=0, zorder=2)

    ax.set_xlabel(r"KL($\mathbb{P}_{\mathrm{test}} \,\|\, \mathbb{P}_{\mathrm{train}}$)")
    ax.set_ylabel(ylabel)

    ax.set_axisbelow(True)
    ax.minorticks_on()
    ax.grid(True, which="major")
    ax.grid(True, which="minor", alpha=0.18, linestyle=":", linewidth=0.6)

    # Slight horizontal padding so markers don't touch the border
    ax.margins(x=0.02)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="infile", type=str, default="runs/label_shift_runs_raw_N100.json",
                    help="JSON file produced by the 100-run experiment.")
    ap.add_argument("--outdir", type=str, default="runs",
                    help="Where to save the figures.")
    ap.add_argument("--suffix", type=str, default="", help="Optional suffix for filenames.")
    args = ap.parse_args()

    infile = args.infile
    outdir = args.outdir
    suffix = (("_" + args.suffix) if args.suffix else "")

    if not os.path.exists(infile):
        # Fallback: single-run JSON from earlier step
        single_path = "runs/label_shift_results_N100.json"
        if not os.path.exists(single_path):
            raise FileNotFoundError(f"Could not find {infile} or {single_path}")
        with open(single_path, "r") as f:
            R = json.load(f)
        KL = np.array(R["kl_targets"], dtype=float)
        # accuracy plot
        plt.figure(figsize=(7,5))
        plt.plot(KL, R["erm_acc"], marker='o', label="ERM")
        plt.plot(KL, R["sam_acc"], marker='o', label="SAM")
        plt.plot(KL, R["irs_acc"], marker='o', label="IRS")
        plt.xlabel(r"KL($\mathbb{P}_{test} || \mathbb{P}_{train})$")
        plt.ylabel("Accuracy (test)")
        plt.legend(); plt.tight_layout()
        _ensure_dir(os.path.join(outdir, "label_shift_acc_single_N100.pdf"))
        plt.savefig(os.path.join(outdir, f"label_shift_acc_single_N100{suffix}.pdf"), dpi=160)

        # CE plot
        plt.figure(figsize=(7,5))
        plt.plot(KL, R["erm_ce"], marker='o', label="ERM")
        plt.plot(KL, R["sam_ce"], marker='o', label="SAM")
        plt.plot(KL, R["irs_ce"], marker='o', label="IRS")
        plt.xlabel(r"KL($\mathbb{P}_{test} || \mathbb{P}_{train})$")
        plt.ylabel("Cross-entropy loss (test)")
        plt.legend(); plt.tight_layout()
        _ensure_dir(os.path.join(outdir, "label_shift_ce_single_N100.pdf"))
        plt.savefig(os.path.join(outdir, f"label_shift_ce_single_N100{suffix}.pdf"), dpi=160)
        print("[saved] single-run plots.")
        return

    # N-run aggregated file
    with open(infile, "r") as f:
        R = json.load(f)

    KL = np.array(R["kl_targets"], dtype=float)

    # accuracy arrays (N_runs, num_points)
    acc_erm = np.array(R["acc_erm_runs"], dtype=float)
    acc_sam = np.array(R["acc_sam_runs"], dtype=float)
    acc_irs = np.array(R["acc_irs_runs"], dtype=float)

    # CE arrays (optional)
    ce_erm = np.array(R["ce_erm_runs"], dtype=float) if "ce_erm_runs" in R else None
    ce_sam = np.array(R["ce_sam_runs"], dtype=float) if "ce_sam_runs" in R else None
    ce_irs = np.array(R["ce_irs_runs"], dtype=float) if "ce_irs_runs" in R else None

    def mean_std(A): return A.mean(axis=0), A.std(axis=0, ddof=0)

    # --- Accuracy mean ± std ---
    # fig, ax = plt.subplots(figsize=(7,5))
    # for A, lbl in [(acc_erm, "ERM"), (acc_sam, "SAM"), (acc_irs, "IRS")]:
    #     m, s = mean_std(A)
    #     plot_mean_std(KL, m, s, lbl, ax, ylabel="Accuracy (test)")
    # ax.legend(); plt.tight_layout()
    # _ensure_dir(os.path.join(outdir, "label_shift_acc_mean_std.pdf"))
    # plt.savefig(os.path.join(outdir, f"label_shift_acc_mean_std{suffix}.pdf"), dpi=160)
    fig, ax = plt.subplots(figsize=(7,5))
    for A, lbl in [(acc_erm, "ERM"), (acc_sam, "SAM"), (acc_irs, "IRS")]:
        m, s = mean_std(A)
        plot_mean_std(KL, m, s, lbl, ax, ylabel="Accuracy (test)", highlight=(lbl=="IRS"))
    ax.legend(loc="best")
    plt.tight_layout()
    _ensure_dir(os.path.join(outdir, "label_shift_acc_mean_std_N100.pdf"))
    plt.savefig(
        os.path.join(outdir, f"label_shift_acc_mean_std_N100{suffix}.pdf"),
        dpi=200,
        bbox_inches="tight"
    )
    plt.close(fig)


    # # --- CE mean ± std (if present) ---
    # if ce_erm is not None:
    #     fig, ax = plt.subplots(figsize=(7,5))
    #     for C, lbl in [(ce_erm, "ERM"), (ce_sam, "SAM"), (ce_irs, "IRS")]:
    #         m, s = mean_std(C)
    #         plot_mean_std(KL, m, s, lbl, ax, ylabel="Cross-entropy loss (test)")
    #     ax.legend(); plt.tight_layout()
    #     _ensure_dir(os.path.join(outdir, "label_shift_ce_mean_std.pdf"))
    #     plt.savefig(os.path.join(outdir, f"label_shift_ce_mean_std{suffix}.pdf"), dpi=160)
    fig, ax = plt.subplots(figsize=(7,5))
    for C, lbl in [(ce_erm, "ERM"), (ce_sam, "SAM"), (ce_irs, "IRS")]:
        m, s = mean_std(C)
        plot_mean_std(KL, m, s, lbl, ax, ylabel="Cross-entropy loss (test)", highlight=(lbl=="IRS"))
    ax.legend(loc="best")
    plt.tight_layout()
    _ensure_dir(os.path.join(outdir, "label_shift_ce_mean_std_N100.pdf"))
    plt.savefig(
        os.path.join(outdir, f"label_shift_ce_mean_std_N100{suffix}.pdf"),
        dpi=200,
        bbox_inches="tight"
    )
    plt.close(fig)


    print("[saved] plots to:", outdir)

if __name__ == "__main__":
    main()
