#!/usr/bin/env python3
"""
Plot a heat-map of validation loss across hyper-parameter sweeps **and**
a scatter-plot relating the *final* validation loss to a Hebbian alignment
metric, side-by-side in one figure.
"""
from __future__ import annotations

import json
import math
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import FuncFormatter, NullFormatter, ScalarFormatter   # ← added ScalarFormatter

# ─── Parameters you are likely to change ────────────────────────────────────
EXP_DIR = Path("__path__")
X_PARAM = "gradient_noise_fraction"   # horizontal axis in heat-map
Y_PARAM = "weight_decay"              # vertical   axis in heat-map
METRIC  = "L1"
OUTPUT  = EXP_DIR / "val_loss_heatmap_and_scatter.png"   # set to None to plt.show

# ----------------------------------------------------------------------------
#  Helpers to read JSON safely
# ----------------------------------------------------------------------------
def _load_json(path: Path):
    try:
        return json.loads(path.read_text()) if path.exists() else None
    except json.JSONDecodeError:
        return None


def _load_config(run_dir: Path):
    return _load_json(run_dir / "metrics" / "config.json")


def _final_val_loss(run_dir: Path):
    """Return the **last** entry of the "val_loss" series or None"""
    data = _load_json(run_dir / "metrics" / "metrics.json")
    if data is None:
        return None
    vals = data[-1].get("val_loss")
    if isinstance(vals, list):
        return float(vals[-1]) if vals else None
    if isinstance(vals, (int, float)):
        return float(vals)
    return None


def _final_alignment(run_dir: Path):
    """Average the last 100 alignment values in frac_pos_alignment.json"""
    data = _load_json(run_dir / "metrics" / "frac_pos_alignment.json")
    if not data:
        return None
    return sum(d["alignments"][METRIC][0] for d in data[-100:]) / 100.0

# ----------------------------------------------------------------------------
#  Data collection
# ----------------------------------------------------------------------------
def collect_heatmap(
    exp_dir: Path, x_param: str, y_param: str
) -> Tuple[List[float], List[float], np.ndarray]:
    runs = [p for p in exp_dir.iterdir() if p.is_dir()]
    recs = []          # (x, y, val_loss)
    for run in runs:
        cfg = _load_config(run)
        if cfg is None:
            continue
        try:
            x = float(cfg.get(x_param))
            y = float(cfg.get(y_param))
        except (TypeError, ValueError):
            continue
        vloss = _final_val_loss(run)
        if vloss is None:
            continue
        recs.append((x, y, vloss))

    if not recs:
        raise RuntimeError("No runs with the requested keys + val_loss found.")

    xs  = sorted({r[0] for r in recs})
    ys  = sorted({r[1] for r in recs})
    mat = np.full((len(ys), len(xs)), np.nan)
    x_idx = {v: j for j, v in enumerate(xs)}
    y_idx = {v: i for i, v in enumerate(ys)}
    for x, y, v in recs:
        mat[y_idx[y], x_idx[x]] = v
    return xs, ys, mat


def collect_scatter(exp_dir: Path) -> Tuple[List[float], List[float]]:
    x_align, y_loss = [], []
    for run in exp_dir.iterdir():
        if not run.is_dir():
            continue
        v = _final_val_loss(run)
        a = _final_alignment(run)
        if v is None or a is None:
            continue
        x_align.append(a)
        y_loss.append(v)
    if not x_align:
        raise RuntimeError("No alignment + val_loss pairs found under EXP_DIR.")
    return x_align, y_loss

# ----------------------------------------------------------------------------
#  Formatting helpers
# ----------------------------------------------------------------------------
def _common_exponent(vals):
    vmax = max(vals)
    return int(math.floor(math.log10(max(vmax, 1e-99))))


def _scaled_labels(vals, exponent):
    scale = 10.0 ** exponent
    return [f"{v/scale:.1g}" for v in vals]

# ----------------------------------------------------------------------------
#  Plotting
# ----------------------------------------------------------------------------
def make_plots(
    xs: List[float],
    ys: List[float],
    mat: np.ndarray,
    alignments: List[float],
    losses: List[float],
    output: str | Path | None = None,
):
    fig, (ax_h, ax_s) = plt.subplots(
        1, 2, figsize=(4, 3), dpi=300, constrained_layout=True
    )

    # ── Heat-map (left) ───────────────────────────────────────────────
    im = ax_h.imshow(mat, origin="lower", aspect="equal", cmap="viridis")

    xt = np.linspace(0, len(xs) - 1, min(6, len(xs)), dtype=int)
    yt = np.linspace(0, len(ys) - 1, min(6, len(ys)), dtype=int)
    x_exp, y_exp = _common_exponent([xs[i] for i in xt]), _common_exponent([ys[i] for i in yt])
    ax_h.set_xticks(xt)
    ax_h.set_yticks(yt)
    ax_h.set_xticklabels(_scaled_labels([xs[i] for i in xt], x_exp),
                         rotation=45, ha="right", fontsize=6)
    ax_h.set_yticklabels(_scaled_labels([ys[i] for i in yt], y_exp), fontsize=6)
    ax_h.set_xlabel(r"Noise Scale ($\sigma$)")
    ax_h.set_ylabel(r"Weight Decay ($\gamma$)")
    ax_h.set_title("Validation Loss",  pad=6)
    if x_exp:
        ax_h.text(1.02, -0.02, rf"$\times10^{{{x_exp}}}$",
                  transform=ax_h.transAxes, fontsize=6, ha="left", va="top")
    if y_exp:
        ax_h.text(-0.02, 1.02, rf"$\times10^{{{y_exp}}}$",
                  transform=ax_h.transAxes, fontsize=6, ha="right",
                  va="bottom", rotation="vertical")

    # ── Scatter (right) ───────────────────────────────────────────────

    xs2 = list(range(int(min(alignments))-1,int(max(alignments))))
    ys2 = list(range(int(min(losses))-1,int(max(losses))))

    # Match axis-label font size with heat-map (default rcParams)
    ax_s.set_xlabel("Hebbian Alignment")
    ax_s.set_ylabel("Validation Loss")
    ax_s.set_title("Validation Loss\nvs. Alignment", pad=6)
    x_exp, y_exp = _common_exponent(alignments), _common_exponent(losses)

    # Scientific-notation tick formatting
    # if x_exp:
    #     ax_s.text(1.02, -0.02, rf"$\times10^{{{x_exp}}}$",
    #               transform=ax_s.transAxes, fontsize=6, ha="left", va="top")
    if y_exp:
        ax_s.text(-0.02, 1.02, rf"$\times10^{{{y_exp}}}$",
                  transform=ax_s.transAxes, fontsize=6, ha="right",
                  va="bottom", rotation="vertical")
    #*(10**-x_exp)
    ax_s.scatter(np.array(alignments), np.array(losses)*(10**-y_exp), alpha=0.7, s=12,
                 edgecolors="black", linewidths=0.3)

    # Tick appearance (size = 6 pt like heat-map)
    ax_s.tick_params(axis="both", which="major",
                     labelsize=6, length=3, width=0.5)

    ax_s.grid(False, linewidth=0.3, alpha=0.6)
    ax_h.set_box_aspect(1)  # Make heatmap square in physical space
    ax_s.set_box_aspect(1)  # Make scatter plot square in physical space
    # Overall title
    fig.suptitle("Generalization Trends",
                 fontweight="bold")
    cbar = fig.colorbar(im, ax=[ax_h, ax_s], orientation="horizontal", pad=0.05, aspect=30, shrink=0.5)
    cbar.set_label("Validation Loss")
    cbar.ax.tick_params(labelsize=6, length=3, width=0.5)
    if output:
        fig.savefig(output, dpi=300)
        print(f"Saved figure → {output}")
    else:
        plt.show()

# ----------------------------------------------------------------------------
#  Main
# ----------------------------------------------------------------------------
if __name__ == "__main__":
    xs, ys, mat = collect_heatmap(EXP_DIR, X_PARAM, Y_PARAM)
    aln, vls    = collect_scatter(EXP_DIR)
    make_plots(xs, ys, mat, aln, vls, OUTPUT)