import time
from pathlib import Path
import numpy as np
import torch
from trainkit.saving import save_object
from trainkit.saving import ask_save_output
from trainkit.timing import print_time_taken
from benchmarking.bench_fns import compute_elapsed_cpu, compute_elapsed_gpu
from benchmarking.bench_fns import build_op
from benchmarking.bench_fns import get_flops

save_output = ask_save_output()
struct = "btt"
# struct = "dense"
output_dir = Path("./logs")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
dim_s = [50_000, 10_000, 5_000, 1_000, 500, 100, 50, 10]
# dim_s = [250_000, 100_000, 50_000, 10_000, 5_000, 1_000, 500, 100, 50, 10]
batch_size, repeat_n = 100, 3
struct_args = {"tt_rank": 1, "tt_dim": 2}
torch.manual_seed(seed=21)

elapsed_fn = compute_elapsed_cpu if str(device) == "cpu" else compute_elapsed_gpu
data = []

tic = time.time()
for dim in dim_s:
    dim_in = dim_out = dim
    X = torch.randn(batch_size, dim_in)
    X = X.to(device)
    W = build_op(struct, device, dim_in, dim_out, **struct_args)
    flops = get_flops(struct, batch_size, dim_in, dim_out, **struct_args)

    def fn():
        _ = X @ W

    times = np.zeros(repeat_n)
    for idx in range(repeat_n):
        elapsed = elapsed_fn(fn)
        times[idx] = elapsed
    mean_times = np.mean(times[1:])
    mean_sterr = np.std(times[1:]) / np.sqrt(len(times[1:]))
    print(f"Struct: {struct} | FLOPs: {flops:1.3e}")
    print(f"Mean: {mean_times:1.3e} over {repeat_n:,d} times | Total {np.sum(times[1:]):1.3e}")

    data.append((struct, dim_in, dim_out, batch_size, flops, str(device), mean_times, mean_sterr))

output_path = output_dir / Path("mvm_" + struct + "_" + str(device) + ".pkl")
if save_output:
    save_object(data, output_path)

print_time_taken(time.time() - tic)
