import torch
import torch.utils.benchmark as benchmark

import sys
sys.path.append("../../")

from hollow.models.fno import NeuralOperator2d
from hollow.models.t1 import T1_2d
from hollow.models.layers.fdm2d import DFTConv2d, DCTConv2d
from hollow.losses.relative_l2 import RelativeL2

from hollow.utils.numerics import dct2d
from hollow.train import train

from omegaconf import DictConfig, OmegaConf, open_dict
import json
from tqdm.auto import tqdm
import os
import numpy as np

# Setup dict and metrics (i know, convoluted but better than nothing)
folder = '.'
file_name = "latent_2d_benchmark_results_layers_highres.json"
model_names = ['vanilla', 'latent', 'latent_dct']
layer_bench = np.arange(1, 9, 1, dtype=int)[::-1].tolist()  # 9
signal_bench =  np.linspace(32, 256, 16, dtype=int)[::-1].tolist() # 16
width = 32 # keep fixed

threads = 32
benchmark_results = {} # use dictionary since it is easier to save

for m in model_names:
    for w in layer_bench: # note: changed name
        try: benchmark_results[m].get(w)
        except: benchmark_results[m] = {}
        for s in signal_bench:
            try: benchmark_results[m][w].get(s)
            except: benchmark_results[m][w] = {}
            benchmark_results[m][w][s] = [] # dummy init


device = torch.device("cuda:2")

# Repeat the experiment n times
repetition_number = 5

for _ in range(repetition_number):

    for nlayers in tqdm(layer_bench, "Width iter"):

        latent_operator = T1_2d(
            modes1=12,
            modes2=12,
            padding=9,
            width=width,
            nlayers=nlayers,
            residual=True,
            keep_high=False,
            perform_inverse=False,
            weight_init=4,
            signal_resolution=(1,1), # dummy res for init
            transform="dft"
        ).to(device)

        latent_dct_operator = T1_2d(
            modes1=12,
            modes2=12,
            padding=9,
            width=width,
            nlayers=nlayers,
            residual=True,
            keep_high=False,
            perform_inverse=False,
            weight_init=4,
            signal_resolution=(1,1), # dummy res for init
            transform="dct"
        ).to(device)


        vanilla_operator = NeuralOperator2d(
            modes1=12,
            modes2=12,
            padding=9,
            width=width,
            nlayers=nlayers,
            residual=True,
            keep_high=False,
            spectral_layer=DFTConv2d,
            weight_init=4,
            signal_resolution=(1,1), # dummy res for init
        ).to(device)


        for signal_res in tqdm(signal_bench, leave=False, desc="Signal res iter"):
            print(f"Timing with signal length: {signal_res}")
            t_vanilla = benchmark.Timer(
                setup='x = torch.randn(32, signal_res, signal_res).to(device)',
                stmt='vanilla_operator(x)',
                globals={'signal_res': signal_res, 'device': device, 'vanilla_operator': vanilla_operator},
                num_threads=threads)

            t_latent = benchmark.Timer(
                setup='x = torch.randn(32, signal_res, signal_res).to(device)',
                stmt='latent_operator(x)',
                globals={'signal_res': signal_res, 'device': device, 'latent_operator': latent_operator},
                num_threads=threads)

            t_latent_dct = benchmark.Timer(
                setup='x = torch.randn(32, signal_res, signal_res).to(device)',
                stmt='latent_dct_operator(x)',
                globals={'signal_res': signal_res, 'device': device, 'latent_dct_operator': latent_operator},
                num_threads=threads)

            benchmark_results['vanilla'][nlayers][signal_res].extend(t_vanilla.blocked_autorange().times)
            benchmark_results['latent'][nlayers][signal_res].extend(t_latent.blocked_autorange().times)
            benchmark_results['latent_dct'][nlayers][signal_res].extend(t_latent_dct.blocked_autorange().times)

# Save file
with open(os.path.join(folder, file_name), "w") as f:
    json.dump(benchmark_results, f,indent=2)
    f.close()