import torch
from nn import MLP
from nn.cola_nn import colafy


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(seed=21)
repeat_n = 1_000
# batch_size = 68
batch_size = 1024
# dim_in, dim_out, depth, width = 32 * 32 * 3, 10, 9, 2 ** 8
dim_in, dim_out, depth, width = 32 * 32 * 3, 10, 9, 2 ** 13
model = MLP(dim_in, dim_out, depth, width)
struct, layers, rank_frac, tt_dim, tt_rank = "none", "all_but_last", 0.1, 2, 1
# struct, layers, rank_frac, tt_dim, tt_rank = "low_rank", "all_but_last", 0.1, 2, 1
# struct, layers, rank_frac, tt_dim, tt_rank = "block_tt", "all_but_last", 0.1, 2, 1
colafy(model, struct=struct, layers=layers, rank_frac=rank_frac, tt_dim=tt_dim, tt_rank=tt_rank)

x = torch.randn(batch_size, dim_in)
x = x.to(device)
model = model.to(device)

prof = torch.profiler.profile(
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./logs/prof/mvms_{struct}"),
    record_shapes=True,
    with_stack=True)
prof.start()
for idx in range(repeat_n):
    prof.step()
    y = model(x)
    loss = torch.mean(y)
    loss.backward()
prof.stop()
print("*=" * 50 + "\nDone\n" + "*=" * 50)
