import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, AutoMinorLocator

def load_losses(path: str):
    with open(path, "r", encoding="utf-8") as f:
        d = json.load(f)

    meta = d.get("meta", {})
    step = int(meta.get("step", 50))

    def arr(key):
        return np.asarray(d.get(key, []), dtype=float)

    return step, arr("on_easy_pos_loss"), arr("on_hard_pos_loss"), arr("off_easy_pos_loss"), arr("off_hard_pos_loss")

def plot_series(ax, y, step, *, color, ls, marker, label):
    if y.size == 0:
        return
    x = np.arange(0, step * y.size, step)
    ax.plot(x, y, color=color, lw=2.2, ls=ls, marker=marker, ms=3,
            markevery=max(1, y.size // 12), label=label)

def main():
    step, on_easy, on_hard, off_easy, off_hard = load_losses("artifacts/losses.json")

    plt.rcParams.update({
        "figure.dpi": 150,
        "savefig.dpi": 300,
        "font.size": 11,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "axes.spines.top": False,
        "axes.spines.right": False,
    })

    EASY_C = "#1f77b4"
    HARD_C = "#d62728"

    fig, ax = plt.subplots(figsize=(7, 3.6), constrained_layout=True)

    plot_series(ax, on_easy,  step, color=EASY_C, ls="-",  marker="o", label="Easy Modes (on-policy)")
    plot_series(ax, off_easy, step, color=EASY_C, ls="--", marker="o", label="Easy Modes (off-policy)")
    plot_series(ax, on_hard,  step, color=HARD_C, ls="-",  marker="s", label="Hard Modes (on-policy)")
    plot_series(ax, off_hard, step, color=HARD_C, ls="--", marker="s", label="Hard Modes (off-policy)")

    ax.set_xlabel("Epoch")
    ax.set_ylabel("loss")
    #ax.set_ylim(bottom=0.0, top=60)

    ax.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=6))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.xaxis.set_minor_locator(AutoMinorLocator())
    ax.yaxis.set_minor_locator(AutoMinorLocator())
    #ax.grid(True, which="major", linewidth=0.8, alpha=0.25)
    #ax.grid(True, which="minor", linewidth=0.5, alpha=0.12)

    ax.legend(loc="upper center", bbox_to_anchor=(0.75, 1), ncol=1,
              frameon=True, handlelength=3.2)
    # ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.22), ncol=2,
    #           frameon=False, handlelength=2.2)

    fig.savefig("hard_vs_easy_modes.png", bbox_inches="tight")
    fig.savefig("hard_vs_easy_modes.pdf", bbox_inches="tight")
    plt.close(fig)

if __name__ == "__main__":
    main()