# %%
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from time import time
import torch

from torch.utils.data import DataLoader
from symo.group import I, S, O
from symo.factory2 import CovFactory
from symo.notebooks.plot_utils import default_rcparams, orange_blue
from symo.data import ShakespeareDataset
from symo.nanogpt import GPT, GPTConfig, symo_group_spec_v2, symo_filtered_spec

plt.rcParams |= default_rcparams()

seed = 2025

batch_size = 500
dropout_rate = 0.2
block_size = 50
num_layers = 1
embed_size = 4
num_heads = 1
vocab_size = 7
ff_mult = 2

# %%

gen = torch.Generator().manual_seed(seed)

# %%

# dataset = ShakespeareDataset(block_size=block_size)
# vocab_size = dataset.vocab_size

# %%

aliases = [
    ("transformer.wte.weight", r"$W_\text{emb}$"),
    ("transformer.h.0.ln_1.weight", r"$\ell_1$"),
    ("transformer.h.0.attn.c_attn_qk.weight", r"$W_{QK}$"),
    ("transformer.h.0.attn.c_attn_v.weight", r"$W_V$"),
    ("transformer.h.0.attn.c_proj.weight", r"$W_\text{proj}$"),
    ("transformer.h.0.ln_2.weight", r"$\ell_2$"),
    ("transformer.h.0.mlp.c_fc.weight", r"$W_{\text{ff}_1}$"),
    ("transformer.h.0.mlp.c_proj.weight", r"$W_{\text{ff}_2}$"),
    ("transformer.ln_f.weight", r"$\ell_3$"),
]

keep_only = [
    "transformer.wte.weight",
    "transformer.h.0.attn.c_attn_qk.weight",
    "transformer.h.0.attn.c_attn_v.weight",
    "transformer.h.0.attn.c_proj.weight",
    "transformer.h.0.mlp.c_fc.weight",
    "transformer.h.0.mlp.c_proj.weight",
]

aliases = [(n, v) for n, v in aliases if n in keep_only]

# %%

gpt_config = GPTConfig(
    block_size=block_size,
    vocab_size=vocab_size,
    n_layer=num_layers,
    n_head=num_heads,
    n_embd=embed_size,
    ff_mult=ff_mult,
    dropout=0.0,
    bias=False,
    pe="linear",
)

nano = GPT(gpt_config)
for n, p in nano.named_parameters():
    if n not in keep_only:
        p.requires_grad_(False)

# %%

params = tuple([p for p in nano.parameters() if p.requires_grad])
named_params = tuple([(n, p) for n, p in nano.named_parameters() if p.requires_grad])

full_params = nano.named_parameters()

# %%

nano_spec = symo_group_spec_v2(nano, qk_group="O", heads_group="I")
nano_spec = symo_filtered_spec(nano_spec, full_params, named_params)

# %%

x = torch.randint(0, vocab_size, (batch_size, block_size), generator=gen)
y = torch.randint(0, vocab_size, (batch_size, block_size), generator=gen)

# %%

nano.zero_grad()
logits, loss = nano(x, y)
loss.backward()

# %%

grads = [p.grad for p in params]
cov_factory = CovFactory(groups_spec=nano_spec)
cov_factory.outer_update(grads)
cov: torch.Tensor = cov_factory.cov()

# %%


def diag_scale(mat):
    mat_scale = 1 / np.sqrt(mat.diagonal())
    scaled = mat_scale[None, :] * mat * mat_scale[:, None]
    return scaled


cov_numpy = cov.detach().cpu().numpy()
cov_norm = diag_scale(cov_numpy)

# %%


def plot_matrix(
    ax,
    mat,
    subtitle,
    show_bar: bool = True,
    clim_max: float | None = None,
    clim_min: float | None = None,
):
    cmap = orange_blue().reversed()
    norm = mcolors.TwoSlopeNorm(
        vcenter=0,
    )

    im = ax.matshow(mat, cmap=cmap, norm=norm, aspect="equal")
    if show_bar:
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    ax.set_xticks([])
    ax.set_yticks([])

    if clim_min is None:
        clim_min = np.min(mat)

    if clim_max is None:
        clim_max = np.max(mat)

    im.set_clim(clim_min, clim_max)


# %%

names, aliases = zip(*aliases)
param_sizes = [p.numel() for _, p in named_params]
param_cumsizes = np.cumsum(param_sizes).tolist()

names_tuple = list(zip(names, aliases, param_sizes, param_cumsizes))

# %%


def matshow_named_axes(
    ax,
    param_names: list[tuple[str, str, int, int]],
    shift: float = -8,
    width: float = 0.1,
    fontsize: int = 12,
):
    for _, alias, size, end in param_names:
        start = end - size
        ax.axhline(y=start - 0.5, color="gray", linestyle="-", linewidth=width)
        ax.axvline(x=start - 0.5, color="gray", linestyle="-", linewidth=width)
        ax.text(
            shift,
            (start + end) / 2,
            alias,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=fontsize,
            # fontweight="bold",
        )

        ax.text(
            (start + end) / 2,
            shift,
            alias,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=fontsize,
            # fontweight="bold",
        )


# %%

cmap = orange_blue().reversed()
norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

fig, ax = plt.subplots()
plot_matrix(ax, cov_norm, subtitle=None, show_bar=False)
matshow_named_axes(ax, names_tuple)
fig.show()

# %%
