import time
import torch
from models.channel_vit import ChannelVisionTransformer as vitmae_encoder
from models.model_heads.mae_decoder import MAEDecoder
from models.simMiM import ChannelVisionTransformer as simmim_encoder
from models.model_heads.classification_layer import ClassificationLayer
from models.model_heads.patch_reconstruction_head import PatchReconstructionHead as simmim_decoder
from criterion.reconstruction_norm import ReconstructionL2
from datasets.hdf5_dataset import HDF5Loader
from datasets.tuab_dataset import TUABDataset
import tqdm
import numpy as np
import csv
import torch.nn as nn

def get_cuda_alloc():
    max_alloc = torch.cuda.max_memory_allocated(device='cuda')
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    return max_alloc

data_gathered = []
for batch_size in [4,32, 40, 64, 128]:
    for attn_idx, attention_type in enumerate(['default', 'alternating']): #, 'bottleneck','twoaxis', ]):
        time.sleep(5)
        simmim_enc = simmim_encoder(in_chans= 23,
                                    masking_ratio= 0.0,
                                    img_size=1280,
                                    patch_size=256,
                                    embed_dim=768,
                                    num_heads= 12,
                                    depth=12,
                                    using_spectrogram=False,
                                    attention_type=attention_type)

        classification_head = ClassificationLayer(num_classes=2, embed_dim=768)
        simmim_enc= nn.DataParallel(simmim_enc).cuda()
        classification_head= nn.DataParallel(classification_head).cuda()

        simmim_enc.eval()
        classification_head.eval()

        runtime = []
        avg_max_memory_allocated = []
        iters = 1
        init_alloc = get_cuda_alloc()
        
        for i in range(iters):
            ## inference
            x = torch.rand((batch_size, 23, 1280)).cuda()
            start = time.time()
            x = simmim_enc(x, mask_tokens=False)
            x = classification_head(x)
            post_forward = time.time()
            
            runtime.append(post_forward-start)
            avg_max_memory_allocated.append(get_cuda_alloc())
            del x

        print('\nattention type:\t',attention_type)
        print('iterations:\t',i)
        print('init alloc', init_alloc/1e6)
        print(f'average runtime:\t', np.mean(runtime))
        print(f'average max alloc:\t', np.mean(avg_max_memory_allocated)/1e6)
        data_gathered.append([attention_type, i, init_alloc/1e6, np.mean(runtime), np.mean(avg_max_memory_allocated)/1e6, batch_size])
        del simmim_enc
        del classification_head
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        np.save('', data_gathered)
        writer = csv.writer(open("/cluster/work/cvl/eeg_foundation/datasets/simmim_wave_inference_stats_fused_few_tokens_training.csv", 'w'))
        writer.writerow(['attention_type', 'nr_iterations', 'init_alloc(MiB)', 'runtime(s)', 'avg_max_mem_alloc(MiB)', 'batch_size'])
        for row in data_gathered:
            writer.writerow(row)
