from copy import deepcopy
import torch
from argparse import ArgumentParser
from tqdm import tqdm
from my_utils.utils import MASK_LENGTH, save_json, Logger, setup_seed, mkdir
from os.path import join
from classification_model.model import MLPModel as MLP
from classification_model.model import ResNet
from classification_model.model import FCN
from classification_model.mixer import MLPMixer
from dataset.dataset import DatasetTSC
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import time
import numpy as np
from einops import rearrange, repeat, reduce
from my_utils.utils import DATASET_CLASSIFICATION, RS_LENS, RS_TYPES, TSC_MODEL_TYPES
from my_utils.certify_rs import certify_block, certify_random


def get_args():
    parser = ArgumentParser()

    parser.add_argument('--dataset_name', type=str, default='DistalPhalanxTW')
    parser.add_argument('--output_dir', type=str, default='result-classify-rs-debug')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--num_epoch', type=int, default=10)
    parser.add_argument('--num_layer', type=int, default=4)
    parser.add_argument('--hidden_dim', type=int, default=96)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--use_filter', type=int, default=1)
    parser.add_argument('--window', type=int, default=15)
    parser.add_argument('--order', type=int, default=5)
    parser.add_argument('--model_type', type=str, default='mlp')
    parser.add_argument('--checkpoint', type=str, default='')
    parser.add_argument('--train', type=int, default=1)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--rt_noise', type=float, default=0.0)
    parser.add_argument('--dropout', type=float, default=0.3)
    parser.add_argument('--lr_decay_rate', type=float, default=0.98)
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--normalization_dataset', type=str, default='instance')
    parser.add_argument('--normalization_model', type=str, default='ln')
    parser.add_argument('--record_path', type=str, default='record-tsc-rs-0818.csv')
    parser.add_argument('--activation', type=str, default='relu')
    parser.add_argument('--rs_len', type=int, default=10)
    parser.add_argument('--rs_num_sample', type=int, default=100)
    parser.add_argument('--rs_type', type=str, default='block')

    args = parser.parse_args()
    return args

def post_process_args(args):
    args.use_filter = True if args.use_filter else False
    args.train = True if args.train else False
    args.num_epoch = args.num_epoch if args.train else 1

    return args


def create_dataloader(dataset_name, mode, normalization_method, use_filter, window, order, batch_size, num_workers, shuffle, drop_last):
    dataloader = DataLoader(
        DatasetTSC(
            dataset_name=dataset_name,
            mode = mode,
            normalization_method=normalization_method,
            use_filter=use_filter,
            window=window,
            order=order,
        ),
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        drop_last=drop_last
    )
    return dataloader


def create_model(args, context_length, num_class):
    '''
    args
        model_type
        hidden_dim
        num_layer
        normalization_model
        dropout
        activation
    '''
    if args.model_type == 'mlp':
        model = MLP(context_length, args.hidden_dim, num_class, args.num_layer, args.normalization_model, args.dropout, args.activation)
    elif args.model_type == 'resnet18':
        model = ResNet(args.model_type, num_class)
    elif args.model_type == 'fcn':
        model = FCN(args.hidden_dim, args.num_layer, args.dropout, num_class, args.activation)
    elif args.model_type == 'mixer':
        model = MLPMixer(context_length, args.hidden_dim, args.num_layer, num_class, args.dropout)
    else:
        raise NotImplementedError(f'Model type: {args.model_type} is not implemented')
    return model


def main(args):
    setup_seed()
    mkdir(args.output_dir)
    ck_dir = join(args.output_dir,'checkpoint')
    mkdir(ck_dir)
    save_json(vars(args),join(args.output_dir,'args.json'))

    test_logger = Logger(join(args.output_dir,'test.txt'))
    train_logger = Logger(join(args.output_dir,'train.txt'))
    summary_writer = SummaryWriter(args.output_dir)

    train_dataloader = create_dataloader(
        args.dataset_name,
        'train',
        args.normalization_dataset,
        args.use_filter,
        args.window,
        args.order,
        args.batch_size,
        args.num_workers,
        True,
        True,
    )

    test_dataloader = create_dataloader(
        args.dataset_name,
        'test',
        args.normalization_dataset,
        args.use_filter,
        args.window,
        args.order,
        args.batch_size,
        args.num_workers,
        False,
        True,
    )

    context_length, num_class = test_dataloader.dataset.get_info()

    test_logger.log('{:<20} == {}'.format('dataset_name',args.dataset_name))
    test_logger.log('{:<20} == {}'.format('num_class',num_class))
    test_logger.log('{:<20} == {}'.format('context_length',context_length))
    test_logger.log('{:<20} == {}'.format('use_filter',args.use_filter))
    test_logger.log('{:<20} == {}'.format('length_train',len(train_dataloader.dataset)))
    test_logger.log('{:<20} == {}'.format('length_test',len(test_dataloader.dataset)))
    if args.use_filter:
        test_logger.log('{:<20} == {}'.format('window',args.window))
        test_logger.log('{:<20} == {}'.format('order',args.order))


    if args.checkpoint != '':
        test_logger.log(f'Loading checkpoint {args.checkpoint}...')
        model = torch.load(args.checkpoint)
    else:
        model = create_model(args, context_length, num_class)


    test_logger.log(f'Model type is {str(type(model))}')
    model.to(args.device)

    optimizer = None
    scheduler = None
    if args.train:
        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),lr=args.lr)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
        elif args.optimizer == 'adadelta':
            optimizer = torch.optim.Adadelta(model.parameters(),lr=args.lr)
        elif args.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(),lr=args.lr)
        else:
            raise NotImplementedError
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,args.lr_decay_rate,verbose=True)


    best_acc = 0
    for epoch in (range(args.num_epoch)):
        if args.train:
            run_epoch('train',epoch,model,train_dataloader,optimizer,scheduler,train_logger,summary_writer,args)
        loss, acc = run_epoch('test',epoch,model,test_dataloader,None,None,test_logger,summary_writer,args)
        if acc > best_acc:
            best_acc = acc
            model.cpu()
            torch.save(model,join(ck_dir,'best.pt'))
            model.to(args.device)
    
    torch.save(model.cpu(), join(ck_dir,'final.pt'))
    args.acc = best_acc

    try:
        df = pd.read_csv(args.record_path)
    except:
        df = pd.DataFrame()
    df = pd.concat([df,pd.DataFrame([vars(args)])])
    df.to_csv(args.record_path,index=False)
    

def compute_loss(y,y_hat):
    '''
    y.shape = [B,]
    y_hat.shape = [B,num_class]
    '''
    loss_fn = torch.nn.CrossEntropyLoss()
    return loss_fn(y_hat,y)


def run_epoch(
    mode,
    epoch,
    model,
    dataloader,
    optimizer,
    scheduler,
    logger,
    summary_writer,
    args,
):  
    time_start = time.time()
    if mode == 'train':
        model.train()
        train = True
    else:
        model.eval()
        train = False
    
    num_sample = num_true = total_loss = num_batch = 0
    
    with torch.set_grad_enabled(train):
        for context, label in dataloader:
            context = context.to(args.device)
            label = label.to(args.device)

            # random smoothing
            mask = np.zeros(context.shape[1])
            if args.rs_type == 'random':
                index = np.random.choice(np.arange(context.shape[1]), args.rs_len, replace=False)
                mask[index] = 1
            elif args.rs_type == 'block':
                num_block = context.shape[1] - args.rs_len + 1
                block_idx = np.random.randint(0,num_block)
                mask[block_idx : block_idx+args.rs_len] = 1
            else:
                raise NotImplementedError(f'RS type {args.rs_type} is not implemented')

            mask = repeat(mask, 't -> b t', b=context.shape[0])
            mask = torch.from_numpy(mask).to(dtype=context.dtype, device=context.device)
            context = context * mask

            logits = model(context)
            loss = compute_loss(label,logits)

            if train:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()
            num_batch += 1
            num_sample += label.shape[0]
            num_true += (logits.argmax(dim=-1) == label).sum().item()


    if train:
        scheduler.step()

    acc = num_true / num_sample
    total_loss /= num_batch
    summary_writer.add_scalar(f'Loss/{mode}',total_loss,epoch)
    summary_writer.add_scalar(f'ACC/{mode}',acc,epoch)

    duration = time.time() - time_start
    logger.log('Mode={:<5}\tEpoch={:<3}\tLoss={:<6.3f}\tACC={:.3f}\tTime={:<.3f}(s)'.format(mode,epoch,total_loss,acc,duration))

    return total_loss, acc


def classify_and_save(model, dataloader, rs_type, rs_len, rs_num_sample, device, output_dir):
    with torch.no_grad():
        model.to(device)
        model.eval()
        ans = []
        # context.shape == [1,T]
        # label.shape == [1,]
        for context, label in dataloader:
            context = context.to(device)
            label = label.to(device)
            if rs_type == 'random':
                mask = np.zeros((rs_num_sample, context.shape[1]))
                for i in range(rs_num_sample):
                    index = np.random.choice(np.arange(context.shape[1]), rs_len, replace=False)
                    mask[i,index] = 1
            elif rs_type == 'block':
                block_len = rs_len
                num_block = context.shape[1] - block_len + 1
                mask = np.zeros((num_block, context.shape[1]))
                for i in range(num_block):
                    index = np.random.randint(0,num_block)
                    mask[index, index:index+block_len] = 1
            else:
                raise NotImplementedError(f'RS type {rs_type} is not implemented')

            if rs_type == 'random':
                n = rs_num_sample
            elif rs_type == 'block':
                n = num_block
            else:
                raise NotImplementedError
            context = repeat(context, '1 t -> n t', n=n).clone()
            label = repeat(label, '1 -> n', n=n).clone()
            mask = torch.from_numpy(mask).to(dtype=context.dtype, device=context.device)
            context = mask * context
            logits = model(context)
            pred = logits.argmax(dim=-1)

            ans.append({
                'label': label[0].cpu(),
                'pred': pred.cpu()
            })
        mkdir(output_dir)
        if rs_type == 'random':
            torch.save(ans, join(output_dir, f'ans-rs_num={rs_num_sample}.pt'))
        else:
            torch.save(ans, join(output_dir, f'ans.pt'))



def evaluate(ans, rs_type, rs_len, atk_len):
    '''
    ans is a list of dict, dict is like
        {
            "label": shape = [0,]
            "pred": shape = [rs_num,]
        }
    
    return acc
    '''
    num_samples = len(ans)
    acc = 0

    for item in ans:
        label = item['label']
        pred = item['pred']
        binmap = torch.bincount(pred)

        try:
            v,i = torch.topk(binmap,k=2)
        except:
            v,i = torch.topk(binmap,k=1)

        top_class = i[0]
        top_weight = v[0]
        if v.shape[0] >= 2:
            second_class = i[1]
            second_weight = v[1]
        else:
            second_class = second_weight = 0

        if rs_type == 'random':
            condition = certify_random(top_weight, pred.shape[0], rs_len, atk_len)
        elif rs_type == 'block':
            condition = certify_block(top_weight, second_weight, rs_len, atk_len)
        else:
            raise NotImplementedError
        
        if top_class == label and condition:
            acc += 1
        
    return acc / num_samples



if __name__ == '__main__':
    ################################################ run a single train
    args = get_args()
    args = post_process_args(args)
    main(args)
    exit(0)


    ################################################ run a batch train
    # result_dir = "result_tsc_rs-0818"
    # args = get_args()
    # for dataset_name in DATASET_CLASSIFICATION:
    #     for rs_len in RS_LENS:
    #         for rs_type in RS_TYPES:
    #             for tsc_model in TSC_MODEL_TYPES:
    #                 args.dataset_name = dataset_name
    #                 args.output_dir = f'{result_dir}/{dataset_name}/{tsc_model}/{rs_type}/{rs_len}'
    #                 args.model_type = tsc_model
    #                 args.rs_len = rs_len
    #                 args.rs_type = rs_type
    #                 args = post_process_args(args)
    #                 main(args)


    ################################################ run a batch for classify and save
    # print('classify and save')
    # args = get_args()
    # args.device = 'cuda:3'
    # args.rs_num_sample = 10000
    # args = post_process_args(args)
    # result_dir = 'result_tsc_rs-0818'
    # for dataset_name in DATASET_CLASSIFICATION:
    #     args.dataset_name = dataset_name
    #     dataloader = create_dataloader(
    #         args.dataset_name,
    #         'test',
    #         args.normalization_dataset,
    #         args.use_filter,
    #         args.window,
    #         args.order,
    #         1,
    #         args.num_workers,
    #         False,
    #         True
    #     )
    #     ctx_len, num_class = dataloader.dataset.get_info()
    #     for model_type in TSC_MODEL_TYPES:
    #         for rs_type in RS_TYPES:
    #             for rs_len in RS_LENS:
    #                 args.model_type = model_type
    #                 args.rs_type = rs_type
    #                 args.rs_len = rs_len
    #                 model = create_model(args,ctx_len,num_class)
    #                 output_dir = f'{result_dir}/{dataset_name}/{model_type}/{rs_type}/{rs_len}'
    #                 classify_and_save(model, dataloader, args.rs_type, args.rs_len, args.rs_num_sample, args.device, output_dir)
    #                 print(output_dir)






    ####################################### evaluate
    print('evaluate')
    import os
    rs_num = 10000
    result_dir = 'result_tsc_rs-0818'
    csv_path = f'record-tsc-rs-metric-rs_num={rs_num}-more_atks.csv'
    if os.path.isfile(csv_path):
        raise RuntimeError(f'{csv_path} already exists')

    df = pd.DataFrame()
    for dataset_name in DATASET_CLASSIFICATION:
        for model_type in TSC_MODEL_TYPES:
            for rs_type in RS_TYPES:
                for rs_len in RS_LENS:
                    if rs_type == 'random':
                        ans_path = f'{result_dir}/{dataset_name}/{model_type}/{rs_type}/{rs_len}/ans-rs_num={rs_num}.pt'
                    elif rs_type == 'block':
                        ans_path = f'{result_dir}/{dataset_name}/{model_type}/{rs_type}/{rs_len}/ans.pt'
                    else:
                        raise NotImplementedError
                    ans = torch.load(ans_path)
                    for atk_len in range(1,max(MASK_LENGTH)+1):
                        acc = evaluate(deepcopy(ans), rs_type, rs_len, atk_len)
                        line = {
                            "dataset_name": dataset_name,
                            "model_type": model_type,
                            "rs_type": rs_type,
                            "rs_len": rs_len,
                            "atk_len": atk_len,
                            "acc": acc,
                        }
                        df = pd.concat([df, pd.DataFrame([line])])
                    print(ans_path)
    df.to_csv(csv_path, index=None)

