import torch
import pandas as pd
import numpy as np
from einops import rearrange, repeat, reduce
import time
from tqdm import tqdm
from my_utils.utils import DATASET_NAMES, MASK_LENGTH, PRE_MODEL_TYPES, RS_LENS, RS_TYPES, TSC_MODEL_TYPES, get_dataset_params
import torchvision
import os

def warm_up(device, warm_up_iter:int=100):
    print(f'Warming up...')
    with torch.no_grad():
        model = torchvision.models.resnet.resnet18()
        model.to(device)
        model.eval()
        x = torch.randn(32,3,512,512).to(device)
        for _ in tqdm(range(warm_up_iter)):
            model(x)


def test_speed_tsf_rs(model, ctx_len, tgt_len, device, rs_type, rs_len, rs_num, test_iter:int=100, warmup_iter:int=0):
    with torch.no_grad():
        model.to(device)
        model.eval()

        ctx = torch.randn((1,ctx_len))
        tgt = torch.randn((1,tgt_len))
        ctx = ctx.to(device)
        tgt = tgt.to(device)
        
        ctx_len = ctx.shape[1]
        if rs_type == 'random':
            pass
        elif rs_type == 'block':
            rs_num = ctx_len - rs_len + 1
        else:
            raise NotImplementedError
        mask = np.zeros((rs_num, ctx_len))

        if rs_type == 'random':
            for i in range(rs_num):
                index = np.random.choice(np.arange(ctx_len), rs_len, replace=False)
                mask[i,index] = 1
        elif rs_type == 'block':
            num_block = ctx_len // rs_len
            for i in range(rs_num):
                block_idx = np.random.randint(0, num_block)
                mask[i, block_idx*rs_len : (block_idx+1)*rs_len] = 1
        else:
            raise NotImplementedError

        mask = torch.from_numpy(mask).to(dtype=ctx.dtype, device=ctx.device)
        ctx = repeat(ctx, '1 t -> n t', n=rs_num)
        tgt = repeat(tgt, '1 t -> n t', n=rs_num)
        ctx = ctx * mask

        for _ in range(warmup_iter):
            pred = model(ctx,tgt)

        s = time.time()
        for _ in tqdm(range(test_iter)):
            pred = model(ctx,tgt)
        e = time.time()
        return (e - s) / test_iter


def test_speed_tsf(imp_model, pre_model, ctx_len, tgt_len, device, mask_len, step, test_iter:int=100, warmup_iter:int=0):
    with torch.no_grad():
        imp_model.to(device)
        pre_model.to(device)
        imp_model.eval()
        pre_model.eval()

        num_mask = (ctx_len - mask_len) // step + 1
        ctx = torch.randn(num_mask, ctx_len).to(device)
        mask = torch.ones_like(ctx).to(device)
        tgt = torch.randn(num_mask, tgt_len).to(device)

        for i in range(warmup_iter):
            imp, _ = imp_model(ctx,mask)
            pred = pre_model(imp,tgt)

        s = time.time()
        for i in tqdm(range(test_iter)):
            imp, _ = imp_model(ctx,mask)
            pred = pre_model(imp,tgt)
        e = time.time()

        return (e-s)/test_iter


def test_speed_tsc(imp_model, cls_model, ctx_len, device, mask_len, step, test_iter:int=100, warmup_iter:int=0):
    with torch.no_grad():
        imp_model.to(device)
        cls_model.to(device)
        imp_model.eval()
        cls_model.eval()

        num_mask = (ctx_len - mask_len) // step + 1
        ctx = torch.randn(num_mask, ctx_len).to(device)
        mask = torch.ones_like(ctx).to(device)
    
        for i in range(warmup_iter):
            imp, _ = imp_model(ctx, mask)
            cls_model(imp)

        s = time.time()
        for i in tqdm(range(test_iter)):
            imp, _ = imp_model(ctx, mask)
            cls_model(imp)
        e = time.time()
        return (e-s)/test_iter


def test_speed_tsc_rs(model, ctx_len, device, rs_type, rs_len, rs_num, test_iter:int=100, warmup_iter:int=0):
    with torch.no_grad():
        model.to(device).eval()
        context = torch.randn(1,ctx_len).to(device)
        if rs_type == 'random':
            pass
        elif rs_type == 'block':
            rs_num = ctx_len - rs_len + 1
        else:
            raise NotImplementedError
        mask = np.zeros((rs_num, context.shape[1]))
        if rs_type == 'random':
            for i in range(rs_num):
                index = np.random.choice(np.arange(context.shape[1]), rs_len, replace=False)
                mask[i,index] = 1
        elif rs_type == 'block':
            block_len = rs_len
            num_block = context.shape[1] // block_len
            for i in range(rs_num):
                index = np.random.randint(0,num_block)
                mask[i,index*block_len : (index+1)*block_len] = 1
        else:
            raise NotImplementedError(f'RS type {rs_type} is not implemented')
        context = repeat(context, '1 t -> n t', n=rs_num).clone()
        mask = torch.from_numpy(mask).to(dtype=context.dtype, device=context.device)
        context = mask * context

        for i in range(warmup_iter):
            model(context)
        
        s = time.time()
        for i in tqdm(range(test_iter)):
            model(context)
        e = time.time()

        return (e-s)/test_iter



if __name__ == '__main__':
    ##################################################### TSC
    device = 'cuda:0'
    END = 100
    warm_up(device)
    csv_path = 'record-tsc-time-0831.csv'
    if os.path.isfile(csv_path):
        raise RuntimeError(f'{csv_path} already exists')
    rs_num = 10000
    dataset_name = 'DistalPhalanxTW'
    ctx_len, _ = get_dataset_params(dataset_name)
    df = pd.DataFrame()

    for cls_type in TSC_MODEL_TYPES[0:END]:
        for mask_len in MASK_LENGTH[0:END]:
            for step in range(1,mask_len+1)[0:END]:
                imp_model = torch.load(f'result_imputation_0806/{dataset_name}/mixer/mask_length={mask_len}/filter=True/checkpoint/best.pt')
                cls_model = torch.load(f'result-tsc-0826/{dataset_name}/{cls_type}/filter=True/rt_noise=0.0/checkpoint/best.pt')
                duration = test_speed_tsc(imp_model, cls_model, ctx_len, device, mask_len, step)
                line = {
                    "cls_type": cls_type,
                    "def_len": mask_len,
                    "atk_len": mask_len - step + 1,
                    "time": duration,
                    "method": 'MIA',
                }
                df = pd.concat([df, pd.DataFrame([line])])

    
    for cls_type in TSC_MODEL_TYPES[0:END]:
        for rs_type in RS_TYPES[0:END]:
            for rs_len in RS_LENS[0:END]:
                cls_model = torch.load(f'result_tsc_rs-0818/{dataset_name}/{cls_type}/{rs_type}/{rs_len}/checkpoint/final.pt')
                duration = test_speed_tsc_rs(cls_model, ctx_len, device, rs_type, rs_len, rs_num)
                line = {
                    "cls_type": cls_type,
                    "def_len": rs_len,
                    "time": duration,
                    "method": rs_type,
                }
                df = pd.concat([df, pd.DataFrame([line])])
    
    df.to_csv(csv_path, index=None)


    ##################################################### TSF
    csv_path = 'record-tsf-time-0831.csv'
    if os.path.isfile(csv_path):
        raise RuntimeError(f'{csv_path} already exists')

    device = 'cuda:0'
    warm_up(device)
    dataset_name = 'electricity_nips'
    rs_num = 10000

    df = pd.DataFrame()
    for pre_type in PRE_MODEL_TYPES:
        for mask_len in MASK_LENGTH:
            for step in range(1,mask_len+1):
                imp_type = 'mixer'
                rt_noise = 0.0
                len_clip = 1.0
                imp_model = torch.load(f'result_imputation_0827/{dataset_name}/{imp_type}/mask_length={mask_len}/filter=True/rt_noise={rt_noise}/len_clip={len_clip}/checkpoint/best.pt')
                pre_model = torch.load(f'result_prediction_0827/{dataset_name}/{pre_type}/filter=True/rt_noise={rt_noise}/len_clip={len_clip}/checkpoint/best.pt')
                ctx_len, tgt_len = get_dataset_params(dataset_name,len_clip)
                duration = test_speed_tsf(imp_model, pre_model, ctx_len, tgt_len, device, mask_len, step)
                line = {
                    "pre_type": pre_type,
                    "def_len": mask_len,
                    "atk_len": mask_len - step + 1,
                    "time": duration,
                    "method": 'MIA',
                }
                df = pd.concat([df, pd.DataFrame([line])])

    for pre_type in PRE_MODEL_TYPES:
        for rs_type in RS_TYPES:
            for rs_len in RS_LENS:
                model = torch.load(f'result-pred-rs-0818/{dataset_name}/{pre_type}/{rs_type}/{rs_len}/checkpoint/best.pt')
                ctx_len, tgt_len = get_dataset_params(dataset_name)
                duration = test_speed_tsf_rs(model, ctx_len, tgt_len, device, rs_type, rs_len, rs_num)
                line = {
                    "pre_type": pre_type,
                    "def_len": rs_len,
                    "time": duration,
                    "method": rs_type,
                }
                df = pd.concat([df, pd.DataFrame([line])])
    
    df.to_csv(csv_path, index=None)