from argparse import ArgumentParser
from os.path import join, isfile
from dataset.dataset import MyDataset
import torch
from torch import linalg as LA
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from my_utils.utils import Logger
from my_utils.utils import DATASET_NAMES, get_dataset_params, mkdir, save_json, setup_seed, PRE_MODEL_TYPES
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import random
from torch.nn import functional as F
import shutil


def get_args():
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--atk_rate", type=float, default=0.1, help="the rate of attack length")
    arg_parser.add_argument("--step_method", type=str, default='pgd', choices=['pgd','normal'], help='step method for optimizer')
    arg_parser.add_argument("--atk_method", type=str, default='adv', choices=['adv','ga'], help='attack method for optimizer, adv=adversarial, ga=gradient ascending')
    arg_parser.add_argument("--pre_type", type=str, default='lstm', choices=PRE_MODEL_TYPES)
    arg_parser.add_argument("--pre_dir", type=str, default='result_prediction_0827')
    arg_parser.add_argument("--norm_weight", type=float, default=0.0, help="loss weight of perturbation norm")
    arg_parser.add_argument("--dataset_name", type=str, default='traffic_nips', choices=DATASET_NAMES)
    arg_parser.add_argument("--output_dir", type=str, default="result-attack-debug")
    arg_parser.add_argument("--max_norm", type=float, default=3.5, help='max norm for PGD')
    arg_parser.add_argument("--seed", type=int, default=1024)
    arg_parser.add_argument("--fontsize", type=int, default=25)
    arg_parser.add_argument("--lr", type=float, default=0.5)
    arg_parser.add_argument("--lr_decay_rate", type=float, default=0.98)


    arg_parser.add_argument("--perturbation", type=str, default="", help="checkpoint of perturbation")
    arg_parser.add_argument("--device", type=str, default="cuda:0")
    arg_parser.add_argument("--num_epoch", type=int, default=300)
    arg_parser.add_argument("--log_epoch", type=int, default=10)
    arg_parser.add_argument("--batch_size", type=int, default=1)
    arg_parser.add_argument("--train", type=int, default=1)
    arg_parser.add_argument("--save", type=int, default=1)


    # filter settings
    arg_parser.add_argument("--use_filter", type=int, default=1)
    arg_parser.add_argument("--window", type=int, default=15)
    arg_parser.add_argument("--order", type=int, default=5)

    arg_parser.add_argument("--rt_noise", type=float, default=0.0)
    arg_parser.add_argument("--len_clip", type=float, default=1.0)


    args = arg_parser.parse_args()
    return args


def post_process_args(args):
    args.train = True if args.train else False
    args.use_filter = True if args.use_filter else False
    if not args.train:
        args.num_epoch = 1
    args.ctx_len, args.pred_len = get_dataset_params(args.dataset_name, args.len_clip)
    args.prediction_model = f'{args.pre_dir}/{args.dataset_name}/{args.pre_type}/filter={args.use_filter}/rt_noise={args.rt_noise}/len_clip={args.len_clip}/checkpoint/best.pt'
    args.atk_len = int(args.atk_rate * args.ctx_len)
    args.atk_start = args.ctx_len - args.atk_len
    assert args.atk_start >= 0
    assert args.atk_start + args.atk_len <= args.ctx_len
    return args


class AttackModule(torch.nn.Module):
    def __init__(self,perturbation_path:str, batch_size:int, atk_len:int, atk_start:int) -> None:
        '''
        Pertube the input context with trainable perturbation
        '''
        super().__init__()
        self.atk_len = atk_len
        self.atk_start = atk_start
        if isfile(perturbation_path):
            self.perturbation = torch.load(perturbation_path)
        else:
            self.perturbation = torch.nn.Parameter(torch.randn(batch_size,atk_len))
    
    def get_norm(self):
        return LA.norm(self.perturbation,dim=1,keepdim=True)
    
    def forward(self,context):
        s = self.atk_start
        l = self.atk_len
        perturbed_context = torch.cat(
            [
                context[:,0:s],
                context[:,s:s+l] + self.perturbation,
                context[:,s+l:]
            ],
            dim = 1
        )
        return perturbed_context


def draw(context,perturbed_context,clean_prediction,perturbed_prediction,target_adv,fig_dir,epoch,args):
    """
    context, perturbed_context, clean_prediction, perturbed_prediction: shape == [B,T]
    """
    context = context.detach().cpu().numpy()
    perturbed_context = perturbed_context.detach().cpu().numpy()
    clean_prediction = clean_prediction.detach().cpu().numpy()
    perturbed_prediction = perturbed_prediction.detach().cpu().numpy()
    target_adv = target_adv.detach().cpu().numpy()


    t = np.arange(args.ctx_len + args.pred_len)

    plt.figure(figsize=(8,6))
    for subplot_index in range(min(context.shape[0],1)):
        plt.subplot(1,1,subplot_index+1)
        ori_pred = np.concatenate([context[subplot_index],clean_prediction[subplot_index]])
        # plt.plot(t[args.ctx_len:],target_adv[subplot_index],'-g',label='tgt-adv')
        pert_pred = np.concatenate([perturbed_context[subplot_index],perturbed_prediction[subplot_index]])

        y_min = min(ori_pred.min(), pert_pred.min())
        y_max = max(ori_pred.max(), pert_pred.max())

        plt.fill_betweenx([y_min,y_max], args.atk_start, args.ctx_len, color='#ffe0e9')
        plt.fill_betweenx([y_min,y_max], args.ctx_len, args.ctx_len+args.pred_len, color='#dfe4f3')


        plt.plot(t,pert_pred,'-r',label='TLA')
        plt.plot(t, ori_pred,'-b',label='normal')
        plt.axvline(args.ctx_len,color='black')
        plt.grid(which='both')
        plt.legend(fontsize=args.fontsize, loc='upper left')
        plt.xticks(fontsize=args.fontsize)
        plt.yticks(fontsize=args.fontsize)
    plt.savefig(join(fig_dir,f'epoch_{epoch}.png'))
    plt.close()


def run(
    model,
    mode:str,
    attack_module,
    optimizer,
    context,
    clean_prediction,
    target_adv,
    epoch:int,
    logger,
    fig_dir:str,
    args,
    scheduler
):
    enable_grad = True if mode == 'train' else False
    with torch.set_grad_enabled(enable_grad):
        perturbed_context = attack_module(context)
        perturbed_prediction = model(perturbed_context,target_adv)
        bias = F.mse_loss(perturbed_prediction,clean_prediction)

        if args.atk_method == 'adv':
            loss = F.mse_loss(perturbed_prediction, target_adv) + args.norm_weight * attack_module.get_norm().mean()
            if mode == 'train':
                loss.backward()
                if args.step_method == 'pgd':
                    with torch.no_grad():
                        direction = attack_module.perturbation.grad / LA.norm(attack_module.perturbation.grad,dim=1,keepdim=True)
                        attack_module.perturbation.data = attack_module.perturbation.data - args.lr * direction
                        norm = attack_module.get_norm().repeat(1,attack_module.perturbation.shape[1])
                        attack_module.perturbation.data = torch.where(
                            (norm > args.max_norm),
                            attack_module.perturbation.data / norm * args.max_norm,
                            attack_module.perturbation.data
                        )
                        args.lr *= args.lr_decay_rate
                        optimizer.zero_grad()

                elif args.step_method == 'normal':
                    optimizer.step()
                    optimizer.zero_grad()
                    scheduler.step()
                else:
                    raise NotImplementedError(f'step_method = {args.step_method} not implemented')

        elif args.atk_method == 'ga':
            loss = -bias + args.norm_weight * attack_module.get_norm().mean()
            if mode == 'train':
                # bias越大越好，loss越小越好，因此对loss来说，还是梯度下降来做
                if args.step_method == 'pgd':
                    loss.backward()
                    with torch.no_grad():
                        direction = attack_module.perturbation.grad / LA.norm(attack_module.perturbation.grad,dim=1,keepdim=True)
                        attack_module.perturbation.data = attack_module.perturbation.data - args.lr * direction
                        norm = attack_module.get_norm().repeat(1,attack_module.perturbation.shape[1])
                        attack_module.perturbation.data = torch.where(
                            (norm > args.max_norm),
                            attack_module.perturbation.data / norm * args.max_norm,
                            attack_module.perturbation.data
                        )
                        args.lr *= args.lr_decay_rate
                        optimizer.zero_grad()
                
                # 攻击的时候bias越大越好，因此loss越小越好，可以使用梯度下降了
                elif args.step_method == 'normal':
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                    scheduler.step()
                else:
                    raise NotImplementedError(f'step_method = {args.step_method} not implemented')

        else:
            raise NotImplementedError(f'atk_method = {args.atk_method} not implemented')

        
        if (epoch + 1) % args.log_epoch == 0:
            logger.log(f'mode={mode}, epoch={epoch+1}, loss={loss.item()}, bias={bias.item()}, norm={attack_module.get_norm().mean()}')
            draw(context,perturbed_context,clean_prediction,perturbed_prediction,target_adv,fig_dir,epoch+1,args)

        with torch.no_grad():
            relative_perturbation = attack_module.perturbation / context[:,args.atk_start:args.atk_start + args.atk_len]
            relative_norm = LA.norm(relative_perturbation,dim=1).mean().item()

        return {
            "bias": bias.item(),
            "loss": loss.item(),
            "norm": attack_module.get_norm().mean().item(),
            "relative_norm": relative_norm,
            "context": context,
            "clean_prediction" : clean_prediction,
            "perturbation": attack_module,
            "target_adv": target_adv,
            "args": args,
            "perturbed_context": perturbed_context,
            "perturbed_prediction": perturbed_prediction
        }


def main(args):
    setup_seed(args.seed)

    shutil.rmtree(args.output_dir)
    fig_dir = join(args.output_dir, 'fig')
    ck_dir = join(args.output_dir, 'ck')
    for d in [args.output_dir, fig_dir, ck_dir]:
        mkdir(d)
    save_json(vars(args),join(args.output_dir,'args.json'))


    logger = Logger(join(args.output_dir,'log.txt'))


    logger.log(args.prediction_model)
    model = torch.load(args.prediction_model).to(args.device).train()
    

    logger.log(args.dataset_name)
    dataset = MyDataset(
        args.dataset_name,
        "test",
        args.ctx_len,
        args.pred_len,
        1000*args.batch_size,
        args.use_filter,
        args.window,
        args.order
    )
    dataloader = DataLoader(dataset,args.batch_size,shuffle=True,drop_last=True)
    dataloader_iter = iter(dataloader)

    tmp = list((dataloader_iter))

    context, target = tmp[random.randint(0,len(tmp)-1)]
    context, target = context.to(args.device), target.to(args.device)
    context, target, _, _ = dataset.normalize(context,target)
    target_adv = target * 0

    with torch.no_grad():
        clean_prediction = model(context, target)


    attack_module = AttackModule(args.perturbation, args.batch_size, args.atk_len, args.atk_start)
    attack_module.to(args.device).train()
    optimizer = Adam(attack_module.parameters(),lr=args.lr)
    scheduler = ExponentialLR(optimizer,args.lr_decay_rate,verbose=True)


    bias_best = 0
    current_norm = 0
    best_record = None
    for epoch in tqdm(range(args.num_epoch)):
        if args.train:
            record = run(model,"train",attack_module,optimizer,context,clean_prediction,target_adv,epoch,logger,fig_dir,args,scheduler)
        else:
            record = run(model,"test",attack_module,None,context,clean_prediction,target_adv,epoch,logger,fig_dir,args,None)

        if record['bias'] > bias_best:
            bias_best = record['bias']
            current_norm = record['norm']
            best_record = record
            draw(record['context'], record['perturbed_context'], record['clean_prediction'], record['perturbed_prediction'], record['target_adv'], fig_dir, 'best', record['args'])
    logger.log('Best bias = {:<3.3f}, norm = {:<3.3f}, relative_norm = {:<3.3f}'.format(bias_best, current_norm, best_record["relative_norm"]))

    if args.save:
        for k in best_record.keys():
            try:
                best_record[k] = best_record[k].cpu()
            except:
                pass
        torch.save(best_record,join(args.output_dir,'best_record.pt'))


if __name__ == "__main__":
    ############################################################
    args = get_args()
    args = post_process_args(args)
    main(args)
    exit()
