#!/usr/bin/env python3
"""Plot alignment & weight‑norm curves for Section 5 “Track Phases – Init” experiments

This script creates a 2 × 3 grid (6 sub‑plots):

* **Row 1 – Alignment**  
  Sliding‑window mean ± SD of *alignment* for the three
  initialisation schemes **low / default / high**.
  Inside each panel a separate line is drawn for every distinct
  **weight‑decay** value that occurred in the runs.
* **Row 2 – Weight‑norm**  
  The corresponding sliding‑window mean ± SD of the *weight ℓ²‑norm*.

Only experiments with a learning‑rate of **0.01** are processed.

The window length is set to 30 to match the `track_updates: 30`
configuration used during training.

All other visual styling follows the house style used in previous
figures.
"""
from __future__ import annotations

import json
import math
from pathlib import Path
from typing import DefaultDict, Dict, List

import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler

# ─── User‑configurable constants ────────────────────────────────────────────
EXP_DIR  = Path("__path__")  # parent of all run dirs
OUTPUT   = "section5_alignment_weightnorm.png"                       # set None ➜ plt.show
WINDOW   = 300        # sliding‑window length (raw iterations)
FIGSIZE  = (3, 2)    # figure size in inches
LAYER = "L2"
INITIALISATIONS = ["low", "default", "high"]

# ─── Styling ────────────────────────────────────────────────────────────────
PALETTE = [
    "#d7191c", "#fdae61", "#2b83ba", "#abdda4", "#bebada", "#ffd92f",
    "#80b1d3", "#8dd3c7", "#fb8072", "#b3de69",  # more colours if needed
]
plt.rcParams["axes.prop_cycle"] = cycler("color", PALETTE)
# plt.rcParams["font.size"] = 8

# ─── Helper functions ───────────────────────────────────────────────────────

def _sliding_mean_std(arr: np.ndarray, win: int = WINDOW) -> tuple[np.ndarray, np.ndarray]:
    """Return sliding mean & SD (trailing window)."""
    if arr.size < win:
        return np.array([], dtype=float), np.array([], dtype=float)
    shape = (arr.size - win + 1, win)
    strides = (arr.strides[0], arr.strides[0])
    windows = np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)
    return windows.mean(axis=1), windows.std(axis=1, ddof=1)


def _load_metric(path: Path) -> np.ndarray | None:
    """Load a JSON list[float] or list[dict]→float at *layer* key."""
    if not path.exists():
        return None
    try:
        data = json.loads(path.read_text())
    except json.JSONDecodeError:
        return None

    aligns = []
    norms = []

    for d in data:
        aligns.append(d["alignments"][LAYER])
        norms.append(d["grad_norms"]["L1"][0] * d["grad_norms"]["L2"][0] * d["grad_norms"]["L3"][0] )

    return np.array(aligns), np.array(norms)
# ─── Data load ─────────────────────────────────────────────────────────────

runs = [p for p in EXP_DIR.iterdir() if p.is_dir()]
if not runs:
    raise RuntimeError(f"No runs found under {EXP_DIR!s}")

a_data: DefaultDict[str, DefaultDict[str, List[np.ndarray]]] = DefaultDict(lambda: DefaultDict(list))  # alignment
w_data: DefaultDict[str, DefaultDict[str, List[np.ndarray]]] = DefaultDict(lambda: DefaultDict(list))  # weight‑norm

max_epoch = 0
max_itr   = 0
weight_decays: set[str] = set()

for run in runs:
    cfg_path = run / "metrics" / "config.json"
    if not cfg_path.exists():
        continue
    try:
        cfg = json.loads(cfg_path.read_text())
    except json.JSONDecodeError:
        continue

    # ── Filter: lr = 0.01 only ────────────────────────────────────────────
    lr = cfg.get("lr")
    if isinstance(lr, list):
        lr = lr[0]


    init = cfg.get("initialization") or cfg.get("init") or cfg.get("initializer")
    if init not in INITIALISATIONS:
        continue

    wd_val = cfg.get("weight_decay")
    # stringify for consistent dict keys
    wd_key = str(wd_val)
    weight_decays.add(wd_key)

    epochs = cfg.get("epochs", 0)
    max_epoch = max(max_epoch, epochs)

    # Alignment metric – try both possible filenames
    align_path = run / "metrics" / "frac_pos_alignment.json"
    if not align_path.exists():
        align_path = run / "metrics" / "align_frac.json"
    aligns, norms = _load_metric(align_path)
    if aligns is None or aligns.size == 0:
        continue

    # Weight‑norm metric
    # norm_path = run / "metrics" / "align_frac.json"
    # norms = _load_metric(norm_path)
    # weight_norm may be recorded less frequently; allow None – but keep shape later

    max_itr = max(max_itr, aligns.size)

    # Aggregate per init×wd
    a_data[init][wd_key].append(aligns)
    # if norms is not None and norms.size:
    w_data[init][wd_key].append(norms)

if not a_data:
    raise RuntimeError("No valid runs matching lr = 0.01 were found.")

# ─── Pre‑compute sliding stats ─────────────────────────────────────────────

align_mu: Dict[str, Dict[str, np.ndarray]] = {}
align_sd: Dict[str, Dict[str, np.ndarray]] = {}

norm_mu: Dict[str, Dict[str, np.ndarray]] = {}
norm_sd: Dict[str, Dict[str, np.ndarray]] = {}

for init in INITIALISATIONS:
    align_mu[init], align_sd[init] = {}, {}
    norm_mu[init],  norm_sd[init]  = {}, {}

    for wd_key in sorted(weight_decays, key=lambda s: float(s.replace("exp_weight_decay", "nan")) if s.replace('.','',1).isdigit() else math.inf):
        # Alignment
        seeds_a = a_data.get(init, {}).get(wd_key, [])
        mus, stds, min_len = [], [], None
        for s in seeds_a:
            m, sd = _sliding_mean_std(s)
            if m.size == 0:
                continue
            min_len = m.size if min_len is None else min(min_len, m.size)
            mus.append(m)
            stds.append(sd)
        if mus:
            mus = [m[:min_len] for m in mus]
            stds = [s[:min_len] for s in stds]
            align_mu[init][wd_key] = np.vstack(mus).mean(axis=0)
            align_sd[init][wd_key] = np.vstack(stds).mean(axis=0)

        # Weight‑norm
        seeds_w = w_data.get(init, {}).get(wd_key, [])
        mus, stds, min_len = [], [], None
        for s in seeds_w:
            m, sd = _sliding_mean_std(s)
            if m.size == 0:
                continue
            min_len = m.size if min_len is None else min(min_len, m.size)
            mus.append(m)
            stds.append(sd)
        if mus:
            mus = [m[:min_len] for m in mus]
            stds = [s[:min_len] for s in stds]
            norm_mu[init][wd_key] = np.vstack(mus).mean(axis=0)
            norm_sd[init][wd_key] = np.vstack(stds).mean(axis=0)

# ─── Plot ──────────────────────────────────────────────────────────────────

n_wd = len(weight_decays)
color_cycle = cycler("color", PALETTE[:n_wd])

fig, axes = plt.subplots(1, 3, figsize=FIGSIZE, dpi=300,     sharex='col',  sharey='row'  ,constrained_layout=True )
fig.suptitle("Alignment while training with different initializations”  |  lr = 0.001", fontweight="bold")

axes_align = axes[0]

legend_handles = []
legend_labels  = []

for col, init in enumerate(INITIALISATIONS):
    ax_a = axes[col]

    m = {"low":"0.5x", "default":"1x", "high":"2x"}
    ax_a.set_title(f"Init: {m[init]}")
    ax_a.axhline(0, color="black", linestyle="--", linewidth=0.8)

    # Apply same colour cycle to both panels in the column
    ax_a.set_prop_cycle(color_cycle)

    for wd_key, color in zip(sorted(weight_decays), PALETTE):
        # Alignment
        mu_a = align_mu.get(init, {}).get(wd_key)
        sd_a = align_sd.get(init, {}).get(wd_key)
        if mu_a is not None:
            x = np.arange(mu_a.size) + WINDOW - 1
            line, = ax_a.plot(x, mu_a, lw=1.4, label=wd_key)
            ax_a.fill_between(x, mu_a - sd_a, mu_a + sd_a, alpha=0.25)
            if col == 0:  # build legend once
                legend_handles.append(line)
                legend_labels.append(f"wd={wd_key}")

        # Weight‑norm
        mu_w = norm_mu.get(init, {}).get(wd_key)
        sd_w = norm_sd.get(init, {}).get(wd_key)

    ax_a.set_ylabel("Alignment" if col == 0 else "")
    ax = ax_a
    ax.tick_params(axis="both", labelsize=6)
    ax.grid(False)

# fig.tight_layout(rect=[0, 0.05, 1, 0.95])
fig.legend(legend_handles, legend_labels, loc="lower center", ncol=min(n_wd, 4), frameon=False)

if OUTPUT:
    fig.savefig(EXP_DIR / OUTPUT, dpi=300)
    print(f"Saved figure → {EXP_DIR / OUTPUT}")
else:
    plt.show()
