import torch
from torch.profiler import profile, record_function, ProfilerActivity
from nn import MLP
from nn.cola_nn import colafy


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(seed=21)
batch_size = 68
# batch_size = 1024
# dim_in, dim_out, depth, width = 32 * 32 * 3, 10, 9, 2 ** 13
dim_in, dim_out, depth, width = 32 * 32 * 3, 10, 9, 2 ** 8
model = MLP(dim_in, dim_out, depth, width)
# struct, layers, rank_frac, tt_dim, tt_rank = "none", "all_but_last", 0.1, 2, 1
struct, layers, rank_frac, tt_dim, tt_rank = "low_rank", "all_but_last", 0.1, 2, 1
# struct, layers, rank_frac, tt_dim, tt_rank = "block_tt", "all_but_last", 0.1, 2, 1
colafy(model, struct=struct, layers=layers, rank_frac=rank_frac, tt_dim=tt_dim, tt_rank=tt_rank)

x = torch.randn(batch_size, dim_in)
x = x.to(device)
model = model.to(device)

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):
        model(x)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
