#!/usr/bin/env python3
"""Plot alignment_fraction curves for Appendix “Regularizers” experiment
using a *sliding-window* mean ± standard-deviation, with one subplot per
weight-decay value.

This script is adapted from the Section 3 *Hebbian-no-SGD* plotter.
It now expects runs produced by the *appendix_regularizers* sweep:

  - regularization_mode ∈ {L2_weight_decay, L1_weight_decay, drop_out, batch_norm}
  - weight_decay ∈ {…} (various values)
  - lr = 0.01 (currently a single learning-rate – code still supports ≥1)

A *single* rolling window (length ``WINDOW``) is used to smooth raw
iterations:

* For each seed, compute the mean and SD over every trailing window of
  ``WINDOW`` points (indexes *t-(WINDOW-1)…t*).
* Across seeds, average the window means → the curve, and average the
  window SDs → the shaded band.

One line is plotted for **each regularization mode**, and one subplot
for each weight-decay value.
"""
from __future__ import annotations

import json
from pathlib import Path
from typing import DefaultDict, Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd  # noqa: F401 – retained for downstream utilities
from cycler import cycler

# ─── User-configurable constants ────────────────────────────────────────────
EXP_DIR   = Path("__path__")  # parent of all run dirs
OUTPUT    = "regularizers_alignment_sliding_fill_by_wd.png"  # set None ➜ plt.show
PREF_LAY  = "L2"                                         # preferred layer key

REG_MODES = [                                            # cfg["regularization_mode"]
    "L2_weight_decay",
    "L1_weight_decay",
    "drop_out",
    "batch_norm",
]

WINDOW      = 200                                        # sliding-window length
FIGSIZE     = (4, 3)                                     # inches – figure size

# ─── Styling to match house style ───────────────────────────────────────────
COLORS = ["#d7191c", "#fdae61", "#2b83ba", "#abdda4"]  # one colour per mode
plt.rcParams["axes.prop_cycle"] = cycler("color", COLORS)

LABEL_MAP = {
    "L2_weight_decay": "L2 weight decay",
    "L1_weight_decay": "L1 weight decay",
    "drop_out": "Dropout (p = 0.5)",
    "batch_norm": "Batch Norm",
}

# ─── Helper functions ───────────────────────────────────────────────────────

def _sliding_mean_std(arr: np.ndarray, win: int = WINDOW) -> tuple[np.ndarray, np.ndarray]:
    """Return sliding-window mean & SD over *win*-length trailing windows."""
    if arr.size < win:
        return np.array([], dtype=float), np.array([], dtype=float)
    shape   = (arr.size - win + 1, win)
    strides = (arr.strides[0], arr.strides[0])
    windows = np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)
    return windows.mean(axis=1), windows.std(axis=1, ddof=1)


def _extract_layer(entry: dict, preferred: str = PREF_LAY) -> float | None:
    """Extract the alignment value for *preferred* layer from a metrics entry."""
    aligns = entry.get("alignments", {}) if isinstance(entry, dict) else {}
    if not aligns:
        return None
    if preferred in aligns:
        return aligns[preferred]
    wd_key = f"{preferred}"
    if wd_key in aligns:
        return aligns[wd_key]
    for v in aligns.values():
        if isinstance(v, (int, float)):
            return v
    return None

# ─── Data load ─────────────────────────────────────────────────────────────

runs = [p for p in EXP_DIR.iterdir() if p.is_dir()]
if not runs:
    raise RuntimeError(f"No runs found under {EXP_DIR!s}")

max_epoch = 0
meta_by_run: Dict[Path, Dict[str, object]] = {}

for run in runs:
    cfg_path = run / "metrics" / "config.json"
    if not cfg_path.exists():
        continue
    try:
        cfg = json.loads(cfg_path.read_text())
    except json.JSONDecodeError:
        continue

    lr            = cfg.get("lr")
    reg_mode      = cfg.get("regularization_mode")
    weight_decay  = cfg.get("weight_decay")
    epochs        = cfg.get("epochs", 0)

    max_epoch = max(max_epoch, int(epochs))

    if (lr is None) or (reg_mode not in REG_MODES) or (weight_decay is None):
        continue

    meta_by_run[run] = {
        "lr": float(lr),
        "mode": reg_mode,
        "weight_decay": float(weight_decay),
    }

if not meta_by_run:
    raise RuntimeError("No valid runs found (check REG_MODES & EXP_DIR).")

align_data: DefaultDict[float, DefaultDict[str, List[np.ndarray]]] = {}
max_itr = 0

for run, meta in meta_by_run.items():
    path = run / "metrics" / "frac_pos_alignment.json"
    if not path.exists():
        path = run / "metrics" / "align_frac.json"
    if not path.exists():
        continue
    try:
        data = json.loads(path.read_text())
    except json.JSONDecodeError:
        continue

    max_itr = max(max_itr, len(data))
    vals = (
        [_extract_layer(d) for d in data]
        if isinstance(data[0], dict)
        else data
    )
    vals = np.asarray([v for v in vals if v is not None], dtype=float)
    if vals.size == 0:
        continue

    wd   = meta["weight_decay"]
    mode = meta["mode"]
    align_data.setdefault(wd, {}).setdefault(mode, []).append(vals)

# Sort unique weight-decay values
decays = sorted(align_data.keys())
if len(decays) == 0:
    raise RuntimeError("No alignment data found.")

# ─── Aggregate across seeds ────────────────────────────────────────────────

mu_by_wd: Dict[float, Dict[str, np.ndarray]] = {}
std_by_wd: Dict[float, Dict[str, np.ndarray]] = {}

for wd in decays:
    mu_by_wd[wd], std_by_wd[wd] = {}, {}
    for mode in REG_MODES:
        seeds = align_data[wd].get(mode, [])
        mus: List[np.ndarray] = []
        stds: List[np.ndarray] = []
        min_len = None

        for s in seeds:
            mu, std = _sliding_mean_std(np.asarray(s))
            if mu.size == 0:
                continue
            min_len = mu.size if min_len is None else min(min_len, mu.size)
            mus.append(mu)
            stds.append(std)

        if min_len is None:
            continue

        # truncate all to the shortest sequence
        mus  = [m[:min_len] for m in mus]
        stds = [s[:min_len] for s in stds]

        mu_by_wd[wd][mode]  = np.vstack(mus).mean(axis=0)
        std_by_wd[wd][mode] = np.vstack(stds).mean(axis=0)

# ─── Plot ──────────────────────────────────────────────────────────────────

n_wd  = len(decays)
fig_w = FIGSIZE[0] * n_wd / 2 if n_wd > 1 else FIGSIZE[0]
fig, axes = plt.subplots(
    1, n_wd,
    figsize=(fig_w, FIGSIZE[1]),
    dpi=300,
    sharey=True
)
axes = np.atleast_1d(axes)

fig.suptitle(
    f"Hebbian Alignment of Gradient Update vs. Regularizers During Training\n(Rolling Window {WINDOW}, mean ± SD)",
    fontweight="bold",
)
fig.subplots_adjust(left=-0.05, right=1.05, wspace=0.15, top=0.75, bottom=0.2)

legend_lines: List[plt.Line2D] = []
legend_labels: List[str] = []

for ax, wd in zip(axes, decays):
    ax.axhline(0, color="black", linestyle="--", linewidth=1)

    for mode, color in zip(REG_MODES, COLORS):
        mu  = mu_by_wd[wd].get(mode)
        std = std_by_wd[wd].get(mode)
        if (mu is None) or (std is None):
            continue
        x = max_epoch / max_itr * (np.arange(mu.size) + WINDOW - 1)
        line, = ax.plot(x, mu, lw=1.6, color=color)
        ax.fill_between(x, mu - std, mu + std, color=color, alpha=0.25)

        # only add legend entries once
        if wd == decays[0]:
            legend_lines.append(line)
            legend_labels.append(LABEL_MAP.get(mode, mode))

    ax.set_title(fr"Weight Decay $\lambda$ = {wd:g}")
    ax.set_xlabel("Epoch")
    ax.grid(False)
    ax.tick_params(axis="both", labelsize=6)

axes[0].set_ylabel("Alignment")
if n_wd > 1:
    for ax in axes[1:]:
        ax.set_ylabel("")

fig.legend(
    legend_lines,
    legend_labels,
    loc="lower center",
    ncol=len(legend_labels),
    frameon=False,
    bbox_to_anchor=(0.5, -0.12),
)

if OUTPUT:
    fig.savefig(EXP_DIR / OUTPUT, dpi=300, bbox_inches="tight")
    print(f"Saved figure → {EXP_DIR / OUTPUT}")
else:
    plt.show()