#!/usr/bin/env python3
"""
plot_grad_update_heatmaps_v3.py  (matplotlib-only, *minor tweaks 2025-05-08*)
----------------------------------------------------------------------------

Changes requested
~~~~~~~~~~~~~~~~~
* **Extra whitespace** between heat-maps and their colour-bars.
* **Smaller tick labels** – numbers on axes *and* colour-bars are now
  **½ ×** the base font size.
* **Mid-point ticks** added for both *Epoch* (x-axis) and *Neuron Index* (y-axis).

Other functionality and defaults are unchanged.

Dependencies
~~~~~~~~~~~~
* numpy
* matplotlib ≥ 3.5
* scipy (optional, for Gaussian smoothing)
"""
from __future__ import annotations

import argparse
import json
from dataclasses import dataclass, asdict
from pathlib import Path
from time import perf_counter
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
from scipy.ndimage import gaussian_filter1d

# ───────────────────────────────────────────────────────────────────────────────
# Configuration & global font handling
# ───────────────────────────────────────────────────────────────────────────────

def _apply_font_rcparams(base: int) -> None:
    """Cascade *base* into all key mpl font settings."""
    plt.rcParams.update({
        "font.size": base,
        "figure.titlesize": base * 1.3,
        "axes.titlesize": base * 1.05,
        "axes.labelsize": base,
        "xtick.labelsize": base * 0.9,
        "ytick.labelsize": base * 0.9,
    })


@dataclass
class Config:
    exp_root: Path = Path(
        "__path__"
    )
    show_epoch: int = 50
    window: int = 0                   # σ for Gaussian smoothing (0 → none)
    neuron_limit: int | None = 10     # ⇽ default: first 10 neurons
    dpi: int = 300

    # Figure fine-tuning
    figsize: tuple[int, int] = (4, 3)
    font_size: int = 9

    out_name: str = "{subexp}_per_neuron_heatmaps.png"

    # Which statistic types to plot. Any combination of "gradient",
    # "weight_update". DEFAULT: only weight updates.
    to_plot: list[str] = ("weight_update",)

# ───────────────────────────────────────────────────────────────────────────────
# Colour-maps & drawing helper
# ───────────────────────────────────────────────────────────────────────────────
DIVERGING = plt.get_cmap("RdBu_r")
MAG_SCALE = plt.get_cmap("Greens")


def _heatmap(ax: plt.Axes, data: np.ndarray, cmap, vmin: float | None, vmax: float | None):
    """Draw *data* transposed (neurons on y) in *ax* and return the image."""
    im = ax.imshow(
        data,
        aspect="auto",
        origin="upper",
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        interpolation="none",
    )
    ax.set_xticks([])                      # x-axis ticks handled later
    for spine in ax.spines.values():
        spine.set_visible(False)
    return im

# ───────────────────────────────────────────────────────────────────────────────
# Main plotting routine
# ───────────────────────────────────────────────────────────────────────────────

def plot_heatmaps(subexp_dir: Path, cfg: Config):
    metrics_path = subexp_dir / "metrics" / "saved_updates.json"
    if not metrics_path.exists():
        print(f"[skip] {subexp_dir.name}: metrics file missing")
        return

    # ------------------------------------------------------------------
    # Parse metrics
    # ------------------------------------------------------------------
    with metrics_path.open() as fp:
        data = json.load(fp)

    records: Sequence[dict] = (
        [rec for epoch_list in data for rec in epoch_list]
        if data and isinstance(data[0], list)
        else data
    )
    records = [r for r in records if r.get("epoch", 0) <= cfg.show_epoch]
    if not records:
        print(f"[skip] {subexp_dir.name}: no records ≤ epoch {cfg.show_epoch}")
        return

    n_iters   = len(records)
    max_epoch = max(r.get("epoch", 0) for r in records)
    n_neurons = len(records[0]["neuron_update_cos_sim"])
    k         = min(cfg.neuron_limit or n_neurons, n_neurons)  # enforce limit

    # Pre-allocate & populate
    grad_cos, update_cos = (np.empty((n_iters, k)) for _ in range(2))
    grad_mag, update_mag = (np.empty((n_iters, k)) for _ in range(2))

    for i, rec in enumerate(records):
        grad_cos[i]   = rec["neuron_grad_cos_sim"][:k]
        update_cos[i] = rec["neuron_update_cos_sim"][:k]
        grad_mag[i]   = rec["per_neuron_grad_mag"][:k]
        update_mag[i] = rec["neuron_update_mag"][:k]

    if cfg.window > 0:
        for arr in (grad_cos, update_cos, grad_mag, update_mag):
            arr[:] = gaussian_filter1d(arr, sigma=cfg.window, axis=0, mode="reflect")

    mag_min, mag_max = (min(grad_mag.min(), update_mag.min()),
                        max(grad_mag.max(), update_mag.max()))

    # ------------------------------------------------------------------
    # Figure layout – manual GridSpec so heat-maps use the *entire* width
    # ------------------------------------------------------------------
    n_types = len(cfg.to_plot)
    rows = 2 * n_types if n_types > 1 else 2  # cosine & magnitude per type

    fig = plt.figure(figsize=cfg.figsize, dpi=cfg.dpi)
    gs  = GridSpec(rows, 1, figure=fig, hspace=0.06)

    # Build panels list in the same order we’ll draw
    panels: list[tuple[np.ndarray, any, float | None, float | None]] = []
    for stat in cfg.to_plot:
        if stat == "gradient":
            panels += [ (grad_cos, DIVERGING, -1, 1),
                        (grad_mag, MAG_SCALE, mag_min, mag_max) ]
        elif stat == "weight_update":
            panels += [ (update_cos, DIVERGING, -1, 1),
                        (update_mag, MAG_SCALE, mag_min, mag_max) ]
        else:
            raise ValueError(f"Unknown stat: {stat}")

    axes  = [fig.add_subplot(gs[i, 0]) for i in range(rows)]
    im_artists = []
    for ax, (arr, cmap, vmin, vmax) in zip(axes, panels):
        im_artists.append(_heatmap(ax, arr.T, cmap, vmin, vmax))

    # ------------------------------------------------------------------
    # Axis labelling (x & y)
    # ------------------------------------------------------------------
    # Epoch ticks: start, mid, end
    mid_iter   = (n_iters - 1) // 2
    epoch_ticks = [0, mid_iter, n_iters - 1]
    epoch_labs  = ["0",
                   f"{int(max_epoch * (mid_iter / (n_iters - 1)))}",
                   f"{int(max_epoch)}"]

    # Neuron index ticks: start, mid, end
    mid_neuron = (k - 1) // 2
    neuron_ticks = [0, mid_neuron, k - 1] if k > 1 else [0]

    # Bottom plot gets x-axis labels
    axes[-1].set_xticks(epoch_ticks)
    axes[-1].set_xticklabels(epoch_labs)
    axes[-1].set_xlabel("Epoch")

    # All plots get y-axis ticks & label
    for ax in axes:
        ax.set_yticks(neuron_ticks)
        ax.set_yticklabels([str(t) for t in neuron_ticks])
        ax.set_ylabel("Neuron Index")

    # Reduce tick label size to ½ base font
    half_size = cfg.font_size * 0.5
    for ax in axes:
        ax.tick_params(axis="both", which="both", labelsize=half_size)

    # ------------------------------------------------------------------
    # Per-panel vertical colour-bars – free-floating beside each axes
    # ------------------------------------------------------------------
    for idx, (ax, im) in enumerate(zip(axes, im_artists)):
        bbox = ax.get_position()
        cax_width = 0.02      # fraction of figure width
        cax_pad   = 0.03      # ⇽ increased pad for extra whitespace
        cax = fig.add_axes([bbox.x1 + cax_pad, bbox.y0, cax_width, bbox.height])

        cb = fig.colorbar(im, cax=cax, orientation="vertical")
        label = "Alignment" if idx % 2 == 0 else "Magnitude"
        cb.set_label(label, labelpad=5)

        # Smaller numbers on colour-bar
        cb.ax.tick_params(length=2, labelsize=half_size)

    # ------------------------------------------------------------------
    # Finishing touches & save
    # ------------------------------------------------------------------
    fig.tight_layout(rect=[0, 0, 1, 0.90])
    fig.suptitle("Per-Neuron Weight Update\nHebbian Alignment and Magnitude", fontweight="bold",y=1.02)

    out_dir = subexp_dir / "figure"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / cfg.out_name.format(subexp=subexp_dir.name)
    fig.savefig(out_path, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved → {out_path}")

# ───────────────────────────────────────────────────────────────────────────────
# CLI
# ───────────────────────────────────────────────────────────────────────────────

def _parse_args(defaults: Config) -> Config:
    p = argparse.ArgumentParser(description="Generate per-neuron heat-maps (matplotlib-only)")
    for field, default in asdict(defaults).items():
        arg_type = type(default) if default is not None else str
        p.add_argument(f"--{field}", type=arg_type, default=default)
    ns = p.parse_args()
    return Config(**vars(ns))


def main():
    cfg = _parse_args(Config())
    # _apply_font_rcparams(cfg.font_size)  # ← keep commented; explicit tick sizes instead

    root = cfg.exp_root.expanduser().resolve()
    subdirs = [d for d in sorted(root.iterdir()) if d.is_dir()]
    print(f"Found {len(subdirs)} sub-experiments under {root}")

    start = perf_counter()
    for sub in subdirs:
        plot_heatmaps(sub, cfg)

    print(f"All done in {perf_counter() - start:.1f}s")


if __name__ == "__main__":
    main()