#!/usr/bin/env python3
import json
import numpy as np
import matplotlib.pyplot as plt
import argparse
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import MultipleLocator

def plot_L2_comparison(cache_path: str, out_path: str):
    # 1) Load the cached deltas
    with open(cache_path, 'r') as f:
        cache = json.load(f)

    # 2) Filter for all entries at epoch 10
    epoch10 = [e for e in cache if e.get('epoch') == 50]
    if not epoch10:
        raise ValueError("No records found for epoch 10 in cached_deltas.json")

    # 3) Pick the batch with highest mean Hebbian (L2) alignment
    mean_hebb_list = [np.mean(e['alignments']['L2']) for e in epoch10]
    best_idx = 0#int(np.argmin(np.abs(mean_hebb_list)))
    entry = epoch10[best_idx]
    print(len(entry['hebbian_update']))
    # 4) Extract layer-3 updates
    hebb = np.array(entry['hebbian_update'][1])
    grad = np.array(entry['gradient'][1])

    # 5) Normalize each by its Frobenius norm
    hebb_norm = hebb / (np.linalg.norm(hebb) + 1e-12)
    grad_norm = -grad / (np.linalg.norm(grad) + 1e-12)

    # 6) Compute difference
    diff = hebb_norm - grad_norm

    s = 20

    # 7) Compute global vmin/vmax across all three maps for consistent coloring
    vmin = min(hebb_norm[:s,:s].min(), grad_norm[:s,:s].min(), diff[:s,:s].min())
    vmax = max(hebb_norm[:s,:s].max(), grad_norm[:s,:s].max(), diff[:s,:s].max())

    # 8) Compute mean alignments for titles
    mean_hebb = mean_hebb_list[best_idx]
    mean_grad = np.mean(entry['alignments']['L2'])

    # 9) Plot side-by-side (3 panels) and save
    # fig, (ax_h, ax_g,ax_d) = plt.subplots(1, 3, figsize=(10, 10))
    fig, (ax_h, ax_g) = plt.subplots(1, 2, figsize=(4, 3))

    im_h = ax_h.imshow(hebb_norm[:s,:s], aspect='equal', vmin=vmin, vmax=vmax,cmap='coolwarm')
    ax_h.set_title("Hebbian")
    ax_h.set_xlabel("Dimension 2")
    ax_h.set_ylabel("Dimension 1")
    ax_h.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax_h.yaxis.set_major_locator(MaxNLocator(integer=True))

    im_g = ax_g.imshow(grad_norm[:s,:s], aspect='equal', vmin=vmin, vmax=vmax,cmap='coolwarm')
    ax_g.set_title("Gradient")
    ax_g.set_xlabel("Dimension 2")
    ax_g.set_ylabel("")  # shared y-axis
    ax_g.xaxis.set_major_locator(MaxNLocator(integer=True))
    # ax_g.yaxis.set_major_locator(MaxNLocator(integer=True))
    tick_spacing = 5  # or 2, 10, etc., depending on your data size

    ax_h.xaxis.set_major_locator(MultipleLocator(tick_spacing))
    ax_h.yaxis.set_major_locator(MultipleLocator(tick_spacing))
    ax_g.xaxis.set_major_locator(MultipleLocator(tick_spacing))
    # ax_g.yaxis.set_major_locator(MultipleLocator(tick_spacing))
    half_size = plt.rcParams["font.size"] * 0.5  # Or just use a fixed number like 6
    ax_g.yaxis.set_ticks([])
    ax_g.set_yticklabels([])
    # Hebbian plot
    ax_h.tick_params(axis='both', which='major', labelsize=half_size)

    # Gradient plot
    ax_g.tick_params(axis='both', which='major', labelsize=half_size)
    # im_d = ax_d.imshow(diff[:s,:s], aspect='equal', vmin=vmin, vmax=vmax)
    # ax_d.set_title("Difference (Hebb – Grad)")
    # ax_d.set_xlabel("input dim")
    # ax_d.set_ylabel("")

    plt.suptitle(f"Normalized Weight Update Example \n Cosine Similarity: {mean_hebb:.3f}", y=0.925, fontweight="bold")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)
    print(f"Saved comparison heat-map (batch {best_idx}) to {out_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Plot normalized L2 Hebb vs. grad and their difference with consistent color scales"
    )
    parser.add_argument(
        "--cache",
        type=str,
        default = "__path__/metrics/cached_deltas.json",
        help="Path to cached_deltas.json",
    )
    parser.add_argument(
        "--out",
        type=str,
        default="graph_L2_hebbian_grad_updates.png",
        help="Where to save the comparison PNG",
    )
    args = parser.parse_args()
    plot_L2_comparison(args.cache, args.out)
