#!/usr/bin/env python3
"""Plot Hebbian alignment heatmaps + noise/weight‑decay curves.

This adapts the original two‑figure layout to the data‑loading, caching, and
styling conventions of *plot_weight_decay_curves.py* – i.e. everything lives in
constants, data are lazily parsed from the *EXP_DIR* experiment tree and then
cached, and we follow the same clean, publication‑ready aesthetics (white‑filled
markers, single shared legend, etc.).

Directory structure – every run folder under *EXP_DIR* must contain:
    metrics/config.json              (with keys "weight_decay" and "gradient_noise_fraction")
    metrics/frac_pos_alignment.json  (list of dicts; "alignments" has "L1", "L2")

Only the final *TAIL* iterations are averaged (±std) to give the convergence
alignment value.

The script produces two PNGs next to *EXP_DIR* and prints their filenames, or
shows them interactively if *OUTPUT* is None.
"""
from __future__ import annotations

import json
import pickle
from pathlib import Path
from typing import DefaultDict, Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from matplotlib.ticker import FuncFormatter, LogLocator
from collections import defaultdict
from matplotlib.ticker import ScalarFormatter
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import NullFormatter
import math

def _common_exponent(vals: List[float]) -> int:
    """Return the floor-of-log10 for the largest value in *vals* (≥1e-99 guard)."""
    vmax = max(vals)
    return int(math.floor(math.log10(max(vmax, 1e-99))))

def _scaled_labels(vals: List[float], exponent: int) -> List[str]:
    """Divide by 10**exponent so labels fall in 0–10."""
    scale = 10.0 ** exponent
    return [f"{v/scale:.1g}" for v in vals]

sci_formatter = FuncFormatter(lambda x, _: f"{x:.1e}")
# ─── Constants ──────────────────────────────────────────────────────────────
EXP_DIR = Path("__path__")  # directory with run sub‑folders
TAIL = 300                                # last N iters to average
LAYERS = ["L1", "L2"]                     # keys inside "alignments"
OUTPUT_HEATMAP = "heatmaps.png"           # file names (set to None to plt.show)
OUTPUT_CURVES = "noise_wd_curves.png"
FIGSIZE_HMAP = (4, 3)                   # inches – heatmap fig
FIGSIZE_CURV = (4, 3)                   # inches – curve fig
CACHE_DIR = Path("cleaned_plots/cache")  # where *.pkl files live
OPTIMIZER = "SGD"
THRESHOLD = 0.5
ZERO_FIT_THRESH = 0.008
line_colors = ['#d7191c', '#fdae61', '#ffec99', '#abdda4', '#2b83ba']
plt.rcParams["axes.prop_cycle"] = cycler("color", line_colors)

# ─── Helpers ────────────────────────────────────────────────────────────────

def _parse_num(x: str | float | int | None):
    try:
        return float(x) if x is not None else None
    except Exception:
        return None


def _load_config(run: Path):
    cfg = run / "metrics" / "config.json"
    if not cfg.exists():
        return None
    try:
        data = json.loads(cfg.read_text())
        

        return data
    except Exception:
        return None


def _tail_align(run: Path, tail: int, layer: str) -> Tuple[float, float] | None:
    path = run / "metrics" / "frac_pos_alignment.json"
    if not path.exists():
        return None
    try:
        data = json.loads(path.read_text())
    except json.JSONDecodeError:
        return None
    if not data:
        return None

    # list‑of‑dicts or list‑of‑floats depending on storage format
    seq = (
        [d.get("alignments", {}).get(layer) for d in data]
        if isinstance(data[0], dict)
        else data
    )
    vals = [v for v in seq if v is not None]
    if not vals:
        return None
    arr = np.array(vals[-tail:])
    return float(arr.mean()), float(arr.std())


# ─── Data collection & cache ────────────────────────────────────────────────

def _collect_runs(exp_dir: Path) -> List[Path]:
    return [d for d in exp_dir.iterdir() if d.is_dir()]


    
def fit_zero_alignment_quadratic(noises: List[float], wds: List[float], data_mean: Dict[str, np.ndarray],
                                 threshold: float = ZERO_FIT_THRESH):
    """Return dict[layer] = np.poly1d representing γ(σ) for |alignment| < threshold."""
    coeffs: Dict[str, np.poly1d | None] = {}
    N, M = len(wds), len(noises)
    for layer in LAYERS:
        σ_samples, γ_samples = [], []
        Z = data_mean[layer]
        for i in range(N):
            for j in range(M):
                if not np.isnan(Z[i, j]) and abs(Z[i, j]) <= threshold:
                    σ_samples.append(noises[j])
                    γ_samples.append(wds[i])
        if len(σ_samples) >= 3:
            X = np.array(σ_samples)**2  # Only quadratic term
            Y = np.array(γ_samples)
            # Fit y = a * x^2 => solve for a using least squares
            a, _, _, _ = np.linalg.lstsq(X[:, np.newaxis], Y, rcond=None)
            coeffs[layer] = np.poly1d([a[0], 0, 0])  # a*x^2 + 0*x + 0
        else:
            coeffs[layer] = None
    return coeffs

    
def load_data(exp_dir: Path, layers: List[str], tail: int):
    """
    Return (noise_list, wd_list, data_mean_dict, data_std_dict)
    where data_mean_dict[layer] is a 2-D array (wd × noise).
    """
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    cache_file = CACHE_DIR / f"{exp_dir.name}_noise_wd.pkl"
    if cache_file.exists():
        with open(cache_file, "rb") as f:
            return pickle.load(f)

    # ------------------------------------------------------------------ #
    #  Gather all runs and read their config.json exactly one time
    # ------------------------------------------------------------------ #
    runs        = _collect_runs(exp_dir)
    run_cfgs    = {run: _load_config(run) for run in runs}
    run_cfgs    = {run: cfg for run, cfg in run_cfgs.items() if cfg is not None and cfg.get("optimizer") == OPTIMIZER}

    # helper – safely pull and cast numerical fields from cfg
    def _get(cfg: dict, key: str):
        return _parse_num(cfg.get(key))

    noises = sorted({
        _get(cfg, "gradient_noise_fraction")
        for cfg in run_cfgs.values()
        if cfg.get("activation") == "tanh" 
    } - {None})

    wds = sorted({
        _get(cfg, "weight_decay")
        for cfg in run_cfgs.values()
        if cfg.get("activation") == "tanh" 
    } - {None})

    noise_idx = {n: j for j, n in enumerate(noises)}
    wd_idx    = {g: i for i, g in enumerate(wds)}

    # ------------------------------------------------------------------ #
    #  Pre-allocate result arrays
    # ------------------------------------------------------------------ #
    data_mean: Dict[str, np.ndarray] = {
        l: np.full((len(wds), len(noises)), np.nan) for l in layers
    }
    data_std: Dict[str, np.ndarray] = {
        l: np.full((len(wds), len(noises)), np.nan) for l in layers
    }
    agg: Dict[str, DefaultDict[Tuple[int, int], List[Tuple[float, float]]]] = {
        l: defaultdict(list) for l in layers
    }

    # ------------------------------------------------------------------ #
    #  Populate aggregates
    # ------------------------------------------------------------------ #
    for run, cfg in run_cfgs.items():
        σ = _get(cfg, "gradient_noise_fraction")
        γ = _get(cfg, "weight_decay")
        if None in (σ, γ):
            continue
        i, j = wd_idx[γ], noise_idx[σ]

        for layer in layers:
            res = _tail_align(run, tail, layer)
            if res is not None:
                agg[layer][(i, j)].append(res)

    # ------------------------------------------------------------------ #
    #  Reduce over repeated seeds
    # ------------------------------------------------------------------ #
    for layer in layers:
        for (i, j), lst in agg[layer].items():
            means, stds = zip(*lst)
            data_mean[layer][i, j] = float(np.mean(means))
            data_std[layer][i, j]  = float(np.sqrt(np.mean(np.square(stds))))

    # ------------------------------------------------------------------ #
    #  Cache to disk and return
    # ------------------------------------------------------------------ #
    payload = (noises, wds, data_mean, data_std)
    with open(cache_file, "wb") as f:
        pickle.dump(payload, f)
    print(f"Saved cache → {cache_file}")

    return payload

# ─── Plotting ────────────────────────────────────────────────────────────────

def plot_heatmaps(noises: List[float], wds: List[float], data_mean: Dict[str, np.ndarray], zero_curves):
    cmap = "bwr"
    vmin, vmax = -THRESHOLD, THRESHOLD
    fig, axes = plt.subplots(1, len(LAYERS), figsize=FIGSIZE_HMAP, dpi=300,
                             constrained_layout=False, gridspec_kw={"top": 0.99})
    fig.subplots_adjust(top=0.95)
    fig.suptitle("Noise vs. Weight Decay Alignment", fontweight="bold")

    if len(LAYERS) == 1:
        axes = [axes]

    xtick_idx = np.linspace(0, len(noises) - 1, 3, dtype=int)
    ytick_idx = np.linspace(0, len(wds)    - 1, 3, dtype=int)
    noise_exp = _common_exponent([noises[i] for i in xtick_idx])
    wd_exp    = _common_exponent([wds[i]    for i in ytick_idx])
    xlabels   = _scaled_labels([noises[i] for i in xtick_idx], noise_exp)
    ylabels   = _scaled_labels([wds[i]    for i in ytick_idx], wd_exp)

    for ax, layer in zip(axes, LAYERS):
        im = ax.imshow(data_mean[layer], origin="lower", cmap=cmap, vmin=vmin, vmax=vmax)
        ax.set_aspect("equal")
        ax.set_title({"L1": "Layer 1", "L2": "Layer 2"}[layer])

        ax.set_xticks(xtick_idx)
        ax.set_yticks(ytick_idx)
        ax.set_xticklabels(xlabels, rotation=45, ha="right", fontsize=6)
        if ax is axes[0]:
            ax.set_yticklabels(ylabels, fontsize=6)
        else:
            ax.set_yticklabels([])
            ax.tick_params(axis="y", which="both", left=False)
            ax.yaxis.set_major_formatter(NullFormatter())
        ax.set_xlabel(r"Noise Scale ($\sigma$)")
        if ax is axes[0]:
            ax.set_ylabel(r"Weight Decay ($\gamma$)")
        if noise_exp != 0:
            ax.text(1.02, -0.02, rf"$\times10^{{{noise_exp}}}$", transform=ax.transAxes,
                    ha='left', va='top', fontsize=6)
        if ax is axes[0] and wd_exp != 0:
            ax.text(-0.02, 1.02, rf"$\times10^{{{wd_exp}}}$", transform=ax.transAxes,
                    ha='right', va='bottom', rotation='vertical', fontsize=6)

        # ── Overlay quadratic zero‑alignment curve ───────────────────────
        p = zero_curves.get(layer)
        if p is not None:
            σ_min, σ_max = min(noises), max(noises)
            σ_dense = np.linspace(σ_min, σ_max, 500)
            γ_dense = p(σ_dense)
            # keep only in‑bounds points
            mask = (γ_dense >= min(wds)) & (γ_dense <= max(wds))
            c = p.coeffs[0]
            exp = int(np.log10(c))

            num=c/(10**exp)
            if mask.any():
                σ_in, γ_in = σ_dense[mask], γ_dense[mask]
                # convert σ, γ to heatmap index space
                x_idx = np.interp(σ_in, noises, np.arange(len(noises)))
                y_idx = np.interp(γ_in, wds,    np.arange(len(wds)))
                ax.plot(x_idx, y_idx, color="black",linestyle="--", lw=1.0, label=f"$\\frac{{c_0}}{{c_1}}={num:.1f} \\times 10^{exp}$")
                ax.legend(loc="lower right", fontsize=6, frameon=True)

    cbar = fig.colorbar(im, ax=axes, orientation="horizontal",
                        fraction=0.05, pad=0.22, aspect=30, shrink=0.75)
    cbar.set_label("Alignment")
    cbar.set_ticks(np.linspace(vmin, vmax, 5))
    cbar.ax.tick_params(length=2, labelsize=6)

    if OUTPUT_HEATMAP:
        out_path = EXP_DIR / OUTPUT_HEATMAP
        fig.savefig(out_path, dpi=300)
        print(f"Saved heatmaps → {out_path}")
    else:
        plt.show()

def plot_noise_wd_curves(noises: List[float], wds: List[float], data_mean: Dict[str, np.ndarray]):
    gamma_idxs = np.arange(0, len(wds), 3)
    sigma_idxs = np.arange(0, len(noises), 3)

    fig, axes = plt.subplots(1, 2, figsize=FIGSIZE_CURV, dpi=300, constrained_layout=True)
    fig.suptitle("Noise and Weight Decay\nCurves for Layer 2", fontweight="bold")
    for ax in axes:
        ax.axhline(0, color="black", lw=1, ls="--")

    se = _common_exponent(wds)
    ne = _common_exponent(noises)

    scaled_noises = [n / 10**ne for n in noises]
    scaled_wds = [g / 10**se for g in wds]
    xticks_noise = [scaled_noises[j] for j in sigma_idxs]
    xticks_wd = [scaled_wds[j] for j in gamma_idxs]

    for ax in axes:
        ax.set_box_aspect(1)  # Forces height = width in physical size

    # Left panel – effect of noise for several γ
    ax = axes[0]
    for idx in gamma_idxs:
        γ = wds[idx] / 10**se
        ax.plot(scaled_noises, data_mean["L2"][idx, :], label=rf"$\gamma$ = {γ:.3g}", marker="o", markerfacecolor="white")
    ax.set_title("Varying\nNoise")
    ax.set_xlabel(r"Noise Scale ($\sigma$)")
    ax.set_ylabel("Alignment")
    ax.set_xticks(xticks_noise)
    ax.set_xticklabels(_scaled_labels([noises[j] for j in sigma_idxs], ne), rotation=45, fontsize=6)
    if ne != 0:
        ax.text(1.02, -0.08, rf'$\times10^{{{ne}}}$', transform=ax.transAxes, ha='left', va='top', fontsize=6)

    # box = ax.get_position()
    # ax.set_position([box.x0, box.y0, box.width, box.height])
    ax.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, -0.31),  # Adjust as needed
        frameon=False,
        ncols=2,
        fontsize=6
    )

    # Right panel – effect of weight decay for several σ
    ax = axes[1]
    for idx in sigma_idxs:
        σ = noises[idx] / 10**ne
        ax.plot(scaled_wds, data_mean["L2"][:, idx], label=rf"$\sigma$ = {σ:.3g}", linestyle="--", marker="s", markerfacecolor="white")
    ax.set_title("Varying\nWeight Decay")
    ax.set_xlabel(r"Weight Decay ($\gamma$)")
    ax.set_xticks(xticks_wd)
    ax.set_xticklabels(_scaled_labels([wds[j] for j in gamma_idxs], se), rotation=45, fontsize=6)
    ax.set_yticks([])
    if se != 0:
        ax.text(1.02, -0.08, rf'$\times10^{{{se}}}$', transform=ax.transAxes, ha='left', va='top', fontsize=6)
    # fig.subplots_adjust(top=0.75)
    # box = ax.get_position()
    # ax.set_position([box.x0, box.y0, box.width, box.height])
    ax.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, -0.31),  # Adjust as needed
        frameon=False,
        ncols=2,
        fontsize=6
    )
    if OUTPUT_CURVES:
        out_path = EXP_DIR / OUTPUT_CURVES
        fig.savefig(out_path, dpi=300)
        print(f"Saved curves → {out_path}")
    else:
        plt.show()

# ─── Main ──────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    noises, wds, data_mean, _ = load_data(EXP_DIR, LAYERS, TAIL)
    zero_curves = fit_zero_alignment_quadratic(noises, wds, data_mean)
    plot_heatmaps(noises, wds, data_mean, zero_curves)
    plot_noise_wd_curves(noises, wds, data_mean)
