#!/usr/bin/env python3
"""Plot Hebbian alignment vs. weight-decay for several activations and metrics.

Styled to match the second example script: weight-decay values > 0.05 are skipped
so the view is zoomed on the interesting regime, and the y-limits are set
proportionally to the observed range per subplot so each activation uses the
available space better.

All configuration lives in the constants block. The script automatically scans
 the experiment directory to discover available activations and weight-decay
values. It renders a clean grid, shows x-axis ticks for each subplot, plots
multiple metrics with distinct colors and markers (white-filled centers),
places a single legend at the bottom, and shows mean ± std as error bars.
Y-axis log scaling is adjusted so that the minimum value per plot is shifted to
 a small epsilon, effectively acting as the zero baseline for log scale.
Additionally, it caches the processed data to avoid recomputing on subsequent
runs."""
from __future__ import annotations

import json
from math import ceil, sqrt
from pathlib import Path
import pickle
from typing import Dict, List, DefaultDict
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import FuncFormatter, LogLocator
from matplotlib.ticker import MaxNLocator

# ─── Constants ──────────────────────────────────────────────────────────────
EXP_DIR = Path("__path__")
TAIL = 300
# Added L7 so the metric list matches the second script exactly
METRICS: List[str] = [
    "L1", "L2", "L3", "L4", "L5", "L6",
]
OUTPUT = "activations_wd_curves_large.png"
FIGSIZE = (4, 3)  # inches – adjust for larger grid if needed
CACHE_DIR = Path("cleaned_plots/cache")  # directory to store/load cached data

# ─── Helpers ────────────────────────────────────────────────────────────────

def _parse_num(x: str):
    try:
        return float(x)
    except Exception:
        return None


def _get_param(run_dir: Path, key: str):
    cfg = run_dir / "metrics" / "config.json"
    if not cfg.exists():
        return None
    try:
        data = json.loads(cfg.read_text())
        val = data.get(key)
        if isinstance(val, (int, float)):
            return float(val)
        if isinstance(val, str):
            maybe = _parse_num(val)
            return maybe if maybe is not None else val
    except Exception:
        pass
    return None


def _collect_runs(exp_dir: Path) -> List[Path]:
    return [d for d in exp_dir.iterdir() if d.is_dir()]


def _tail_align(run: Path, tail: int, metric: str):
    path = run / "metrics" / "frac_pos_alignment.json"
    if not path.exists():
        return None
    try:
        data = json.loads(path.read_text())
    except json.JSONDecodeError:
        return None
    if not data:
        return None

    seq = (
        [d.get("alignments", {}).get(metric) for d in data]
        if isinstance(data[0], dict)
        else data
    )
    vals = [v for v in seq if v is not None]
    if not vals:
        return None
    return float(np.mean(vals[-tail:])), float(np.std(vals[-tail:]))

# ─── Plotting ───────────────────────────────────────────────────────────────

def plot_weight_decay_curves(exp_dir: Path, metrics: List[str]):
    # Discover runs and parameters
    runs = _collect_runs(exp_dir)
    activations = sorted({
        _get_param(r, "activation")
        for r in runs
        if _get_param(r, "activation") is not None
    })
    wds_all = sorted({
        float(_get_param(r, "weight_decay"))
        for r in runs
        if _get_param(r, "weight_decay") is not None
    })
    if not wds_all:
        raise RuntimeError("No weight-decay values found in experiments.")

    non_zero = [w for w in wds_all if w > 0]
    sentinel = min(non_zero) * 0.1 if non_zero else 0.0001

    # Prepare cache directory and file
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    cache_file = CACHE_DIR / f"{exp_dir.name}_data.pkl"
    if cache_file.exists() and False:  # disable cache for now
        with open(cache_file, "rb") as f:
            data = pickle.load(f)
        print(f"Loaded cached data from {cache_file}")
    else:
        # Build raw lists of values
        data: Dict[str, Dict[str, DefaultDict[float, List[float]]]] = {
            act: {m: defaultdict(list) for m in metrics} for act in activations
        }
        for run in runs:
            act = _get_param(run, "activation")
            wd = _get_param(run, "weight_decay")
            if None in (act, wd):
                continue
            # --- Style tweak from the second script: focus on small weight decays
            if (int(wd*1000))%4 == 2:
                continue
            for metric in metrics:
                res = _tail_align(run, TAIL, metric)
                if res is None:
                    continue
                mean, std = res
                key = sentinel if np.isclose(float(wd), 0.0) else float(wd)
                data[act][metric][key].append((mean, std))
        # Save to cache
        with open(cache_file, "wb") as f:
            pickle.dump(data, f)
        print(f"Saved cached data to {cache_file}")

    # Plotting section
    nice = {"linear": "Linear", "relu": "ReLU", "sigmoid": "Sigmoid", "tanh": "6x512 tanh MLP"}
    n = len(activations)
    rows = cols = ceil(sqrt(n))
    if rows * (cols - 1) >= n:
        cols -= 1

    fig, axes = plt.subplots(rows, cols, figsize=FIGSIZE)
    fig.subplots_adjust(right=0.75, top=.8, left=0.2, bottom=0.2)  # Leave space for legend on the right
    axes = np.array(axes).reshape(-1)
    fig.suptitle("Hebbian Alignment of Gradient", fontweight="bold")

    # Define colors and markers for each metric (kept simple)
    colors = {
        "L1": "#e41a1c",  # red
        "L2": "#377eb8",  # blue
        "L3": "#4daf4a",  # green
        "L4": "#984ea3",  # purple
        "L5": "#ff7f00",  # orange
        "L6": "#E1E111",  # yellow
        "L7": "#a65628",  # brown
        "L8": "#f781bf",  # pink
        "L9": "#999999",  # gray
        "L10": "#66c2a5", # turquoise
    }

    markers = {
        "L1": "o",
        "L2": "s",
        "L3": "^",
        "L4": "D",
        "L5": "v",
        "L6": "P",
        "L7": "*",
        "L8": "X",
        "L9": "h",
        "L10": "8",
    }
    for ax, act in zip(axes, activations):
        ax.axhline(0, color="black", lw=1, ls="--")
        max_y = 0
        for metric in metrics:
            wd_vals = sorted(data[act][metric].keys())
            if not wd_vals:
                continue
            # compute mean and std arrays
            means, stds = [], []
            for w in wd_vals:
                vals = [m for m, s in data[act][metric][w]]
                errs = [s for m, s in data[act][metric][w]]
                means.append(np.mean(vals))
                stds.append(np.sqrt(np.sum(np.array(errs) ** 2) / len(errs)))
            # drop last for sigmoid

            # shift baseline (keep the numbers intact, only y-limits are local)
            max_y = max(max_y, max(means))
            ax.plot(
                wd_vals,
                means,
                markers[metric]+"-",
                #"o" + "-",  # marker + line style
                color=colors[metric],
                markerfacecolor='white',
                #markeredgecolor=colors[metric],
                label=metric
            )        


        # Axes formatting
        # ax.set_xscale("log")
        # ax.xaxis.set_major_formatter(
        #     FuncFormatter(lambda x, _: "0" if np.isclose(x, sentinel) else f"{x:g}")
        # )
        # ax.xaxis.set_minor_locator(LogLocator(base=10, subs=np.arange(2, 10) * 0.1))
        ax.tick_params(axis="both", which="major", labelsize=8)

        # Style tweak from the second script: tight y-limits around the data
        # if max_y > 0:
        #     ax.set_ylim(-1 * max_y, 1.5 * max_y)

        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel("Weight Decay")
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel("Alignment")
        ax.set_title(nice.get(str(act), str(act)))

    # hide unused axes
    for ax in axes[len(activations):]:
        ax.set_visible(False)

    # legend at bottom (single shared legend)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))  # Try 3–6 depending on space
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="center right",
        ncol=1,
        frameon=True,
        bbox_to_anchor=(0.98, 0.5)
    )

    if OUTPUT:
        fig.savefig(OUTPUT, dpi=300)
        print(f"Saved: {OUTPUT}")
    else:
        plt.show()


if __name__ == "__main__":
    plot_weight_decay_curves(EXP_DIR, METRICS)
