import torch
import torch.utils.benchmark as benchmark
from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import CopyIfCallgrind

import matplotlib.pyplot as plt
import sys
import math
sys.path.append("../../")

from hollow.datamodules.basic_pde import Burgers
from hollow.tasks.default import DefaultModel
from hollow.models.fno import NeuralOperator1d
from hollow.models.t1 import T1_1d
from hollow.models.layers.fdm1d import DFTConv1d, DCTConv1d
from hollow.losses.relative_l2 import RelativeL2

from hollow.utils.numerics import dct1d
from hollow.train import train
from hollow.utils import utils

from omegaconf import DictConfig, OmegaConf, open_dict
import json
from tqdm.auto import tqdm
import os
import numpy as np

# Setup dict and metrics (i know, convoluted but better than nothing)
folder = '.'
file_name = "latent_1d_benchmark_results.json"
model_names = ['vanilla', 'latent', 'latent_dct']
width_bench =  np.linspace(64, 1024, 16, dtype=int)[::-1].tolist() #[128, 256, 512, 1024, 2048] #[4, 8, 16, 32, 64, 128, 256] # (maybe 128)
signal_bench = np.linspace(64, 1024, 16, dtype=int)[::-1].tolist() #[128, 256, 512, 1024, 2048] # (but with less layers)
n_layers = 4
threads = 32
benchmark_results = {} # use dictionary since it is easier to save

for m in model_names:
    for w in width_bench:
        try: benchmark_results[m].get(w)
        except: benchmark_results[m] = {}
        for s in signal_bench:
            try: benchmark_results[m][w].get(s)
            except: benchmark_results[m][w] = {}
            benchmark_results[m][w][s] = [] # dummy init



device = "cuda:2" # as string

# Repeat the experiment n times
repetition_number = 5

for _ in range(repetition_number):

    for width in tqdm(width_bench, "Width iter"):

        latent_operator = T1_1d(
            modes=12,
            padding=9,
            width=width,
            nlayers=n_layers,
            residual=True,
            keep_high=False,
            perform_inverse=False,
            weight_init=4,
            signal_resolution=1, # dummy res for init
            transform="dft"
        ).to(device)

        latent_dct_operator = T1_1d(
            modes=12,
            padding=9,
            width=width,
            nlayers=n_layers,
            residual=True,
            keep_high=False,
            perform_inverse=False,
            weight_init=4,
            signal_resolution=1, # dummy res for init
            transform="dct"
        ).to(device)


        vanilla_operator = NeuralOperator1d(
            modes=12,
            padding=9,
            width=width,
            nlayers=n_layers,
            residual=True,
            keep_high=False,
            spectral_layer=DFTConv1d,
            weight_init=4,
            signal_resolution=1, # dummy res for init
        ).to(device)

        # Iterate over all signal res for each width
        for signal_res in tqdm(signal_bench, leave=False, desc="Signal res iter"):

            ## Experimental: callgrind stuff 
            ## Commmented stuff are for trying
            # vanilla_operator = CopyIfCallgrind(vanilla_operator)
            # latent_operator = CopyIfCallgrind(latent_operator)
            # latent_dct_operator = CopyIfCallgrind(latent_dct_operator)

            # print(f"Timing with signal length: {signal_res}")
            t_vanilla = benchmark.Timer(
                setup='x = torch.randn(32, signal_res).to(torch.device("cuda:2"))',
                stmt='vanilla_operator(x)',
                globals={'signal_res': signal_res, 'vanilla_operator': vanilla_operator},
                num_threads=threads)

            t_latent = benchmark.Timer(
                setup='x = torch.randn(32, signal_res).to(torch.device("cuda:2"))',
                stmt='latent_operator(x)',
                globals={'signal_res': signal_res, 'latent_operator': latent_operator},
                num_threads=threads)

            t_latent_dct = benchmark.Timer(
                setup='x = torch.randn(32, signal_res).to(torch.device("cuda:2"))',
                stmt='latent_dct_operator(x)',
                globals={'signal_res': signal_res,  'latent_dct_operator': latent_operator},
                num_threads=threads)


            benchmark_results['vanilla'][width][signal_res].extend(t_vanilla.blocked_autorange().times)
            benchmark_results['latent'][width][signal_res].extend(t_latent.blocked_autorange().times)
            benchmark_results['latent_dct'][width][signal_res].extend(t_latent_dct.blocked_autorange().times)


# Save file
with open(os.path.join(folder, file_name), "w") as f:
    json.dump(benchmark_results, f,indent=2)
    f.close()