#!/usr/bin/env python3
"""
plot_loss_and_alignment_heatmaps.py

Create publication-ready heatmaps of validation loss and tail-averaged
alignment (L2) over learning-rate × weight-decay grids.

Changes vs. previous version
----------------------------
* Alignment heatmap now uses a diverging red–blue colormap (“bwr”).
* Figure uses `constrained_layout=True`; no overlap between title,
  panels, or colour-bars.
* Tick labels are always formatted in scientific notation (e.g. 2e-4);
  per-axis exponent ‘×10^k’ annotations have been removed.
"""

from pathlib import Path
import json
import math
from typing import List, Tuple

import numpy as np
import matplotlib.pyplot as plt

# ── CONFIGURATION ───────────────────────────────────────────────────────────
EXP_DIR       = Path("results")              # parent folder containing runs
FILE_PATTERN  = "__path__"       # top-level pattern
TAIL_STEPS    = 100                          # for alignment averaging
ALIGN_METRIC  = "L2"
X_PARAM       = "lr"
Y_PARAM       = "weight_decay"

FIGSIZE       = (4, 3)     # overall figure size (inches)
DPI           = 300
FONT_SIZE     = 8
TITLE_SIZE    = 10
CMAP_LOSS     = "viridis"
CMAP_ALIGN    = "bwr"      # <- red–blue diverging
OUTPUT        = "lr_wd.png"
# ────────────────────────────────────────────────────────────────────────────

# global rcParams

# ── Helpers for parsing run folders ─────────────────────────────────────────
def _parse_num(x):
    try:
        return float(x)
    except Exception:
        return None

def _get_param(run_dir: Path, key: str):
    cfg = run_dir / "metrics" / "config.json"
    if not cfg.exists():
        return None
    data = json.loads(cfg.read_text())
    val = data.get(key)
    if isinstance(val, (int, float)):
        return float(val)
    if isinstance(val, str):
        num = _parse_num(val)
        return num if num is not None else val
    return None

def _tail_alignment(run_dir: Path, tail: int, metric: str):
    f = run_dir / "metrics" / "frac_pos_alignment.json"
    if not f.exists():
        return None
    data = json.loads(f.read_text())
    if not data:
        return None
    seq = (
        [d.get("alignments", {}).get(metric) for d in data]
        if isinstance(data[0], dict)
        else data
    )
    vals = [v for v in seq if v is not None]
    if not vals:
        return None
    return float(np.mean(vals[-tail:]))

def _final_val_loss(run_dir: Path):
    f = run_dir / "metrics" / "metrics.json"
    if not f.exists():
        return None
    data = json.loads(f.read_text())
    if not data:
        return None
    v = data[-1].get("val_acc")
    if isinstance(v, list) and v:
        return float(v[-1])
    if isinstance(v, (int, float)):
        return float(v)
    return None

# ── Collect matrices for heatmaps ───────────────────────────────────────────
def collect_matrices(
    exp_dir: Path,
    pattern: str,
    tail: int,
    metric: str,
    x_param: str,
    y_param: str,
) -> Tuple[List[float], List[float], np.ndarray, np.ndarray]:

    roots = sorted(exp_dir.glob(pattern))
    runs  = [sub for rd in roots for sub in rd.iterdir() if sub.is_dir()]

    records = []
    for r in runs:
        x = _get_param(r, x_param)
        y = _get_param(r, y_param)
        if None in (x, y):
            continue
        loss  = _final_val_loss(r)
        align = _tail_alignment(r, tail, metric)
        if None in (loss, align):
            continue
        records.append((float(x), float(y), loss, align))

    if not records:
        raise RuntimeError("No valid runs found!")

    x_vals = sorted({r[0] for r in records})
    y_vals = sorted({r[1] for r in records})
    xi = {v: i for i, v in enumerate(x_vals)}
    yi = {v: i for i, v in enumerate(y_vals)}

    loss_mat  = np.full((len(y_vals), len(x_vals)), np.nan)
    align_mat = np.full_like(loss_mat, np.nan)

    tmp = {}
    for x, y, ls, al in records:
        idx = (yi[y], xi[x])
        tmp.setdefault(idx, {"loss": [], "align": []})
        tmp[idx]["loss"].append(ls)
        tmp[idx]["align"].append(al)

    for (i, j), v in tmp.items():
        loss_mat[i, j]  = np.mean(v["loss"])
        align_mat[i, j] = np.mean(v["align"])

    return x_vals, y_vals, loss_mat, align_mat

# ── Plotting utilities ──────────────────────────────────────────────────────
def _format_sci(x: float) -> str:
    """Return *x* as 'ae±bb', e.g. 1e-04."""
    return f"{x:.0e}".replace("e-0", "e-").replace("e+0", "e+")

def plot_heatmap(ax, x_vals, y_vals, mat, title, cmap):

    if title == "Alignment":
        im = ax.imshow(mat, origin="lower", cmap=cmap, aspect="equal",vmin=-1,vmax=1)
    else:
        im = ax.imshow(mat, origin="lower", cmap=cmap, aspect="equal")

    # choose ≤6 tick positions
    xt = np.linspace(0, len(x_vals) - 1, min(6, len(x_vals)), dtype=int)
    yt = np.linspace(0, len(y_vals) - 1, min(6, len(y_vals)), dtype=int)

    ax.set_xticks(xt)


    ax.set_xticklabels([_format_sci(x_vals[i]) for i in xt],
                       rotation=45, ha="right", fontsize=6)

    ax.set_xlabel(rf"Learning Rate $\eta$")

    if title != "Alignment":

        ax.set_yticks(yt)
        ax.set_yticklabels([_format_sci(y_vals[i]) for i in yt], fontsize=6)
        ax.set_ylabel(rf"Weight Decay $\gamma$")
    else:
        ax.set_yticks([])
        ax.set_ylabel("")

    ax.set_title(title, fontsize=TITLE_SIZE)

    return im

# ── Main ────────────────────────────────────────────────────────────────────
def main():
    x_vals, y_vals, loss_mat, align_mat = collect_matrices(
        EXP_DIR, FILE_PATTERN, TAIL_STEPS, ALIGN_METRIC, X_PARAM, Y_PARAM
    )

    fig, axes = plt.subplots(
        1, 2,
        figsize=FIGSIZE,
        dpi=DPI,
        constrained_layout=True,  # avoids overlap automatically
    )
    fig.suptitle(
        "Validation Accuracy and Alignment for Layer 2",
        fontsize=TITLE_SIZE,
        fontweight="bold",
        y=1.05            # push title a bit above panels
    )

    im0 = plot_heatmap(axes[0], x_vals, y_vals, loss_mat,
                       title="Validation Accuracy", cmap=CMAP_LOSS)
    im1 = plot_heatmap(axes[1], x_vals, y_vals, align_mat,
                       title="Alignment", cmap=CMAP_ALIGN)

    # colour-bars (horizontal, under each panel)
    for ax, im, label in zip(axes, (im0, im1), ("Accuracy", "Alignment")):
        cbar = fig.colorbar(
            im, ax=ax, location="bottom",
            fraction=0.07, pad=0.08, aspect=30,
            orientation="horizontal"
        )
        cbar.ax.tick_params(labelsize=FONT_SIZE)
        cbar.set_label(label, fontsize=FONT_SIZE)

    if OUTPUT:
        fig.savefig(OUTPUT, dpi=DPI, bbox_inches="tight")
        print(f"Saved → {OUTPUT}")
    else:
        plt.show()

if __name__ == "__main__":
    main()