# %%
from pathlib import Path
import matplotlib.pyplot as plt
from time import time
import torch

from torch.utils.data import DataLoader
from symo.group import I, S, O
import symo.optim2 as optim
from symo.notebooks.plot_utils import default_rcparams
from symo.data import ShakespeareDataset
from symo.nanogpt import GPT, GPTConfig, symo_group_spec_v2

# from symo.experiments.utils import inverse_step_schedule

plt.rcParams |= default_rcparams()


seed = 2025

lr = 1.0
momentum: float = 0.95
decay: float = 0.98
damping: float = 0.0
adam_lr: float = 1e-3
muon_lr: float = 0.02
batch_size: int = 4
n_iterations = 1000
n_freq_eval = 10
n_freq_print = 10

dropout_rate = 0.2
block_size = 64
num_layers = 1
embed_size = 64
num_heads = 2
head_size = 32

# %%

dataset = ShakespeareDataset()
vocab_size = dataset.vocab_size

# %%

gpt_config = GPTConfig(
    block_size=256,
    vocab_size=vocab_size,
    n_layer=6,
    n_head=6,
    n_embd=384,
    dropout=0.2,
    bias=True,
)

nano = GPT(gpt_config)
params = tuple(nano.parameters())
named_params = tuple(nano.named_parameters())

# %%

nano_spec = symo_group_spec_v2(nano)

# %%

start = time()
opt = optim.Symo(
    params,
    nano_spec,
    lr=lr,
    block_diag=True,
    decomp_precision="fp64",
)
end = time()

# %%

data_loader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
)

data_iter = iter(data_loader)
x, y = next(data_iter)

# %%

nano.zero_grad()
logits, loss = nano(x, y)
loss.backward()
opt.step()

# %%

print()
