# export_icml_figs_simple.py
from __future__ import annotations

import json
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.ticker import FuncFormatter

# Python 3.11+
import tomllib

from diffusionGrid.diffusionGrid_rewards import build_log_reward_fn


def millions(x, pos):
    return rf"{x/1e6:g}M"  # g evita .0 desnecessário


# -----------------------------
# Matplotlib preamble
# -----------------------------
matplotlib.rcParams.update({
    "font.family": "serif",
    "font.size": 14.0,
    "lines.linewidth": 2,
    "lines.antialiased": True,
    "axes.facecolor": "fdfdfd",
    "axes.edgecolor": "777777",
    "axes.linewidth": 1,
    "axes.titlesize": "medium",
    "axes.labelsize": "medium",
    "axes.axisbelow": True,
    "xtick.color": "333333",
    "xtick.labelsize": "medium",
    "xtick.direction": "in",
    "ytick.major.size": 0,
    "ytick.minor.size": 0,
    "ytick.major.pad": 6,
    "ytick.minor.pad": 6,
    "ytick.color": "333333",
    "ytick.labelsize": "medium",
    "ytick.direction": "in",
    "axes.grid": True,
    "grid.alpha": 0.3,
    "grid.linewidth": 1,
    "legend.fancybox": True,
    "legend.fontsize": "Small",
    "figure.facecolor": "1.0",
    "figure.edgecolor": "0.5",
    "hatch.linewidth": 0.1,
    "text.usetex": True,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})
plt.rcParams["text.latex.preamble"] = r"\usepackage{times, amsmath, amssymb}"


def my_formatter(x, pos):
    val_str = "{:g}".format(x)
    if np.abs(x) > 0 and np.abs(x) < 1:
        return val_str.replace("0", "", 1)
    return val_str


major_formatter = FuncFormatter(my_formatter)


def apply_format(ax):
    ax.xaxis.set_major_formatter(major_formatter)
    ax.yaxis.set_major_formatter(major_formatter)


# -----------------------------
# Names + colors
# -----------------------------
field_to_name = {
    "std": r"Trajectory Balance\\(Malkin et al., NeurIPS 2022)",
    "div": r"Divergent (\textbf{Ours})",
    "random": r"Random",
    "teacher": r"Adaptive Teacher\\(Kim et al., ICLR 2025)",
    "sa": r"Sibling Augmented\\(Madan et al., ICLR 2025)",
}

field_to_color = {
    "std": "#b22222",
    "div": "#1f77b4",
    "random": "#808080",
    "teacher": "#911eb4",
    "sa": "#f58231",
}


# -----------------------------
# Helpers
# -----------------------------
def save_legend(handles, labels, out_path: str | Path, ncol=2, fontsize=10):
    fig_leg = plt.figure(figsize=(6.2, 0.55))
    fig_leg.legend(
        handles, labels,
        loc="center",
        ncol=ncol,
        frameon=False,
        fontsize=fontsize,
        handlelength=2.2,
        columnspacing=1.4,
    )
    fig_leg.canvas.draw()
    fig_leg.savefig(out_path, bbox_inches="tight", pad_inches=0.0)
    plt.close(fig_leg)


def load_eval(run_dir: str | Path,
              x_key: str = "epoch",
              l1_key: str = "l1_tv",
              easy_key: str = "easy_pos_loss",
              hard_key: str = "hard_pos_loss",
              filename: str = "eval.jsonl"):
    """
    Lê .../eval.jsonl e retorna:
      steps_l1, l1_tv, steps_loss, easy_loss, hard_loss

    - Aceita tanto o diretório quanto o caminho direto pro arquivo.
    """
    p = Path(run_dir)
    if p.is_dir():
        p = p / filename

    steps_l1, l1s = [], []
    steps_loss, easy_losses, hard_losses = [], [], []

    if not p.exists():
        raise FileNotFoundError(f"Arquivo não encontrado: {p}")

    with open(p, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            row = json.loads(line)

            if row.get(x_key) is None:
                continue
            ep = int(row[x_key])

            if row.get(l1_key) is not None:
                steps_l1.append(ep)
                l1s.append(float(row[l1_key]))

            if row.get(easy_key) is not None and row.get(hard_key) is not None:
                steps_loss.append(ep)
                easy_losses.append(float(row[easy_key]))
                hard_losses.append(float(row[hard_key]))

    return (np.array(steps_l1), np.array(l1s),
            np.array(steps_loss), np.array(easy_losses), np.array(hard_losses))


def list_seed_runs(base_run_dir: str | Path) -> list[Path]:
    """
    Se existir runs/.../run_id/seed_*/ -> retorna todas as seed dirs.
    Senão, retorna [base_run_dir] (fallback single-run).
    """
    base = Path(base_run_dir)
    if not base.is_dir():
        raise FileNotFoundError(f"Base run dir não existe: {base}")
    seeds = sorted([p for p in base.glob("seed_*") if p.is_dir()])
    return seeds if len(seeds) > 0 else [base]


def _align_and_stack(steps_list: list[np.ndarray], vals_list: list[np.ndarray]) -> tuple[np.ndarray, np.ndarray]:
    """
    Alinha por interseção de steps e retorna:
      steps_common (sorted), Y (K x T)
    """
    if len(steps_list) != len(vals_list) or len(steps_list) == 0:
        raise ValueError("steps_list/vals_list inválidos")

    common = set(steps_list[0].tolist())
    for s in steps_list[1:]:
        common &= set(s.tolist())
    if len(common) == 0:
        raise ValueError("Sem interseção de steps entre seeds (verifique log frequency/arquivos).")

    steps_common = np.array(sorted(common), dtype=np.int64)

    Y = []
    for steps, vals in zip(steps_list, vals_list):
        idx = {int(e): i for i, e in enumerate(steps.tolist())}
        Y.append(np.array([vals[idx[int(e)]] for e in steps_common], dtype=float))
    Y = np.stack(Y, axis=0)  # (K, T)
    return steps_common, Y


def aggregate_eval_over_seeds(base_run_dir: str | Path,
                             batch_size: int,
                             x_key: str = "epoch",
                             l1_key: str = "l1_tv",
                             easy_key: str = "easy_pos_loss",
                             hard_key: str = "hard_pos_loss",
                             filename: str = "eval.jsonl"):
    """
    Lê eval.jsonl em cada seed_* e devolve:
      steps_l1_traj, l1_mean, l1_std,
      steps_loss_traj, easy_mean, easy_std, hard_mean, hard_std,
      num_seeds
    """
    seed_dirs = list_seed_runs(base_run_dir)

    steps_l1_list, l1_list = [], []
    steps_loss_list, easy_list, hard_list = [], [], []

    for sd in seed_dirs:
        s_l1, l1, s_loss, easy, hard = load_eval(
            sd, x_key=x_key, l1_key=l1_key, easy_key=easy_key, hard_key=hard_key, filename=filename
        )
        if len(s_l1) > 0:
            steps_l1_list.append(s_l1.astype(np.int64))
            l1_list.append(l1.astype(float))
        if len(s_loss) > 0:
            steps_loss_list.append(s_loss.astype(np.int64))
            easy_list.append(easy.astype(float))
            hard_list.append(hard.astype(float))

    K = len(seed_dirs)

    # --- L1 ---
    if len(steps_l1_list) > 0:
        steps_l1_common, Y_l1 = _align_and_stack(steps_l1_list, l1_list)
        ddof = 1 if Y_l1.shape[0] > 1 else 0
        l1_mean = Y_l1.mean(axis=0)
        l1_std  = Y_l1.std(axis=0, ddof=ddof)
        steps_l1_traj = steps_l1_common * int(batch_size)
    else:
        steps_l1_traj = np.array([], dtype=np.int64)
        l1_mean = np.array([], dtype=float)
        l1_std  = np.array([], dtype=float)

    # --- Losses ---
    if len(steps_loss_list) > 0:
        steps_loss_common, Y_easy = _align_and_stack(steps_loss_list, easy_list)
        _,                Y_hard = _align_and_stack(steps_loss_list, hard_list)
        ddof = 1 if Y_easy.shape[0] > 1 else 0

        easy_mean = Y_easy.mean(axis=0)
        easy_std  = Y_easy.std(axis=0, ddof=ddof)
        hard_mean = Y_hard.mean(axis=0)
        hard_std  = Y_hard.std(axis=0, ddof=ddof)

        steps_loss_traj = steps_loss_common * int(batch_size)
    else:
        steps_loss_traj = np.array([], dtype=np.int64)
        easy_mean = easy_std = hard_mean = hard_std = np.array([], dtype=float)

    return (steps_l1_traj, l1_mean, l1_std,
            steps_loss_traj, easy_mean, easy_std, hard_mean, hard_std,
            K)


def find_experiments_toml() -> Path:
    candidates = [
        Path("diffusionGrid/experiments.toml"),
        Path("experiments.toml"),
        Path("diffusionGrid/experiments/experiments.toml"),
    ]
    for c in candidates:
        if c.exists():
            return c
    raise FileNotFoundError(
        "Não encontrei experiments.toml. Tente colocar em diffusionGrid/experiments.toml "
        "ou ajuste os candidates em find_experiments_toml()."
    )


def load_batch_size_from_toml(toml_path: Path, run_id: str) -> int:
    with open(toml_path, "rb") as f:
        cfg = tomllib.load(f)

    defaults = cfg.get("defaults", {}) or {}
    runs = cfg.get("runs", []) or []

    for r in runs:
        if str(r.get("run_id", "")) == run_id:
            if "batch_size" in r:
                return int(r["batch_size"])
            if "batch_size" in defaults:
                return int(defaults["batch_size"])
            raise KeyError(
                f"run_id={run_id} encontrado, mas batch_size não existe nem no [[runs]] nem em [defaults]."
            )

    raise KeyError(f"run_id={run_id} não encontrado em {toml_path}.")


def run_id_from_run_dir(run_dir: str | Path) -> str:
    return Path(run_dir).name


# -----------------------------
# Paths (ajuste aqui)
# -----------------------------
OUT = Path("icml_figs")
OUT.mkdir(parents=True, exist_ok=True)

OURS = "runs/diffusionGrid/dtb/run_8g_reward_alpha0.25_eps0.1"
SA   = "runs/diffusionGrid/sa/run_8g_reward_alpha0.25_eps0.1"
TB   = "runs/diffusionGrid/tb/run_8g_reward_alpha0.25_eps0.1_test"
TS   = "runs/diffusionGrid/teacher_student/run_8g_reward_alpha0.25_eps0.1"

# -----------------------------
# Batch sizes (do experiments.toml)
# -----------------------------
TOML_PATH = find_experiments_toml()

bs_ours = load_batch_size_from_toml(TOML_PATH, run_id_from_run_dir(OURS))
bs_sa   = load_batch_size_from_toml(TOML_PATH, run_id_from_run_dir(SA))
bs_tb   = load_batch_size_from_toml(TOML_PATH, run_id_from_run_dir(TB))
bs_ts   = load_batch_size_from_toml(TOML_PATH, run_id_from_run_dir(TS))

# -----------------------------
# Load + aggregate evals por seeds
# -----------------------------
(ours_steps_l1, ours_l1_mean, ours_l1_std,
 ours_steps_loss, ours_easy_mean, ours_easy_std, ours_hard_mean, ours_hard_std,
 K_ours) = aggregate_eval_over_seeds(OURS, batch_size=bs_ours)

(sa_steps_l1, sa_l1_mean, sa_l1_std,
 sa_steps_loss, sa_easy_mean, sa_easy_std, sa_hard_mean, sa_hard_std,
 K_sa) = aggregate_eval_over_seeds(SA, batch_size=bs_sa)

(tb_steps_l1, tb_l1_mean, tb_l1_std,
 tb_steps_loss, tb_easy_mean, tb_easy_std, tb_hard_mean, tb_hard_std,
 K_tb) = aggregate_eval_over_seeds(TB, batch_size=bs_tb)

(ts_steps_l1, ts_l1_mean, ts_l1_std,
 ts_steps_loss, ts_easy_mean, ts_easy_std, ts_hard_mean, ts_hard_std,
 K_ts) = aggregate_eval_over_seeds(TS, batch_size=bs_ts)

# -----------------------------
# (1) L1 figure (export) — mean + std band
# -----------------------------
fig = plt.figure(figsize=(6.2, 3.8))
ax = fig.add_subplot(1, 1, 1)

h_ours, = ax.plot(ours_steps_l1, ours_l1_mean, label=field_to_name["div"], color=field_to_color["div"])
ax.fill_between(ours_steps_l1, ours_l1_mean - ours_l1_std, ours_l1_mean + ours_l1_std,
                color=field_to_color["div"], alpha=0.18, linewidth=0)

h_sa, = ax.plot(sa_steps_l1, sa_l1_mean, label=field_to_name["sa"], color=field_to_color["sa"])
ax.fill_between(sa_steps_l1, sa_l1_mean - sa_l1_std, sa_l1_mean + sa_l1_std,
                color=field_to_color["sa"], alpha=0.18, linewidth=0)

h_tb, = ax.plot(tb_steps_l1, tb_l1_mean, label=field_to_name["std"], color=field_to_color["std"])
ax.fill_between(tb_steps_l1, tb_l1_mean - tb_l1_std, tb_l1_mean + tb_l1_std,
                color=field_to_color["std"], alpha=0.18, linewidth=0)

h_ts, = ax.plot(ts_steps_l1, ts_l1_mean, label=field_to_name["teacher"], color=field_to_color["teacher"])
ax.fill_between(ts_steps_l1, ts_l1_mean - ts_l1_std, ts_l1_mean + ts_l1_std,
                color=field_to_color["teacher"], alpha=0.18, linewidth=0)

ax.set_xlabel(r"Sampled trajectories")
ax.set_ylabel(r"$\ell_1$ TV distance")
apply_format(ax)
ax.xaxis.set_major_formatter(FuncFormatter(millions))

fig.canvas.draw()
fig.savefig(OUT / "l1_tv.pdf", bbox_inches="tight", pad_inches=0.0)
fig.savefig(OUT / "l1_tv.png", dpi=300, bbox_inches="tight", pad_inches=0.0)
plt.close(fig)

# -----------------------------
# (2) Target verdadeiro (Rings) + highlight circular grid-aligned
# -----------------------------
N = 18
r_easy = 9.0  # raio do círculo "easy" em unidades do grid

xs = torch.arange(-N, N + 1, dtype=torch.get_default_dtype())
ys = torch.arange(-N, N + 1, dtype=torch.get_default_dtype())
X, Y = torch.meshgrid(xs, ys, indexing="ij")
pos = torch.stack([X.flatten(), Y.flatten()], dim=-1)

logR_fn = build_log_reward_fn("rings", radii=[0.2 * N, 0.8 * N], sigma_r=1.0, weights=None)
dummy_env = SimpleNamespace(pos=pos, width=N, height=N)
Rmat = logR_fn(dummy_env).exp().view(2 * N + 1, 2 * N + 1).detach().cpu().numpy()

x_edges = np.arange(-N - 0.5, N + 1.5, 1.0)
y_edges = np.arange(-N - 0.5, N + 1.5, 1.0)
Xe, Ye = np.meshgrid(x_edges, y_edges, indexing="ij")

inside = (X**2 + Y**2) < (r_easy**2)
inside_np = inside.detach().cpu().numpy()

overlay = np.full_like(Rmat, np.nan, dtype=float)
overlay[inside_np] = 1.0

# -----------------------------
# (3) Painel A separado (quadrado)
# -----------------------------
figA = plt.figure(figsize=(5.2, 5.2))
axA = figA.add_subplot(1, 1, 1)

im = axA.pcolormesh(Xe, Ye, Rmat, shading="flat", edgecolors="k", linewidth=0.25)
axA.set_xlim([-N - 0.5, N + 0.5])
axA.set_ylim([-N - 0.5, N + 0.5])

axA.set_xticks([])
axA.set_yticks([])
axA.set_box_aspect(1)
apply_format(axA)

overlay_color = (1.0, 1.0, 1.0, 1.0)
axA.pcolormesh(
    Xe, Ye, overlay,
    shading="flat",
    cmap=ListedColormap([overlay_color]),
    vmin=0.0, vmax=1.0,
    edgecolors="none",
    alpha=0.10,
)

axA.contour(
    X.detach().cpu().numpy(),
    Y.detach().cpu().numpy(),
    inside_np.astype(float),
    levels=[0.5],
    colors=[overlay_color],
    linewidths=2.0,
)

axA.text(
    0.0, r_easy,
    "Easy Mode",
    color="w",
    fontsize=9,
    ha="center",
    va="bottom",
    weight="bold",
)

figA.canvas.draw()
figA.savefig(OUT / "panel_A_true_reward.pdf", bbox_inches="tight", pad_inches=0.0)
figA.savefig(OUT / "panel_A_true_reward.png", dpi=300, bbox_inches="tight", pad_inches=0.0)
plt.close(figA)

# -----------------------------
# (4) Painel B: loss easy — mean + std band (sem legenda, sem título)
# -----------------------------
figB = plt.figure(figsize=(6.0, 3.0))
axB = figB.add_subplot(1, 1, 1)

h_div, = axB.plot(ours_steps_loss, ours_easy_mean, label=field_to_name["div"], color=field_to_color["div"])
axB.fill_between(ours_steps_loss, ours_easy_mean - ours_easy_std, ours_easy_mean + ours_easy_std,
                 color=field_to_color["div"], alpha=0.18, linewidth=0)

h_std, = axB.plot(tb_steps_loss, tb_easy_mean, label=field_to_name["std"], color=field_to_color["std"])
axB.fill_between(tb_steps_loss, tb_easy_mean - tb_easy_std, tb_easy_mean + tb_easy_std,
                 color=field_to_color["std"], alpha=0.18, linewidth=0)

apply_format(axB)
axB.grid(True)
axB.xaxis.set_major_formatter(FuncFormatter(millions))

figB.canvas.draw()
figB.savefig(OUT / "panel_B_loss_easy.pdf", bbox_inches="tight", pad_inches=0.0)
figB.savefig(OUT / "panel_B_loss_easy.png", dpi=300, bbox_inches="tight", pad_inches=0.0)
plt.close(figB)

# -----------------------------
# (5) Painel C: loss hard — mean + std band (sem legenda, sem título)
# -----------------------------
figC = plt.figure(figsize=(6.0, 3.0))
axC = figC.add_subplot(1, 1, 1)

axC.plot(ours_steps_loss, ours_hard_mean, label=field_to_name["div"], color=field_to_color["div"])
axC.fill_between(ours_steps_loss, ours_hard_mean - ours_hard_std, ours_hard_mean + ours_hard_std,
                 color=field_to_color["div"], alpha=0.18, linewidth=0)

axC.plot(tb_steps_loss, tb_hard_mean, label=field_to_name["std"], color=field_to_color["std"])
axC.fill_between(tb_steps_loss, tb_hard_mean - tb_hard_std, tb_hard_mean + tb_hard_std,
                 color=field_to_color["std"], alpha=0.18, linewidth=0)

axC.set_xlabel(r"Sampled trajectories", fontsize=14)
apply_format(axC)
axC.grid(True)
axC.xaxis.set_major_formatter(FuncFormatter(millions))

figC.canvas.draw()
figC.savefig(OUT / "panel_C_loss_hard.pdf", bbox_inches="tight", pad_inches=0.0)
figC.savefig(OUT / "panel_C_loss_hard.png", dpi=300, bbox_inches="tight", pad_inches=0.0)
plt.close(figC)

# -----------------------------
# (6) Legenda separada das losses
# -----------------------------
save_legend(
    [h_div, h_std],
    [field_to_name["div"], field_to_name["std"]],
    OUT / "legend_losses.pdf",
    ncol=2,
    fontsize=10,
)
save_legend(
    [h_div, h_std],
    [field_to_name["div"], field_to_name["std"]],
    OUT / "legend_losses.png",
    ncol=2,
    fontsize=10,
)

print(f"Saved figures to: {OUT.resolve()}")
print(" - l1_tv.pdf")
print(" - panel_A_true_reward.pdf")
print(" - panel_B_loss_easy.pdf")
print(" - panel_C_loss_hard.pdf")
print(" - legend_losses.pdf")
print(f"Using batch sizes: ours={bs_ours}, sa={bs_sa}, tb={bs_tb}, ts={bs_ts} (from {TOML_PATH})")
print(f"Seeds detected: ours={K_ours}, sa={K_sa}, tb={K_tb}, ts={K_ts}")