import time
import numpy as np
import torch
import itertools
from nn.fcnet import CoLAMLP
from nn.cola_nn import get_builder_fn
from trainkit.saving import save_object
from trainkit.saving import ask_save_output
from trainkit.timing import print_time_taken

save_output = ask_save_output()
output_path = "./logs/bench.pkl"
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
torch.manual_seed(seed=21)
# width_s = [i ** 2 for i in range(15, 75, 4)]
# dim_in, dim_out, depth = 32 * 32 * 4, 10, 9
width_s = [2**i for i in range(5, 13)]
dim_in, dim_out, depth = 32 * 32 * 3, 10, 9
# width_s = [2 ** i for i in range(1, 5)]
# dim_in, dim_out, depth = 32 * 32 * 3, 10, 2
batch_size, repeat_n = 2048, 3
# struct_s = ["none", "bfly"]
struct_s = ["hbtt"]
cases = itertools.product(width_s, struct_s)
data = []


def compute_elapsed_gpu(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    fn()
    end.record()
    torch.cuda.synchronize()
    elapsed = start.elapsed_time(end) / 1000
    return elapsed


def compute_elapsed_cpu(fn):
    t0 = time.time()
    fn()
    t1 = time.time()
    elapsed = t1 - t0
    return elapsed


elapsed_fn = compute_elapsed_cpu if str(device) == "cpu" else compute_elapsed_gpu

tic = time.time()
for width, struct in cases:
    layers, rank_frac, tt_dim, tt_rank = "all_but_last", 0.1, 2, 1
    builder_fn = get_builder_fn(struct=struct, layers=layers, rank_frac=rank_frac, kron_mult=1, tt_dim=tt_dim,
                                tt_rank=tt_rank)
    model = CoLAMLP(dim_in, dim_out, depth, width, builder_fn)
    x = torch.randn(batch_size, dim_in)
    x = x.to(device)
    model = model.to(device)

    def fn():
        y = model(x)
        loss = torch.mean(y)
        loss.backward()

    times = np.zeros(repeat_n)
    for idx in range(repeat_n):
        elapsed = elapsed_fn(fn)
        times[idx] = elapsed
    data.append((struct, width, str(device), np.mean(times[1:]), np.std(times[1:]) / np.sqrt(len(times[1:]))))
    print(f"Struct: {struct}")
    print(f"Mean: {np.mean(times[1:]):1.3e} over {repeat_n:,d} times | Total {np.sum(times[1:]):1.3e}")
print_time_taken(time.time() - tic)

if save_output:
    save_object(data, output_path)
