#!/usr/bin/env python3
"""Plot alignment_fraction curves for Section 3 “Hebbian-no-SGD” experiments
using a *sliding-window* mean ± standard-deviation.

This revision replaces the previous two-stage (block-average + grouped error-bar)
aggregation with a *single* rolling window of **200** raw iterations:

* For each seed, compute the mean and standard deviation over every consecutive
  window of 200 points (trailing window -- indexes *t-199 … t*).
* Across seeds, average the window means to obtain the curve, and average the
  window standard deviations to obtain the shaded band.
* Lines are plotted for each (update_rule, activation_update) combo and
  learning-rate.  Shaded regions show ± the average within-window SD.

All other styling (colour palette, layout, etc.) follows the house style.
"""
from __future__ import annotations

import json
from pathlib import Path
from typing import DefaultDict, Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd  # noqa: F401 – kept in case downstream utilities rely on it
from cycler import cycler

# ─── User-configurable constants ────────────────────────────────────────────
EXP_DIR   = Path("__path__")  # parent of all run dirs
OUTPUT    = "alignment_sliding_fill.png"                    # set None ➜ plt.show
PREF_LAY  = "L2"                                             # preferred layer key

UPDATE_RULES       = ["standard", "oja"]                   # cfg["update_rule"]
ACTIVATION_UPDATES = ["pre", "post"]                      # cfg["activation_update"]
COMBOS             = [f"{u}_{a}" for u in UPDATE_RULES for a in ACTIVATION_UPDATES]

WINDOW      = 200      # length of sliding window (raw iterations)
FIGSIZE     = (4, 2)   # inches – figure size

# ─── Styling to match house style ───────────────────────────────────────────
COLORS = ["#d7191c", "#fdae61", "#2b83ba", "#abdda4"]  # one colour per combo
plt.rcParams["axes.prop_cycle"] = cycler("color", COLORS)

# ─── Helper functions ───────────────────────────────────────────────────────

def _sliding_mean_std(arr: np.ndarray, win: int = WINDOW) -> tuple[np.ndarray, np.ndarray]:
    if arr.size < win:
        return np.array([], dtype=float), np.array([], dtype=float)
    shape = (arr.size - win + 1, win)
    strides = (arr.strides[0], arr.strides[0])
    windows = np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)
    return windows.mean(axis=1), windows.std(axis=1, ddof=1)


def _extract_layer(entry: dict, preferred: str = PREF_LAY) -> float | None:
    aligns = entry.get("alignments", {}) if isinstance(entry, dict) else {}
    if not aligns:
        return None
    if preferred in aligns:
        return aligns[preferred]
    wd_key = f"{preferred}"
    if wd_key in aligns:
        return aligns[wd_key]
    for v in aligns.values():
        if isinstance(v, (int, float)):
            return v
    return None

# ─── Data Load ─────────────────────────────────────────────────────────────

runs = [p for p in EXP_DIR.iterdir() if p.is_dir()]
if not runs:
    raise RuntimeError(f"No runs found under {EXP_DIR!s}")

max_epoch = 0
meta_by_run = {}
for run in runs:
    cfg_path = run / "metrics" / "config.json"
    if not cfg_path.exists():
        continue
    try:
        cfg = json.loads(cfg_path.read_text())
    except json.JSONDecodeError:
        continue
    lr = cfg.get("lr")
    upd = cfg.get("update_rule")
    act = cfg.get("activation_update")
    max_epoch = max(max_epoch, cfg.get("epochs"))

    if lr is None or upd not in UPDATE_RULES or act not in ACTIVATION_UPDATES:
        continue
    meta_by_run[run] = {"lr": float(lr), "combo": f"{upd}_{act}"}

if not meta_by_run:
    raise RuntimeError("No valid runs found.")

align_data: DefaultDict[float, DefaultDict[str, List[np.ndarray]]] = {}

max_itr = 0
for run, meta in meta_by_run.items():
    path = run / "metrics" / "frac_pos_alignment.json"
    if not path.exists():
        path = run / "metrics" / "align_frac.json"
    if not path.exists():
        continue
    try:
        data = json.loads(path.read_text())
    except json.JSONDecodeError:
        continue
    max_itr = max(max_itr, len(data))
    vals = [_extract_layer(d) for d in data] if isinstance(data[0], dict) else data
    vals = np.asarray([v for v in vals if v is not None], dtype=float)
    if vals.size == 0:
        continue
    lr, combo = meta["lr"], meta["combo"]
    align_data.setdefault(lr, {}).setdefault(combo, []).append(vals)

lrs = align_data.keys()
if len(lrs) != 2:
    raise RuntimeError(f"Expected 2 learning rates, found {len(lrs)}")

mu_by_lr, std_by_lr = {}, {}
for lr in lrs:
    mu_by_lr[lr], std_by_lr[lr] = {}, {}
    for combo in COMBOS:
        seeds = align_data[lr][combo]
        mus, stds, min_len = [], [], None
        for s in seeds:
            mu, std = _sliding_mean_std(np.asarray(s))
            if mu.size == 0:
                continue
            if min_len is None or mu.size < min_len:
                min_len = mu.size
            mus.append(mu)
            stds.append(std)
        if min_len is None:
            continue
        mus = [m[:min_len] for m in mus]
        stds = [s[:min_len] for s in stds]
        mu_by_lr[lr][combo] = np.vstack(mus).mean(axis=0)
        std_by_lr[lr][combo] = np.vstack(stds).mean(axis=0)

# ─── Plot ────────────────────────────────────────────────────────────────

fig, axes = plt.subplots(1, 2, figsize=FIGSIZE, dpi=300,  sharey=True)
fig.suptitle(f"Hebbian Learning SGD Alignment \nRolling Window {WINDOW} (mean ± SD)", fontweight="bold")
fig.subplots_adjust(left=-0.1 , right=1.1, wspace=0.1, top=0.60, bottom = 0.25)
legend_lines = []
legend_labels = []

for ax, lr in zip(axes, lrs):
    ax.axhline(0, color="black", linestyle="--", linewidth=1)
    for combo, color in zip(COMBOS, COLORS):
        mu = mu_by_lr[lr].get(combo)
        std = std_by_lr[lr].get(combo)
        if mu is None or std is None:
            continue
        x = max_epoch/max_itr * (np.arange(mu.size) + WINDOW - 1)
        line, = ax.plot(x, mu, lw=1.6, color=color)
        ax.fill_between(x, mu - std, mu + std, color=color, alpha=0.25)
        if lr == list(lrs)[0]:
            legend_lines.append(line)
            upd, act = combo.split("_")
            if upd == "standard":
                upd = "Normalized Hebbian"
            else:
                upd = "Oja"
            legend_labels.append(f"{upd} • {act.capitalize()}")
        ax.set_title(fr"Learning Rate $\eta$ = {lr:g}")
        ax.set_xlabel("Epoch")
        ax.grid(False)#, ls="--", lw=0.5, alpha=0.4)
        
        ax.tick_params(axis="both", labelsize=6)
        ax.set_ylim(-0.5,0.5)

axes[0].set_ylabel("Alignment")

axes[1].set_ylabel("")

fig.legend(legend_lines, legend_labels, loc="lower center", ncol=4, frameon=False, bbox_to_anchor=(0.5, -0.12))

if OUTPUT:
    fig.savefig(EXP_DIR / OUTPUT, dpi=300, bbox_inches="tight")
    print(f"Saved figure → {EXP_DIR / OUTPUT}")
else:
    plt.show()
