from os.path import isdir
import torch
from argparse import ArgumentParser
from imputation_model.model import PyPOTS_IMP_Model
from my_utils.utils import PYPOTS_IMP_MODEL_TYPES, setup_seed, save_json, mkdir
from einops import rearrange, repeat
from dataset.dataset import DatasetTSC, DATA_ROOT_TSC
from torch.utils.data import DataLoader
import os
import pandas as pd


def get_mask(context,mask_len,step):
    '''
    Input:
        context.shape == [L]
    Return:
        context.shape == [num_mask,L]
        mask.shape == [num_mask,L]
    '''
    assert len(context.shape) == 1
    context_len = context.shape[0]
    num_mask = (context_len - mask_len)//step + 1
    context = repeat(context,'t -> n t',n=num_mask)
    mask = torch.ones(context.shape).to(dtype=torch.float32,device=context.device)

    for mask_index in range(num_mask):
        start = mask_index * step
        mask[mask_index, start:start+mask_len] = 0
    return context, mask


def get_args():
    parser = ArgumentParser()
    parser.add_argument('--imp_model', type=str, default='result_imputation_0806/DistalPhalanxTW/saits/mask_length=8/filter=True/checkpoint/best.pt')
    parser.add_argument('--cls_model', type=str, default='result-tsc-0814/DistalPhalanxTW/fcn/filter=True/rt_noise=0.0/checkpoint/best.pt')
    parser.add_argument('--dataset_name', type=str, default='DistalPhalanxTW')
    parser.add_argument('--imp_type', type=str, default='saits')
    parser.add_argument('--cls_type', type=str, default='fcn')
    parser.add_argument('--mask', type=int, default=15)
    parser.add_argument('--step', type=int, default=2)
    parser.add_argument('--num_sample', type=int, default=101)
    parser.add_argument('--use_filter', type=int, default=1)
    parser.add_argument('--window', type=int, default=15)
    parser.add_argument('--order', type=int, default=5)
    parser.add_argument('--output_dir', type=str, default='result-iac-debug')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--check_acc', type=int, default=1)
    parser.add_argument('--acc_record_path', type=str, default='tmp.csv')


    args = parser.parse_args()
    return args


def post_process_args(args):
    args.use_filter = True if args.use_filter else False
    args.check_acc = True if args.check_acc else False
    args.attack_len = args.mask - args.step + 1
    return args


def main(args):
    setup_seed()
    with torch.no_grad():
        if args.imp_type == 'mixer':
            imp_model = torch.load(args.imp_model).eval().to(args.device)
        elif args.imp_type in PYPOTS_IMP_MODEL_TYPES:
            imp_model = PyPOTS_IMP_Model(torch.load(args.imp_model))
        else:
            raise NotImplementedError(f'imputation model type {args.imp_type} is not implemented')
            
        cls_model = torch.load(args.cls_model).eval().to(args.device)

        # normalization already done in dataset
        dataset = DatasetTSC(DATA_ROOT_TSC,args.dataset_name,'test','instance',args.use_filter,args.window,args.order)
        dataloader = DataLoader(dataset,batch_size=1,shuffle=False,drop_last=True)

        ans = []
        for i, (context,label) in enumerate(dataloader):
            i += 1
            context = context.to(args.device)
            context = rearrange(context,'1 t -> t')

            # [N,L]
            context, mask = get_mask(context, args.mask, args.step)
            num_mask = context.shape[0]
            label = repeat(label,'1 -> n 1',n=num_mask)

            context_imp, _ = imp_model(context, mask)
            label_pred = cls_model(context_imp).argmax(dim=-1)

            d = {
                "context_imp": context_imp.cpu(),
                "context": context[0].cpu(),
                "label": label.cpu(),
                "label_pred": label_pred.cpu()
            }
            ans.append(d)
            if i >= args.num_sample:
                break
        mkdir(args.output_dir)
        torch.save(ans,os.path.join(args.output_dir,'ans.pt'))
        save_json(vars(args), os.path.join(args.output_dir, 'args.json'))

        if args.check_acc:
            args.acc = check_iac_acc(ans)
            try:
                df = pd.read_csv(args.acc_record_path)
            except:
                df = pd.DataFrame()
            df = pd.concat([df,pd.DataFrame([vars(args)])])
            df.to_csv(args.acc_record_path, index=False)


def check_iac_acc(ans:list):
    '''
    ans is a list of d, where d is like:
        d = {
            "context_imp": context_imp.cpu(),  # [num_mask,L]
            "context": context[0].cpu(),       # [L,]
            "label": label.cpu(),              # [1,]
            "label_pred": label_pred.cpu()     # [num_mask,]
        }
    '''
    num_correct = 0
    num_sample  = len(ans)

    for d in ans:
        if torch.all(d['label_pred']==d['label']):
            num_correct += 1
    return num_correct / num_sample


if __name__ == "__main__":
    ################################## run a single
    # args = get_args()
    # args = post_process_args(args)
    # main(args)
    # exit()

    ################################## run a batch
    from my_utils.utils import DATASET_CLASSIFICATION, TSC_MODEL_TYPES, MASK_LENGTH, ABLATION_STUDY_IMP_MODEL_TYPES
    # imp_dir = 'result_imputation_0806'
    imp_dir = 'result_imputation_random_mask-0921'
    tsc_dir = 'result-tsc-0826'
    iac_dir = 'result-iac-random_mask_imp-0921'

    args = get_args()
    args = post_process_args(args)
    args.num_sample = 101
    args.device = 'cuda:2'
    args.check_acc = True
    args.acc_record_path = 'record-iac-random_mask_imp-0921.csv'

    for dataset_name in DATASET_CLASSIFICATION:
        # for imp_type in ABLATION_STUDY_IMP_MODEL_TYPES:
        for imp_type in ['mixer']:
            for cls_type in TSC_MODEL_TYPES:
                for use_filter in [True]:
                    for rt_noise in [0.1]:
                        for mask in MASK_LENGTH:
                            for step in range(1,mask+1):
                                # args.imp_model = f'{imp_dir}/{dataset_name}/{imp_type}/mask_length={mask}/filter={use_filter}/checkpoint/best.pt'
                                args.imp_model = f'{imp_dir}/{dataset_name}/{imp_type}/mask_length={mask}/filter={use_filter}/rt_noise={rt_noise}/len_clip=1.0/checkpoint/best.pt'
                                args.cls_model = f'{tsc_dir}/{dataset_name}/{cls_type}/filter={use_filter}/rt_noise={rt_noise}/checkpoint/best.pt'
                                args.dataset_name = dataset_name
                                args.cls_type = cls_type
                                args.imp_type = imp_type
                                args.mask = mask
                                args.step = step
                                args.rt_noise = rt_noise
                                args.use_filter = use_filter
                                args.output_dir = f'{iac_dir}/{dataset_name}/imp={args.imp_type}/cls={cls_type}/mask={mask}/step={step}/use_filter={use_filter}/rt_noise={rt_noise}'

                                if isdir(args.output_dir):
                                    continue

                                args = post_process_args(args)
                                main(args)
                                print(args.output_dir)


