#!/usr/bin/env python3
"""
plot_loss_and_alignment_heatmaps.py

Create publication-ready heatmaps of validation loss and tail-averaged
alignment (L2) over learning-rate × weight-decay grids.

Now with caching of the collected matrix data; if a cache file exists,
it will be loaded instead of re-parsing all runs.
"""

from pathlib import Path
import json
import math
from typing import List, Tuple

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# ── CONFIGURATION ───────────────────────────────────────────────────────────
EXP_DIR       = Path("results")              # parent folder containing runs
FILE_PATTERN  = "__path__"
TAIL_STEPS    = 100
ALIGN_METRIC  = "L2"
X_PARAM       = "batch_size"
Y_PARAM       = "weight_decay"

FIGSIZE       = (4, 3)
DPI           = 300
FONT_SIZE     = 8
TITLE_SIZE    = 10
CMAP_LOSS     = "viridis"
CMAP_ALIGN    = "bwr"
OUTPUT        = "batch_size_wd.png"

# ── CACHING CONFIG ──────────────────────────────────────────────────────────
CACHE_DIR     = Path("/om2/user/dkoplow/gradient_hebbian/"
                     "clean_expirement_code/cleaner_code/"
                     "cleaned_plots/cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
CACHE_FILE    = CACHE_DIR / (
    f"{FILE_PATTERN}_tail{TAIL_STEPS}_metric{ALIGN_METRIC}_"
    f"x{X_PARAM}_y{Y_PARAM}.npz"
)
# ────────────────────────────────────────────────────────────────────────────

def _parse_num(x):
    try:
        return float(x)
    except Exception:
        return None

def _get_param(run_dir: Path, key: str):
    cfg = run_dir / "metrics" / "config.json"
    if not cfg.exists():
        return None
    data = json.loads(cfg.read_text())
    val = data.get(key)
    if isinstance(val, (int, float)):
        return float(val)
    if isinstance(val, str):
        num = _parse_num(val)
        return num if num is not None else val
    return None

def _tail_alignment(run_dir: Path, tail: int, metric: str):
    f = run_dir / "metrics" / "frac_pos_alignment.json"
    if not f.exists():
        return None
    data = json.loads(f.read_text())
    if not data:
        return None
    seq = (
        [d.get("alignments", {}).get(metric) for d in data]
        if isinstance(data[0], dict)
        else data
    )
    vals = [v for v in seq if v is not None]
    if not vals:
        return None
    return float(np.mean(vals[-tail:]))

def _final_val_loss(run_dir: Path):
    f = run_dir / "metrics" / "metrics.json"
    if not f.exists():
        return None
    data = json.loads(f.read_text())
    if not data:
        return None
    v = data[-1].get("val_acc")
    if isinstance(v, list) and v:
        return float(v[-1])
    if isinstance(v, (int, float)):
        return float(v)
    return None

def collect_matrices(
    exp_dir: Path,
    pattern: str,
    tail: int,
    metric: str,
    x_param: str,
    y_param: str,
) -> Tuple[List[float], List[float], np.ndarray, np.ndarray]:

    roots = sorted(exp_dir.glob(pattern))
    runs  = [sub for rd in roots for sub in rd.iterdir() if sub.is_dir()]

    records = []
    for r in runs:
        x = _get_param(r, x_param)
        y = _get_param(r, y_param)
        if None in (x, y):
            continue
        loss  = _final_val_loss(r)
        align = _tail_alignment(r, tail, metric)
        if None in (loss, align):
            continue
        records.append((float(x), float(y), loss, align))

    if not records:
        raise RuntimeError("No valid runs found!")

    x_vals = sorted({r[0] for r in records})
    y_vals = sorted({r[1] for r in records})
    xi = {v: i for i, v in enumerate(x_vals)}
    yi = {v: i for i, v in enumerate(y_vals)}

    loss_mat  = np.full((len(y_vals), len(x_vals)), np.nan)
    align_mat = np.full_like(loss_mat, np.nan)

    tmp = {}
    for x, y, ls, al in records:
        idx = (yi[y], xi[x])
        tmp.setdefault(idx, {"loss": [], "align": []})
        tmp[idx]["loss"].append(ls)
        tmp[idx]["align"].append(al)

    for (i, j), v in tmp.items():
        loss_mat[i, j]  = np.mean(v["loss"])
        align_mat[i, j] = np.mean(v["align"])

    return x_vals, y_vals, loss_mat, align_mat

def _format_sci(x: float) -> str:
    return f"{x:.1e}".replace("e-0", "e-").replace("e+0", "e+")

def plot_heatmap(ax, x_vals, y_vals, mat, title, cmap):
    if title == "Alignment":
        im = ax.imshow(mat, origin="lower", cmap=cmap, vmin=-0.5, vmax=0.5)
    else:
        im = ax.imshow(mat, origin="lower", cmap=cmap)

    mask = np.isnan(mat)
    for i, j in zip(*np.where(mask)):
        patch = Rectangle((j - 0.5, i - 0.5), 1, 1,
                          facecolor='none', edgecolor='grey',
                          hatch='///', linewidth=0)
        ax.add_patch(patch)

    xt = np.linspace(0, len(x_vals) - 1, len(x_vals), dtype=int)
    yt = np.linspace(0, len(y_vals) - 1, len(y_vals), dtype=int)

    ax.set_xticks(xt)
    ax.set_xticklabels([int(x_vals[i]) for i in xt],
                       rotation=45, ha="right", fontsize=6)
    ax.set_xlabel("Batch Size")

    if title != "Alignment":
        ax.set_yticks(yt)
        ax.set_yticklabels([_format_sci(y_vals[i]) for i in yt], fontsize=6)
        ax.set_ylabel(r"Weight Decay $\gamma$")
    else:
        ax.set_yticks([])
        ax.set_ylabel("")

    ax.set_title(title, fontsize=TITLE_SIZE)
    return im

def main():
    # ── Load from cache if available ────────────────────────────────────────
    if CACHE_FILE.exists():
        print(f"Loading cached data from {CACHE_FILE}")
        with np.load(CACHE_FILE) as data:
            x_vals    = data['x_vals'].tolist()
            y_vals    = data['y_vals'].tolist()
            loss_mat  = data['loss_mat']
            align_mat = data['align_mat']
    else:
        print("Cache not found; computing matrices...")
        x_vals, y_vals, loss_mat, align_mat = collect_matrices(
            EXP_DIR, FILE_PATTERN, TAIL_STEPS,
            ALIGN_METRIC, X_PARAM, Y_PARAM
        )
        np.savez_compressed(
            CACHE_FILE,
            x_vals=np.array(x_vals),
            y_vals=np.array(y_vals),
            loss_mat=loss_mat,
            align_mat=align_mat
        )
        print(f"Saved cache to {CACHE_FILE}")

    fig, axes = plt.subplots(
        1, 2, figsize=FIGSIZE, dpi=DPI, constrained_layout=True
    )
    fig.suptitle(
        "Validation Accuracy and\nAlignment for Layer 2",
        fontsize=TITLE_SIZE, fontweight="bold", y=.9
    )

    im0 = plot_heatmap(axes[0], x_vals, y_vals, loss_mat,
                       title="Validation Accuracy", cmap=CMAP_LOSS)
    im1 = plot_heatmap(axes[1], x_vals, y_vals, align_mat,
                       title="Alignment", cmap=CMAP_ALIGN)

    for ax, im, label in zip(axes, (im0, im1), ("Loss", "Alignment")):
        cbar = fig.colorbar(
            im, ax=ax, location="bottom",
            fraction=0.07, pad=0.08, aspect=30,
            orientation="horizontal"
        )
        cbar.ax.tick_params(labelsize=FONT_SIZE)
        cbar.set_label(label, fontsize=FONT_SIZE)

    if OUTPUT:
        fig.savefig(OUTPUT, dpi=DPI)
        print(f"Saved → {OUTPUT}")
    else:
        plt.show()

if __name__ == "__main__":
    main()