# train and eval prediction model with random smoothing
from argparse import ArgumentParser
from copy import deepcopy
from os.path import isdir, join
import numpy as np
from dataset.dataset import MyDataset
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from my_utils.utils import DELTAS, MASK_LENGTH, PRE_MODEL_TYPES, RS_LENS, RS_TYPES, mkdir, save_json, Logger, get_dataset_params, DATASET_NAMES
from my_utils.utils import get_metric, setup_seed
from my_utils.utils import GenerateData2 as GD
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt
from tqdm import tqdm
import random
import pandas as pd
from prediction_model.model import MLPMixer as Mixer
from prediction_model.model import MLPModel3 as MLP
from prediction_model.model import LSTM3 as LSTM
from prediction_model.model import TemporalConvNet as TCN
from prediction_model.model import LinearModel, ResNet18, Transformer, ResNet34, RNN, GRU
import time
from einops import rearrange, reduce, repeat
from my_utils.certify_rs import certify_random, certify_block


def get_args():
    parser = ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default="electricity_nips")

    parser.add_argument("--output_dir", type=str, \
        # default="result-pred-rs-debug-" + str(time.time()))
        default="result-pred-rs-debug")

    parser.add_argument("--model_type", type=str, default="mlp")
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument("--num_epoch", type=int, default=10)


    # not set very frequently
    parser.add_argument("--epoch_gap", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--lr_decay_rate", type=float, default=0.95)
    parser.add_argument("--only_eval", type=int, default=0)
    parser.add_argument("--checkpoint", type=str, default="")
    parser.add_argument("--num_train_batch", type=int, default=500)
    parser.add_argument("--num_test_batch", type=int, default=100)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--optimizer", type=str, default="adam")
    parser.add_argument("--record_path", type=str, default="record-prediction-rs.csv")


    # Model settings
    parser.add_argument("--hidden_dim", type=int, default=33)
    parser.add_argument("--num_layer", type=int, default=4)
    parser.add_argument("--drop_out", type=float, default=0.1)
    parser.add_argument("--norm", type=str, default="ln", help="normalization method")
    parser.add_argument("--pred_weight", type=float, default=1.0)
    parser.add_argument("--mode", type=str, default='padding', help='mode for rnn-based model')
    parser.add_argument("--activation", type=str, default='relu', help='activation for mlp')


    # random training
    parser.add_argument("--RT", type=int, default=0, help="When training model, add noise")
    parser.add_argument("--noise_strength", type=float, default=0.0, help="Noise for Random training")

    # random smoothing
    parser.add_argument("--rs_type", type=str, default='block')
    parser.add_argument("--rs_len", type=int, default=16)


    # filter
    parser.add_argument("--use_filter", type=int, default=1)
    parser.add_argument("--window", type=int, default=15)
    parser.add_argument("--order", type=int, default=5)

    args = parser.parse_args()

    return args


def post_process_args(args):
    # modify other args
    args.context_length, args.prediction_length = get_dataset_params(args.dataset_name)
    if args.only_eval:
        args.num_epoch = 1
    args.train = False if args.only_eval else True
    args.norm = args.norm.lower()
    args.model_type = args.model_type.lower()
    args.mode = args.mode.lower()
    args.use_filter = True if args.use_filter else False
    return args


def draw(target,prediction,fig_path,summary_writer,epoch,mode):
    row = 3
    col = 4
    fig, axes = plt.subplots(row,col,figsize=(8,4.5))
    # fig, axes = plt.subplots(row,col)
    T = np.arange(target.shape[-1])
    for i in range(min(target.shape[0], row*col)):
        r = i // col
        c = i % col
        axes[r][c].plot(T,target[i],'-b',label='gt')
        axes[r][c].plot(T,prediction[i],'-r',label='pred')
        axes[r][c].legend()
    summary_writer.add_figure(
        f'{mode}',
        fig,
        epoch
    )
    return True

def get_loss(y,y_hat):
    loss_fn = torch.nn.MSELoss()
    return loss_fn(y,y_hat)


def run_epoch(model,mode,data_loader,optimizer,logger,epoch,fig_dir,checkpoint_dir,args, summary_writer, lr_scheduler):
    if mode == 'train':
        enable_grad = True
        model.train()
    else:
        enable_grad = False
        model.eval()
    

    mse = mae = epoch_loss = 0.0

    with torch.set_grad_enabled(enable_grad):
        # randomly select a batch to draw
        draw_iter_index = random.randint(0,len(data_loader)-1)
        for idx, (context, target) in tqdm(enumerate(data_loader)):
            count = idx + 1
            context, target = context.to(args.device), target.to(args.device)
            context, target, mean, std = data_loader.dataset.normalize(context, target)

            if args.RT and mode == 'train':
                context = context + args.noise_strength * torch.randn_like(context).to(args.device)
            

            # context.shape is [B,T]
            ctx_len = context.shape[1]
            b = context.shape[0]
            if args.rs_type == 'random':
                mask = np.zeros(ctx_len)
                one_index = np.random.choice(np.arange(ctx_len), args.rs_len, replace=False)
                mask[one_index] = 1
                mask = repeat(mask, 't -> b t', b=b)
                mask = torch.from_numpy(mask).to(dtype=context.dtype, device=context.device)
                context = context * mask
            elif args.rs_type == 'block':
                mask = np.zeros(ctx_len)
                num_mask = ctx_len - args.rs_len + 1
                one_index = np.random.randint(0,num_mask)
                mask[one_index : one_index+args.rs_len] = 1
                mask = repeat(mask, 't -> b t', b=b)
                mask = torch.from_numpy(mask).to(dtype=context.dtype, device=context.device)
                context = context * mask
            else:
                raise NotImplementedError(f'RS type {args.rs_type} is not implemented')


            prediction = model(context, target)
            loss = get_loss(prediction, target)

            if mode == 'train':
                loss.backward()
                optimizer.step()
                # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.zero_grad()
            
            epoch_loss += loss.item()
            metric = get_metric(prediction,target)
            mae += metric['mae']
            mse += metric['mse']


            # if args.de_normalize:
                # context, target, prediction = data_loader.dataset.de_normalize(context,target,prediction,mean,std)

            if idx == draw_iter_index:
                draw(
                    target.detach().cpu().numpy(),
                    prediction.detach().cpu().numpy(),
                    join(fig_dir,f"epoch_{epoch}_{mode}.png"),
                    summary_writer,
                    epoch,
                    mode
                )

        epoch_loss /= count
        mse /= count
        mae /= count

        logger.log("Epoch=%d, mode=%s, loss=%.6f" % (
            epoch,
            mode,
            epoch_loss
        ))
        
        logger.log("Epoch=%d, mode=%s, MAE=%.6f, MSE=%.6f" % (
            epoch,
            mode,
            mae,
            mse
        ))

        summary_writer.add_scalar(f'MSE/{mode}', mse, epoch)
        summary_writer.add_scalar(f'MAE/{mode}', mae, epoch)
        summary_writer.add_scalar(f'Loss/{mode}', epoch_loss, epoch)

        if mode == 'train':
            lr_scheduler.step()

        if mode == 'train' and (epoch + 1) % args.epoch_gap == 0:
            torch.save(model.cpu(),join(checkpoint_dir,f"epoch_{epoch+1}.pt"))
            model.to(args.device)

        return mse
        

def main(args):
    setup_seed()
    # get args
    
    # create dirs
    dirs = []
    dirs.append(args.output_dir)
    fig_dir = join(args.output_dir, "fig")
    dirs.append(fig_dir)
    checkpoint_dir = join(args.output_dir, "checkpoint")
    dirs.append(checkpoint_dir)
    for dir in dirs:
        mkdir(dir)
    
    # save args
    save_json(vars(args), join(args.output_dir, "args.json"))

    # generate data
    gd = GD(args.dataset_name,args.use_filter,args.window,args.order)
    gd.generate()

    train_logger = Logger(join(args.output_dir, "train.txt"))
    test_logger = Logger(join(args.output_dir, "test.txt"))
    summary_writer = SummaryWriter(args.output_dir)

    # test_logger.log("Loading model...")
    if args.checkpoint == "":
        if args.model_type == "mixer":           
            model = Mixer(args.context_length, args.prediction_length, args.hidden_dim, args.num_layer)
        elif args.model_type == 'linear':
            model = LinearModel(args.context_length, args.prediction_length)
        elif args.model_type == "mlp":
            model = MLP(args.context_length, args.prediction_length, args.hidden_dim, args.num_layer, args.norm, args.drop_out, args.activation)
        elif args.model_type == "trans":
            model = Transformer(args.context_length, args.prediction_length, args.hidden_dim, max(args.hidden_dim//8,1), 2*args.hidden_dim, args.drop_out, args.num_layer)
        elif args.model_type == "lstm":
            model = LSTM(args.hidden_dim, args.num_layer, args.drop_out, args.pred_weight, args.mode)
        elif args.model_type == "gru":
            model = GRU(args.hidden_dim, args.num_layer, args.drop_out, args.pred_weight, args.mode)
        elif args.model_type == "rnn":
            model = RNN(args.hidden_dim, args.num_layer, args.drop_out, args.pred_weight, args.mode)
        elif args.model_type == "tcn":
            model = TCN(args.context_length, args.prediction_length, args.hidden_dim, args.num_layer, args.drop_out)
        elif args.model_type == 'resnet18':
            model = ResNet18(args.context_length, args.prediction_length)
        elif args.model_type == 'resnet34':
            model = ResNet34(args.context_length, args.prediction_length)
        else:
            raise NotImplementedError(f"model_type = {args.model_type} not implemented")
    else:
        test_logger.log(f'Loading checkpoint: {args.checkpoint}')
        model = torch.load(args.checkpoint)
    test_logger.log(args.output_dir)
    test_logger.log(f"Model type is {type(model)}")
    model.to(args.device)


    train_dataloader = DataLoader(
        MyDataset(args.dataset_name,"train",args.context_length,args.prediction_length,args.num_train_batch*args.batch_size,args.use_filter,args.window,args.order),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers
    ) if args.train else None

    test_dataloader = DataLoader(
        MyDataset(args.dataset_name,"test",args.context_length,args.prediction_length,args.num_test_batch*args.batch_size,args.use_filter,args.window,args.order),
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers
    )
    

    if args.optimizer == 'adam':
        optimizer = Adam(model.parameters(),lr=args.lr) if args.train else None
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),lr=args.lr) if args.train else None
    else:
        raise NotImplementedError(f'optimizer={args.optimizer} is not implemented')
    lr_scheduler = ExponentialLR(optimizer, args.lr_decay_rate, verbose=True)
    test_mse_best = 1e12

    for epoch in tqdm(range(args.num_epoch)):
        if args.train:
            run_epoch(model,"train",train_dataloader,optimizer,train_logger,epoch,fig_dir,checkpoint_dir,args, summary_writer, lr_scheduler)

        test_mse = run_epoch(model,"test",test_dataloader,None,test_logger,epoch,fig_dir,checkpoint_dir,args, summary_writer, lr_scheduler)

        if test_mse < test_mse_best:
            test_mse_best = test_mse
            torch.save(model.cpu(), join(checkpoint_dir, f"best.pt"))
            model.to(args.device)

    torch.save(model.cpu(),join(checkpoint_dir, "final.pt"))

    try:
        df = pd.read_csv(args.record_path)
    except:
        df = pd.DataFrame()
    args.best_mse = test_mse_best
    df = pd.concat([df,pd.DataFrame([vars(args)])], axis=0, join='outer')
    df.to_csv(args.record_path, index=False)


def predict(model, dataloader, rs_type:str, rs_len:int, rs_num:int, device:str, output_dir:str):
    '''
    batch_size of this dataloader should be 1
    Save a list of {
        'ctx'   : ctx.cpu(),        # [T,]
        'tgt'   : tgt.cpu(),        # [T,]
        'pred'  : pred.cpu(),       # [rs_num, T]
    } in output_dir/ans.pt
    '''
    mkdir(output_dir)
    with torch.no_grad():
        model.to(device)
        model.eval()
        ans = []
        
        for ctx,tgt in dataloader:
            ctx = ctx.to(device)
            tgt = tgt.to(device)
            ctx, tgt, _, _ = dataloader.dataset.normalize(ctx, tgt)

            ctx_copy = ctx.clone().detach()
            ctx_copy = rearrange(ctx_copy, '1 t -> t')
            tgt_copy = tgt.clone().detach()
            tgt_copy = rearrange(tgt_copy, '1 t -> t')
            
            ctx_len = ctx.shape[1]

            if rs_type == 'random':
                mask = np.zeros((rs_num, ctx_len))
                for i in range(rs_num):
                    index = np.random.choice(np.arange(ctx_len), rs_len, replace=False)
                    mask[i,index] = 1
            elif rs_type == 'block':
                num_block = ctx_len - rs_len + 1
                mask = np.zeros((num_block, ctx_len))
                for block_idx in range(num_block):
                    mask[block_idx, block_idx : block_idx+rs_len] = 1
            else:
                raise NotImplementedError

            mask = torch.from_numpy(mask).to(dtype=ctx.dtype, device=ctx.device)
            if rs_type == 'random':
                ctx = repeat(ctx, '1 t -> n t', n=rs_num)
                tgt = repeat(tgt, '1 t -> n t', n=rs_num)
            elif rs_type == 'block':
                ctx = repeat(ctx, '1 t -> n t', n=num_block)
                tgt = repeat(tgt, '1 t -> n t', n=num_block)
            ctx = ctx * mask

            pred = model(ctx,tgt)
            ans.append({
                'ctx'   : ctx_copy.cpu(),
                'tgt'   : tgt_copy.cpu(),
                'pred'  : pred.cpu(),
            })
        
        if rs_type == 'random':
            torch.save(ans, join(output_dir,f'ans-rs_num={rs_num}.pt'))   
        elif rs_type == 'block':
            torch.save(ans, join(output_dir,f'ans.pt'))   
        else:
            raise NotImplementedError


def evaluate(ans, delta, rs_type, rs_len, atk_len, atk_idx:int=-1):
    '''
    Input:
        ans is a list of dict, where the dict is
            {
                "ctx": tensor, shape = [T,]
                "tgt": tensor, shape = [L,]
                "pred": tensor, shape = [num_rs, L]
            }
    Return:
        mae, mse, acc
    '''      
    list_mse = []
    list_mae = []
    num_correct = 0
    for d in ans:
        tgt = d['tgt'][atk_idx]
        tgt_copy = tgt.item()
        pred = d['pred'][:,atk_idx]
        # tmp = get_metric(repeat(tgt,'1 -> n', n=pred.shape[0]),pred)
        # mae = tmp['mae']
        # mse = tmp['mse']
        # list_mae.append(mae)
        # list_mse.append(mse)

        min_value = min(pred.min(), tgt).item() - 1
        tgt -= min_value
        tgt_copy -= min_value
        pred -= min_value
        tgt = torch.div(tgt, delta, rounding_mode='floor').to(torch.int32)
        pred = torch.div(pred, delta, rounding_mode='floor').to(torch.int32)
        bin_map = torch.bincount(pred)
        # most_class = torch.argmax(bin_map).item()

        k = min(2, bin_map.shape[0])
        v,i = torch.topk(bin_map, k=k)
        top_class = i[0].item()
        top_weight = v[0].item()
        if k >= 2:
            second_class = i[1].item()
            second_weight = v[1].item()
        else:
            second_class = second_weight = 0

        if rs_type == 'random':
            condition = certify_random(top_weight, pred.shape[0], rs_len, atk_len)
        elif rs_type == 'block':
            condition = certify_block(top_weight, second_weight, rs_len, atk_len)
        else:
            raise NotImplementedError

        if top_class == tgt and condition:
            num_correct += 1
            pred = top_class*delta + 0.5*delta
            mse = (pred - tgt_copy)**2
            mae = abs(pred - tgt_copy)
            list_mae.append(mae)
            list_mse.append(mse)

    if num_correct == 0:
        return np.inf, np.inf, 0.0
    else:
        mae = sum(list_mae) / len(list_mae)
        mse = sum(list_mse) / len(list_mse)
        acc = num_correct / len(ans)
        return mae, mse, acc
            



if __name__ == "__main__":
    # ==================================== single run
    args = get_args()
    args = post_process_args(args)
    main(args)
    exit()

    # ==================================== batch run
    # args = get_args()
    # result_dir = 'result-pred-rs-0818'
    # for pre_type in PRE_MODEL_TYPES:
    #     for dataset_name in DATASET_NAMES:
    #         for rs_type in RS_TYPES:
    #             for rs_len in RS_LENS:
    #                 args.dataset_name = dataset_name
    #                 args.output_dir = f"{result_dir}/{dataset_name}/{pre_type}/{rs_type}/{rs_len}"
    #                 args.model_type = pre_type
    #                 args.rs_type = rs_type
    #                 args.rs_len = rs_len
    #                 args = post_process_args(args)
    #                 main(args)

    # ==================================== predict and save
    # setup_seed()
    # result_dir = 'result-pred-rs-0818'
    # device = 'cuda:2'
    # num_sample = 101
    # rs_num = 10000
    # for dataset_name in DATASET_NAMES:
    #     ctx_len, tgt_len = get_dataset_params(dataset_name)
    #     dataloader = DataLoader(
    #         MyDataset(dataset_name,"test",ctx_len,tgt_len,num_sample,True,15,5),
    #         batch_size=1,
    #         shuffle=False,
    #         drop_last=True,
    #         num_workers=4
    #     )
    #     for pre_type in PRE_MODEL_TYPES:
    #         for rs_type in RS_TYPES:
    #             for rs_len in RS_LENS:
    #                 tgt_dir = f'{result_dir}/{dataset_name}/{pre_type}/{rs_type}/{rs_len}'
    #                 model = torch.load(join(tgt_dir,'checkpoint/best.pt'))
    #                 predict(model, dataloader, rs_type, rs_len, rs_num, device, tgt_dir)
    #                 print(tgt_dir)
    # exit()




    # ==================================== evaluate ans generated by predict
    rs_num = 10000
    result_dir = 'result-pred-rs-0818'
    csv_path = f'record-pred-rs-metric-rs_num={rs_num}-more_atks.csv'
    df = pd.DataFrame()
    for dataset_name in DATASET_NAMES:
        for pre_type in PRE_MODEL_TYPES:
            for rs_type in RS_TYPES:
                for rs_len in RS_LENS:
                    if rs_type == 'random':
                        ans_path = f'{result_dir}/{dataset_name}/{pre_type}/{rs_type}/{rs_len}/ans-rs_num={rs_num}.pt'
                    elif rs_type == 'block':
                        ans_path = f'{result_dir}/{dataset_name}/{pre_type}/{rs_type}/{rs_len}/ans.pt'
                    ans = torch.load(ans_path)
                    for delta in DELTAS:
                        for atk_len in range(1,max(MASK_LENGTH)+1):
                            mae, mse, acc = evaluate(deepcopy(ans), delta, rs_type, rs_len, atk_len)
                            line = {
                                "dataset_name": dataset_name,
                                "pre_type": pre_type,
                                "rs_type": rs_type,
                                "rs_len": rs_len,
                                "delta": delta,
                                "mae": mae,
                                "mse": mse,
                                "acc": acc,
                                "atk_len": atk_len,
                            }
                            df = pd.concat([df,pd.DataFrame([line])])
                    print(ans_path)
    df.to_csv(csv_path, index=None)

