#!/usr/bin/env python3
"""Plot Hebbian alignment vs. weight-decay for several activations and metrics.

All configuration lives in the constants block. The script automatically
scans the experiment directory to discover available activations and weight-decay values.
It renders a clean grid, shows x-axis ticks for each subplot, plots multiple metrics
with distinct colors and markers (white-filled centers), places a single legend at the bottom,
and shows mean ± std as error bars. Y-axis log scaling is adjusted so that the minimum 
value per plot is shifted to a small epsilon, effectively acting as the zero baseline 
for log scale. Additionally, it caches the processed data to avoid recomputing on subsequent runs."""
from __future__ import annotations

import json
from math import ceil, sqrt
from pathlib import Path
import pickle
from typing import Dict, List, DefaultDict
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import FuncFormatter, LogLocator

# ─── Constants ──────────────────────────────────────────────────────────────
EXP_DIR = Path("__path__")
TAIL = 300
METRICS: List[str] = ["L1", "L2"]  # list of metrics stored in frac_pos_alignment.json
OUTPUT = "activations_wd_curves.png"
FIGSIZE = (4,3)  # inches – adjust for larger grid if needed
CACHE_DIR = Path("cleaned_plots/cache")  # directory to store/load cached data

# ─── Helpers ────────────────────────────────────────────────────────────────

def _parse_num(x: str):
    try:
        return float(x)
    except Exception:
        return None


def _get_param(run_dir: Path, key: str):
    cfg = run_dir / "metrics" / "config.json"
    if not cfg.exists():
        return None
    try:
        data = json.loads(cfg.read_text())
        val = data.get(key)
        if isinstance(val, (int, float)):
            return float(val)
        if isinstance(val, str):
            maybe = _parse_num(val)
            return maybe if maybe is not None else val
    except Exception:
        pass
    return None


def _collect_runs(exp_dir: Path) -> List[Path]:
    return [d for d in exp_dir.iterdir() if d.is_dir()]


def _tail_align(run: Path, tail: int, metric: str):
    path = run / "metrics" / "frac_pos_alignment.json"
    if not path.exists():
        return None
    try:
        data = json.loads(path.read_text())
    except json.JSONDecodeError:
        return None
    if not data:
        return None

    seq = (
        [d.get("alignments", {}).get(metric) for d in data]
        if isinstance(data[0], dict)
        else data
    )
    vals = [v for v in seq if v is not None]
    if not vals:
        return None
    return float(np.mean(vals[-tail:])), float(np.std(vals[-tail:]))

# ─── Plotting ───────────────────────────────────────────────────────────────

def plot_weight_decay_curves(exp_dir: Path, metrics: List[str]):
    # Discover runs and parameters
    runs = _collect_runs(exp_dir)
    activations = sorted({
        _get_param(r, "activation")
        for r in runs
        if _get_param(r, "activation") is not None
    })
    wds_all = sorted({
        float(_get_param(r, "weight_decay"))
        for r in runs
        if _get_param(r, "weight_decay") is not None
    })
    if not wds_all:
        raise RuntimeError("No weight-decay values found in experiments.")

    non_zero = [w for w in wds_all if w > 0]
    sentinel = min(non_zero) * 0.1 if non_zero else 1e-6

    # Prepare cache directory and file
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    cache_file = CACHE_DIR / f"{exp_dir.name}_data.pkl"
    if cache_file.exists() and False:
        with open(cache_file, "rb") as f:
            data = pickle.load(f)
        print(f"Loaded cached data from {cache_file}")
    else:
        # Build raw lists of values
        data: Dict[str, Dict[str, DefaultDict[float, List[float]]]] = {
            act: {m: defaultdict(list) for m in metrics} for act in activations
        }
        for run in runs:
            act = _get_param(run, "activation")
            wd = _get_param(run, "weight_decay")
            if None in (act, wd):
                continue
            if (int(wd*1000))%4 == 2:
                continue
            for metric in metrics:
                res = _tail_align(run, TAIL, metric)
                if res is None:
                    continue
                mean, std = res
                key = sentinel if np.isclose(float(wd), 0.0) else float(wd)
                data[act][metric][key].append((mean, std))
        # Save to cache
        with open(cache_file, "wb") as f:
            pickle.dump(data, f)
        print(f"Saved cached data to {cache_file}")

    # Plotting section
    nice = {"linear": "Linear", "relu": "ReLU", "sigmoid": "Sigmoid", "tanh": "Tanh"}
    n = len(activations)
    rows = cols = ceil(sqrt(n))
    if rows * (cols - 1) >= n:
        cols -= 1

    fig, axes = plt.subplots(rows, cols, figsize=FIGSIZE, constrained_layout=True, sharex=True)
    axes = np.array(axes).reshape(-1)
    fig.suptitle("Hebbian Alignment of Gradient", fontweight="bold")
    

    # Define colors and markers for each metric
    colors = {
        "L1": "#e41a1c",  # red
        "L2": "#377eb8",  # blue
        "L3": "#4daf4a",  # green
        "L4": "#984ea3",  # purple
        "L5": "#ff7f00",  # orange
        "L6": "#E1E111",  # yellow
        "L7": "#a65628",  # brown
        "L8": "#f781bf",  # pink
        "L9": "#999999",  # gray
        "L10": "#66c2a5", # turquoise
    }

    markers = {
        "L1": "o",
        "L2": "s",
        "L3": "^",
        "L4": "D",
        "L5": "v",
        "L6": "P",
        "L7": "*",
        "L8": "X",
        "L9": "h",
        "L10": "8",
    }
    for ax, act in zip(axes, activations):
        ax.axhline(0, color="black", lw=1, ls="--")
        my = 0
        for metric in metrics:
            wd_vals = sorted(data[act][metric].keys())
            if not wd_vals:
                continue
            # compute mean and std arrays
            means = []
            stds = []
            for w in wd_vals:
                vals = [m for m, s in data[act][metric][w]]
                errs = [s for m, s in data[act][metric][w]]
                means.append(np.mean(vals))
                stds.append(np.sqrt(np.sum(np.array(errs)**2) / len(errs)))
            # drop last for sigmoid
            # if str(act) == "sigmoid":
            #     wd_vals = wd_vals[:-1]
            #     means = means[:-1]
            #     stds = stds[:-1]

            # shift baseline
            baseline = min(means)
            eps = baseline * 0.1 if baseline > 0 else 1e-6
            shifted = [m for m in means]
            my = max(my,max(shifted))
            # plot with error bars
            ax.plot(
                wd_vals,
                shifted,
                markers[metric] + "-",  # marker + line style
                color=colors[metric],
                markerfacecolor='white',
                markeredgecolor=colors[metric],
                label=metric
            )        
        # ax.set_xscale("log")
        # ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: "0" if np.isclose(x, sentinel) else f"{x:g}"))
        # ax.xaxis.set_minor_locator(LogLocator(base=10, subs=np.arange(2, 10) * 0.1))
        ax.tick_params(axis='both', which='major', labelsize=8)  # adjust size as needed
        # adjust y-limits based on shifted data
        if shifted:
            top = my  * 1.5
            bottom = - my  * 0.5
            ax.set_ylim(bottom, top)
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)
        if ax.get_subplotspec().is_last_row():

            ax.set_xlabel("Weight Decay")
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel("Alignment")
        ax.set_title(nice.get(str(act), str(act)))

    # hide unused axes
    for ax in axes[len(activations):]:
        ax.set_visible(False)

    # legend at bottom
    handles, labels = axes[0].get_legend_handles_labels()
    # fig.legend(handles, labels, loc='lower center', ncol=len(metrics), frameon=False, bbox_to_anchor=(0.5, -0.02))
    if OUTPUT:
        fig.savefig(OUTPUT, dpi=300) #bbox_inches="tight"
        print(f"Saved: {OUTPUT}")
    else:
        plt.show()


if __name__ == "__main__":
    plot_weight_decay_curves(EXP_DIR, METRICS)