import torch
from trainkit.saving import load_object
from nn.gpt2 import StructGPT as GPT
from nn.cola_nn import cola_parameterize
from nn.cola_nn import save_spectrum
from nn.cola_nn import get_model_summary_and_flops

block_size, vocab_size = 1024, 50_304
# n_layer, n_head, n_embd = 12, 12, 768
# n_layer, n_head, n_embd = 12, 6, 384
n_layer, n_head, n_embd = 12, 4, 248
dropout, bias = 0.0, False
init_lr, weight_decay, beta1, beta2 = 6e-4, 1e-1, 0.9, 0.95
device = 'cuda'
struct = "dense"
# struct = "btt_spect"
# struct = "dense_spect"
tt_dim, tt_rank, num_blocks, rank_frac = 2, 1, 2, 0.1
every_n_fwds = 1
layers, input_lr_mult = "all_but_last", 1.
device_type = 'cuda' if 'cuda' in device else 'cpu'  # for later use in torch.autocast

model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, bias=bias, vocab_size=vocab_size,
                  dropout=dropout)
cola_kwargs = dict(tt_dim=tt_dim, tt_rank=tt_rank, num_blocks=num_blocks, rank_frac=rank_frac, every_n_fwds=1)
optim_kwargs = {"weight_decay": weight_decay, "lr": init_lr, "betas": (beta1, beta2), "device_type": device_type}
model, _ = cola_parameterize(GPT, model_args, init_lr, target_config=None, struct=struct, layer_select_fn=layers, device=device,
                             cola_kwargs=cola_kwargs, optim_kwargs=optim_kwargs)
input_shape = (1, block_size)
fake_input = torch.randint(low=0, high=vocab_size, size=input_shape).to(device)
_ = get_model_summary_and_flops(model, (fake_input, fake_input))

X = torch.randint(low=0, high=vocab_size, size=input_shape).to(device)
Y = torch.randint(low=0, high=vocab_size, size=input_shape).to(device)
_ = model(X, Y)
_ = model(X, Y)

save_spectrum(model, log_dir="./")
results = load_object(filepath="./sigmas.pkl")
for name in results.keys():
    print(f"{name}: {results[name][-1]:1.3e}")
