#!/usr/bin/env python
"""
layer_plots.py – Fraction-positive alignment, Layers 1 & 3 only

Rows  : initial-scale order 0.5× → 1× → 2×
Cols  : [SGD-L1 | SGD-L3 | Oja-L1 | Oja-L3]

Adds curly brackets and labels over optimiser pairs, a centred legend
at the bottom, and a figure-wide title.
"""

import json, glob
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

for update_rule in ["standard","oja"]:
    for ext in ["","_wd"]:


        # ── matplotlib minimalist look ──────────────────────────────────────────────
        plt.rcParams.update({
            "font.family":      "sans-serif",
            "font.size":        8,
            "figure.dpi":       300,
            "figure.facecolor": "white",
            "axes.linewidth":   0.8,
            "axes.edgecolor":   "black",
            "lines.linewidth":  1.2,
            "lines.markersize": 3.5,
        })

        smoothing_window = 200     # moving-average window
        wd_cutoff        = 0.01#0.00256 # minimum weight-decay to plot

        # ── helper ──────────────────────────────────────────────────────────────────
        def moving_average(x: np.ndarray, window: int) -> np.ndarray:
            """Centered moving average (window=1 → identity)."""
            if window <= 16:
                return x
            return (
                pd.Series(x.flatten())               # flatten (N,1) ➜ (N,)
                .rolling(window, center=True, min_periods=1)
                .mean()
                .to_numpy()
            )

        # ── locate all run directories ──────────────────────────────────────────────
        base_pattern = (
            "__path__/*"
        )
        run_dirs = sorted(glob.glob(base_pattern))
        if not run_dirs:
            raise ValueError(f"No runs found matching {base_pattern!r}")


        # ── group runs by (initialisation, optimiser) ───────────────────────────────
        groups: dict[tuple[str, str], list[tuple[float, Path]]] = {}
        for rd_path in run_dirs:
            rd = Path(rd_path)
            cfg_path = rd / "metrics" / "config.json"
            if not cfg_path.exists():
                continue

            cfg  = json.loads(cfg_path.read_text())
            init = cfg.get("initialization", "standard")
            opt  = cfg.get("optimizer", "default")          # "SGD" or "Hebb"
            wd   = cfg.get("weight_decay")

            if wd is None or( wd>0 and wd < wd_cutoff):
                continue
            if cfg["update_rule"] != update_rule:
                continue
            groups.setdefault((init, opt), []).append((wd, rd))

        for run_list in groups.values():
            run_list.sort(key=lambda t: t[0])               # sort by weight-decay

        if not groups:
            raise ValueError("No valid runs after filtering weight_decay ≥ "
                            f"{wd_cutoff}")

        # ── load alignment series into memory ───────────────────────────────────────
        all_data: dict[tuple[str, str], dict[float, dict[str, np.ndarray]]] = {}
        for (init, opt), run_list in groups.items():
            init_dict = {}
            for wd, rd in run_list:
                f = rd / "metrics" / "frac_pos_alignment.json"
                if not f.exists():
                    continue
                raw = json.loads(f.read_text())

                layers = {f"L1{ext}": [], f"L3{ext}": []}   # only L1 & L3
                for entry in raw:
                    aligns = entry.get("alignments", {})
                    for L in layers:
                        layers[L].append(aligns.get(L, np.nan))

                init_dict[wd] = {L: np.asarray(layers[L], dtype=float) for L in layers}
            all_data[(init, opt)] = init_dict

        # ── determine global min length across all series ───────────────────────────
        min_len = min(
            arr.shape[0]
            for combo in all_data.values()
            for layer_dict in combo.values()
            for arr in layer_dict.values()
        )
        steps = np.arange(min_len)

        # ── colour palette ──────────────────────────────────────────────────────────
        cmap      = plt.cm.rainbow
        max_runs  = max(len(runs) for runs in groups.values())
        color_seq = [cmap(i / max(1, max_runs - 1)) for i in range(max_runs)]

        # ── mapping & custom row order ──────────────────────────────────────────────
        init_name   = {"low": "0.5×", "standard": "1×", "default": "1×", "high": "2×"}
        init_value  = {"low": 0.5,   "standard": 1,   "default": 1,   "high": 2}

        # collect unique inits then sort by numeric scale (0.5 ×, 1 ×, 2 ×)
        unique_inits = {init for init, _ in groups}
        inits_ordered = sorted(unique_inits, key=lambda k: init_value.get(k, 999))

        optimizers = ["SGD", "Hebb"]                    # “Hebb” renamed “Oja”
        layers     = [f"L1{ext}", f"L3{ext}"]           # order: L1 then L3

        n_rows, n_cols = len(inits_ordered), 4
        fig, axes = plt.subplots(
            n_rows, n_cols,
            figsize=(12, 2.0 * n_rows),                 # reduced row height
            sharex=True, sharey=True,
            dpi=300
        )
        if n_rows == 1:
            axes = axes[np.newaxis, :]                  # force 2-D array for uniformity

        # ── draw curves ─────────────────────────────────────────────────────────────
        for r, init in enumerate(inits_ordered):
            disp_init = init_name.get(init, init)
            for c_opt, opt in enumerate(optimizers):
                run_list = groups.get((init, opt), [])

                for c_layer, layer in enumerate(layers):
                    col = c_opt * 2 + c_layer
                    ax  = axes[r, col]

                    for i, (wd, _) in enumerate(run_list):
                        arr = all_data[(init, opt)][wd][layer][:min_len]
                        arr = arr[:, 0] if arr.ndim == 2 else arr
                        ax.plot(steps,
                                moving_average(arr, smoothing_window),
                                color=color_seq[i],
                                label=f"γ = {wd:g}")

                    # cosmetics
                    ax.axhline(0, ls="--", lw=0.8, color="gray")
                    ax.set_ylim(-1, 1)
                    ax.spines["top"].set_visible(False)
                    ax.spines["right"].set_visible(False)
                    ax.grid(False)

                    if r == 0:                          # column titles = layer number
                        layer_label = "Layer " + layer.replace(ext, "")[1:]
                        ax.set_title(layer_label)

                    if col == 0:                        # y-label leftmost only
                        ax.set_ylabel(f"Init scale: {disp_init}", labelpad=8)

                    if r == n_rows - 1:                 # x-axis label bottom row
                        ax.set_xlabel("Step")

        # ── legend at the bottom ────────────────────────────────────────────────────
        handles, labels = axes[0, 0].get_legend_handles_labels()
        fig.legend(handles, labels,
                frameon=False, fontsize=12,
                ncol=min(len(labels), 6),
                loc="lower center",
                bbox_to_anchor=(0.5, 0.025))

        # ── figure-wide title ───────────────────────────────────────────────────────

        title =  "Hebbian-Gradient Update Alignment"  if ext != "_wd" else  "Hebbian-Full Weight Update Alignment"
        
        fig.suptitle(f"{title} While Training\n(Windowed Mean w={smoothing_window})",
                    fontsize=14, y=0.88)               # larger & lowered

        # ── layout before brackets ──────────────────────────────────────────────────
        plt.tight_layout(rect=[0, 0.08, 1, 0.80])       # 20 % top margin

        fig.canvas.draw_idle()                           # positions now final

        # ── helper to draw brackets and labels ──────────────────────────────────────
        def add_bracket(ax_left, ax_right, label, pad=0.07):
            """Draw a bracket over two neighbouring axes and label it."""
            pos_l, pos_r = ax_left.get_position(), ax_right.get_position()
            x0, x1 = pos_l.x0, pos_r.x1
            y = pos_l.y1 + pad                         # gap above top axis

            # horizontal line
            fig.add_artist(mlines.Line2D([x0, x1], [y, y], transform=fig.transFigure,
                                        lw=1.0, color="gray", clip_on=False))
            # vertical ticks
            tick = pad * 0.5
            for x in (x0, x1):
                fig.add_artist(mlines.Line2D([x, x], [y, y - tick],
                                            transform=fig.transFigure,
                                            lw=1.0, color="gray", clip_on=False))
            # centred text
            fig.text((x0 + x1) / 2, y + tick * 1.2, label,
                    ha="center", va="bottom", fontsize=12)


        opt = update_rule.title() if update_rule=="oja" else "Hebbian"
        # ── brackets for optimisers (use first row as reference) ────────────────────
        add_bracket(axes[0, 0], axes[0, 1], "Optimizer: SGD")       # columns 0-1
        add_bracket(axes[0, 2], axes[0, 3], "Optimizer: "+opt)       # columns 2-3

        # ── save figure ─────────────────────────────────────────────────────────────
        Path("figures").mkdir(parents=True, exist_ok=True)
        fname = f"figures/alignment_by_params_{update_rule}_{ext}.png"
        fig.savefig(fname,
                    facecolor="white", dpi=300, bbox_inches="tight")
        print(f"saved to {fname}")