import json
import os
import shutil
from time import time

import numpy as np

import paddle.vision.transforms as transforms
import paddle
from paddle import nn
import paddle.nn.functional as F

import config
from utils.dataloader import PostTensorTransform, get_dataloader, DictDataset, get_dataset, get_transform
from utils.utils import progress_bar
from autoencoders.models import NetC_MNIST, AutoEncoder_MNIST
from autoencoders.models import AutoencoderCifar
from autoencoders.models import AutoencoderMnist
from autoencoders.models import AutoencoderCeleba

from attacks.attacks import SmoothAttacker, CleanLabelAttacker, HardAttacker
from attacks.backdoors import BadNetTrigger, SIGTrigger, WaNetTrigger, TcbTrigger, NaiveTcbTrigger, GeneTcbTrigger, AETcbTrigger
from defences import StripDefencer, EntropyDefencer, ShrinkPad
from classifier_models import VGG, PreActResNet18, ResNet18, AutoEncoder_VGG


class BaseTrainer(object):
    def __init__(self, opt, *args):
        self.opt = opt
        
    def get_model(self, opt, mode=None, **kwargs):
        netC = None
        optimizerC = None
        schedulerC = None

        if opt.dataset == "cifar10" or opt.dataset == 'gtsrb':
            if mode == 'source':
                netC = VGG('VGG11', num_classes=opt.num_classes)
                # netC = AutoEncoder_VGG('VGG11', num_classes=opt.num_classes)
            elif mode == 'target':
                netC = PreActResNet18(num_classes=opt.num_classes)
        if opt.dataset == "celeba":
            netC = ResNet18()
        if opt.dataset == "mnist":
            netC = NetC_MNIST()

        # set optimizer and scheduler
        clip = paddle.nn.ClipGradByNorm(clip_norm=5.0)
        schedulerC = paddle.optimizer.lr.MultiStepDecay(opt.lr_C, opt.schedulerC_milestones, opt.schedulerC_lambda)
        optimizerC = paddle.optimizer.Momentum(schedulerC, momentum=0.9, parameters=netC.parameters(), weight_decay=5e-4, grad_clip=clip)
        # optimizerC = paddle.optimizer.Momentum(schedulerC, momentum=0.95, parameters=netC.parameters(), weight_decay=5e-4)
        # optimizerC = paddle.optimizer.Adam(learning_rate=1e-3, grad_clip=clip, weight_decay=5e-4, parameters=netC.parameters())

        return netC, optimizerC, schedulerC

    def train(self):
        """
        Perform the training and evaluating process.
        """

        raise NotImplementedError

    def eval(self, netC, valid_dl, epoch, opt, *args):
        netC.eval() # dropout/normalization
        total_correct = 0
        total_sample = 0
        for batch_idx, batch in enumerate(valid_dl):
            with paddle.no_grad():
                inputs, targets = batch['input'], batch['target']
                targets = targets.squeeze()
                bs = inputs.shape[0]
                total_sample += bs
                preds = netC(inputs)
                total_correct += paddle.sum(paddle.to_tensor(paddle.argmax(preds, 1) == targets, dtype=paddle.int64))
                acc = total_correct * 100.0 / total_sample
                info_string = "Acc: {:.4f}".format(acc.item())
                progress_bar(batch_idx, len(valid_dl), info_string)
        netC.train()
        return {'Acc':acc.item()}

    def predict(self, netC, test_dl, opt, *args):
        """
        Predict the probability

        @Returns:
            inputs (np.ndarray): input images
            labels (np.ndarray): output labels
            probs (np.ndarray): predicting probability
        """
        print("Predicting the probability:")
        netC.eval()
        total_sample = 0
        list_inputs = []
        list_probs = []
        list_labels = []
        for batch_idx, batch in enumerate(test_dl):
            with paddle.no_grad():
                inputs, targets = batch['input'], batch['target']
                bs = inputs.shape[0]
                total_sample += bs
                probs = nn.functional.softmax(netC(inputs), axis=-1)
                list_inputs.append(inputs)
                list_labels.append(targets)
                list_probs.append(probs)
                progress_bar(batch_idx, len(test_dl), "")
        inputs = paddle.concat(list_inputs).cpu().numpy()
        labels = paddle.concat(list_labels).cpu().numpy()
        probs = paddle.concat(list_probs).cpu().numpy()
        netC.train()
        return inputs, labels, probs
    
    def train_one_epoch(self, netC, optimizerC, schedulerC, train_dl, epoch, opt, *args):
        """
        Train the model one epoch.
        """
        print(" Train:")
        netC.train()
        total_loss_ce = 0
        total_sample = 0
        if opt.use_label_smooth:
            criterion_CE = paddle.nn.CrossEntropyLoss(soft_label=True)
        else:
            criterion_CE = paddle.nn.CrossEntropyLoss()

        total_time = 0
        total_correct = 0

        for batch_idx, batch in enumerate(train_dl):
            optimizerC.clear_grad()
            inputs, targets = batch['input'], batch['target']
            inputs, targets = inputs, targets
            targets = targets.squeeze()
            bs = inputs.shape[0]
            
            preds = netC(inputs)
            if opt.use_label_smooth:
                soft_targets = paddle.eye(self.opt.num_classes)[targets]
                soft_targets *= 0.8
                soft_targets += 0.2 / opt.num_classes
                loss_ce = criterion_CE(preds, soft_targets)
            else:
                loss_ce = criterion_CE(preds, targets)

            loss = loss_ce
            loss.backward()
        
            optimizerC.step()

            total_sample += bs
            total_loss_ce += loss_ce.detach()

            total_correct += paddle.sum(paddle.to_tensor(paddle.argmax(preds, axis=1) == targets, dtype=paddle.int64))
            
            avg_acc = total_correct * 100.0 / total_sample

            avg_loss_ce = total_loss_ce / (batch_idx + 1)

            progress_bar(
                batch_idx,
                len(train_dl),
                "CE Loss: {:.4f} | Acc: {:.4f} ".format(avg_loss_ce.item(), avg_acc.item()),
            )
        schedulerC.step()


class BaseAttackTrainer(BaseTrainer):
    def __init__(self, opt, attackers, *args):
        super(BaseAttackTrainer, self).__init__(opt)
        self.attackers = attackers

    def get_transformer(self, opt=None, train=True):
        if not opt:
            opt = self.opt
        transforms_list = []
        if train:
            transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop))
        if opt.dataset == "cifar10": 
            if train:
                transforms_list.append(transforms.RandomHorizontalFlip(prob=0.5)) 
            transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
        elif opt.dataset == 'mnist':
            transforms_list.append(transforms.Normalize([0.5], [0.5]))
        if transforms_list:
            transformer = transforms.Compose(transforms_list)
        else:
            transformer = None
        return transformer

    def denormalize(self, inputs, opt=None):
        if not opt:
            opt = self.opt
        if opt.dataset == 'cifar10':
            mean = np.array([0.4914, 0.4822, 0.4465], dtype='float32')[:, None, None]
            std = np.array([0.247, 0.243, 0.261], dtype='float32')[:, None, None]
        elif opt.dataset == 'mnist':
            mean = np.array([0.5], dtype='float32')[:, None, None]
            std = np.array([0.5], dtype='float32')[:, None, None]
        else:
            mean = np.array([0.], dtype='float32')[:, None, None]
            std = np.array([1.], dtype='float32')[:, None, None]
        if isinstance(inputs, np.ndarray):
            inputs = inputs * std[None, ...] + mean[None, ...]
        else:
            inputs = inputs * paddle.to_tensor(std).unsqueeze(0) + paddle.to_tensor(mean).unsqueeze(0)
        return inputs

    def get_poisoned_train_data(self, inputs, labels, pred_probs, opt):
        """
        Perform attack to training data.

        @Args:
            inputs (np.ndarray): input images (unnormalized)
            labels (np.ndarray): output labels
            pred_probs (np.ndarray): predicted probability of input images
            opt (dict): configuration
        @Return:
            training dataloader
        """
        assert len(inputs) == len(labels)
        assert len(labels) == len(pred_probs)
        class_num = pred_probs.shape[1]
        assert class_num >= 2
        target_probs = np.clip(pred_probs.max(axis=1) + 0.1, 1./opt.num_classes, .6) # beta
        source_probs = pred_probs[:, self.opt.target_label] # alpha

        # insert backdoors to some of the training data
        attacked_inputs, attacked_labels = inputs, labels
        for attacker in self.attackers:
            attacked_inputs, attacked_labels = attacker.attack(attacked_inputs, attacked_labels, source_probs, target_probs)

        # build dataset and dataloader 
        transformer = self.get_transformer(opt, train=True)
        dataset = DictDataset({'input':attacked_inputs, 'target':attacked_labels, 'origin_input':inputs, 'origin_target':labels},
        input_transform=transformer)
        dataloader = paddle.io.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
        return dataloader

    def get_poisoned_eval_data(self, eval_dl, opt, *args, attackers=None, pred_probs=None, mode='all'):
        if not attackers:
            attackers = self.attackers
        inputs = []
        labels = []
        for batch in eval_dl:
            inputs.append(batch['input'])
            labels.append(batch['target'])
        if pred_probs:
            assert len(inputs) == len(pred_probs)
        inputs = np.concatenate(inputs)
        labels = np.concatenate(labels)
        
        #Denormalize
        inputs = self.denormalize(inputs)
        
        if mode == 'all': # poison all samples
            attacked_inputs = inputs
            for attacker in attackers:
                attacked_inputs = attacker.trigger.apply_all(attacked_inputs)
            attacked_labels = np.ones_like(labels)*opt.target_label
        elif mode == 'non-target': # poison non-target class samples
            if len(labels.shape) > 1:
                indices = labels[:, 0] != opt.target_label
            else:
                indices = labels != opt.target_label
            attacked_inputs = inputs[indices]
            for attacker in attackers:
                attacked_inputs = attacker.trigger.apply_all(attacked_inputs)
            attacked_labels = np.ones_like(labels[indices])*opt.target_label
        elif mode == 'pred-non-target': # poison predicted non-target class samples by the source model
            pred_labels = pred_probs.max(axis=1)
            indices = pred_labels != opt.target_label
            attacked_inputs = inputs[indices]
            for attacker in attackers:
                attacked_inputs = attacker.trigger.apply_all(attacked_inputs)
            attacked_labels = np.ones_like(labels[indices])*opt.target_label
        else:
            raise Exception('Unsupport mode: {}'.format(mode))
        # Build testing dataloader
  
        transformer = self.get_transformer(opt, train=False)
        dataset = DictDataset({'input':attacked_inputs, 'target':attacked_labels, 'origin_input':inputs, 'origin_target':labels}, input_transform=transformer)
        dataloader = paddle.io.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
        return dataloader

    def get_clean_data(self, opt, attack=False):
        if opt.dataset in ["mnist", "cifar10"]:
            opt.num_classes = 10
        elif opt.dataset == "gtsrb":
            opt.num_classes = 43
        elif opt.dataset == "celeba":
            opt.num_classes = 8
        else:
            raise Exception("Invalid Dataset")

        if opt.dataset == "cifar10":
            opt.input_height = 32
            opt.input_width = 32
            opt.input_channel = 3
        elif opt.dataset == "gtsrb":
            opt.input_height = 32
            opt.input_width = 32
            opt.input_channel = 3
        elif opt.dataset == "mnist":
            opt.input_height = 28
            opt.input_width = 28
            opt.input_channel = 1
        elif opt.dataset == "celeba":
            opt.input_height = 64
            opt.input_width = 64
            opt.input_channel = 3
        else:
            raise Exception("Invalid Dataset")
        train_dl = get_dataloader(opt, train=True, attack=attack)
        test_dl = get_dataloader(opt, train=False)
        
        return train_dl, test_dl

    def train_a_model(self, netC, optimizerC, schedulerC, train_dl, valid_dl,  ckpt_path, opt, *args, current_epoch=0):
        netC.train()
        perf_dict = {}
        best_perf_dict = {}
        attack_acc = 0
        metric = opt.model_check_metric
        posioned_dl = self.get_poisoned_eval_data(valid_dl, attackers=self.attackers, mode='all', opt=self.opt)
        example = paddle.randn([opt.input_channel, opt.input_height, opt.input_width])
        for epoch in range(opt.n_iters):
            print("Epoch {}:".format(current_epoch + epoch + 1))
            self.train_one_epoch(netC, optimizerC, schedulerC, train_dl, epoch, opt)
            perf_dict = self.eval(netC=netC, valid_dl=valid_dl, epoch=epoch, opt=opt)
            attack_acc = self.eval(netC=netC, valid_dl=posioned_dl, epoch=epoch, opt=opt)["Acc"]
            # Save checkpoint
            if metric not in best_perf_dict or perf_dict[metric] > best_perf_dict[metric]:
                print(" Saving...")
                best_perf_dict[metric] = perf_dict[metric]
                best_perf_dict["ASR"] = attack_acc
                state_dict = {
                    "netC": netC.state_dict(),
                    "schedulerC": schedulerC.state_dict(),
                    "optimizerC": optimizerC.state_dict(),
                    "epoch_current": epoch + current_epoch + 1,
                    'config': opt
                }
                paddle.save(state_dict, ckpt_path)
            print("Best performance:")
            print(best_perf_dict)
        # recover the best model
        state_dict = paddle.load(ckpt_path)
        netC.set_state_dict(state_dict["netC"])
        optimizerC.set_state_dict(state_dict["optimizerC"])
        schedulerC.set_state_dict(state_dict["schedulerC"])
        return best_perf_dict

    def train(self):
        """
        Train f_clean -> train f_vim 
        """
        # train the source model (f_clean)
        train_dl, test_dl = self.get_clean_data(self.opt, attack=False)
        netS, optimizerS, schedulerS = self.get_model(self.opt, mode='source')
        # source_clean_perf_dict = self.train_a_model(netS, optimizerS, schedulerS, train_dl, 
        #                                             test_dl, self.opt.source_ckpt_path, self.opt)
        source_clean_perf_dict = {'acc':1}
        # perform prediction using the source model
        train_dl, test_dl = self.get_clean_data(self.opt, attack=True)
        inputs, labels, probs = self.predict(netS, train_dl, self.opt)
        #Denormalize
        inputs = self.denormalize(inputs)
        
        # get poisoned training data
        attacked_train_dl = self.get_poisoned_train_data(inputs, labels, probs, opt=self.opt)
        # train the victim model
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        target_clean_perf_dict = self.train_a_model(netT, optimizerT, schedulerT, attacked_train_dl, 
                                                    test_dl, self.opt.target_ckpt_path, self.opt)
        # print performance
        print('Source Clean Performance:')
        result = '\t'.join(['{}:{}'.format(key, source_clean_perf_dict[key]) for key in source_clean_perf_dict.keys()])
        print(result)
        print('Target Clean Performance:')
        result = '\t'.join(['{}:{}'.format(key, target_clean_perf_dict[key]) for key in target_clean_perf_dict.keys()])
        print(result)
        return 0

    def train_attack(self):
        # load the source model
        train_dl, test_dl = self.get_clean_data(self.opt, attack=True)
        netS, optimizerS, schedulerS = self.get_model(self.opt, mode='source')
        state_dict = paddle.load(self.opt.source_ckpt_path)
        netS.set_state_dict(state_dict["netC"])
        optimizerS.set_state_dict(state_dict["optimizerC"])
        schedulerS.set_state_dict(state_dict["schedulerC"])
        # perform prediction using the source model
        inputs, labels, probs = self.predict(netS, train_dl, self.opt)
        #Denormalize
        inputs = self.denormalize(inputs)
        # get poisoned training data
        attacked_train_dl = self.get_poisoned_train_data(inputs, labels, probs, self.opt)
        # train the attacked model
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        current = 0
        if self.opt.continue_training:
            if os.path.exists(self.opt.target_ckpt_path):
                print("continue")
                state_dict = paddle.load(self.opt.target_ckpt_path)
                netT.set_state_dict(state_dict["netC"])
                optimizerT.set_state_dict(state_dict["optimizerC"])
                schedulerT.set_state_dict(state_dict["schedulerC"])
                current = state_dict['epoch_current']
        target_clean_perf_dict = self.train_a_model(netT, optimizerT, schedulerT, attacked_train_dl, 
                                                    test_dl, self.opt.target_ckpt_path, self.opt, current_epoch = current)
        # get the attack success rate on poisoned test dataset
        target_attack_perf_dicts = []
        for attacker in self.attackers:
            attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=[attacker], opt=self.opt, mode='all')
            target_attack_perf_dicts.append(self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt))
        attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers, opt=self.opt, mode='all')
        target_attack_perf_dicts.append(self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt))
        
        # print performance
        print('Target Clean Performance:')
        result = '\t'.join(['{}:{}'.format(key, target_clean_perf_dict[key]) for key in target_clean_perf_dict.keys()])
        print(result)
        for i, target_attack_perf_dict in enumerate(target_attack_perf_dicts):
            print('Target Attack Performance for Attacker {}:'.format(i))
            result = '\t'.join(['{}:{}'.format(key, target_attack_perf_dict[key]) for key in target_attack_perf_dict.keys()])
            print(result)
        return 0

    def eval_attack(self):
        # load the target model
        train_dl, test_dl = self.get_clean_data(self.opt)
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        state_dict = paddle.load(self.opt.target_ckpt_path)
        netT.set_state_dict(state_dict["netC"])
        optimizerT.set_state_dict(state_dict["optimizerC"])
        schedulerT.set_state_dict(state_dict["schedulerC"])

        # get the attack success rate on poisoned test dataset
        self.eval(netT, valid_dl=test_dl, epoch=self.opt.n_iters, opt=self.opt)
        attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers, opt=self.opt, mode='all')
        self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt)
        target_attack_perf_dicts = []
        for attacker in self.attackers:
            attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=[attacker], opt=self.opt, mode='all')
            target_attack_perf_dicts.append(self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt))
        # for i in range(len(self.attackers)-1):
        #     attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers[i:i+2], opt=self.opt, mode='all')
        #     target_attack_perf_dicts.append(self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt))
        # for i in range(len(self.attackers)-2):
        #     attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers[i:i+3], opt=self.opt, mode='all')
        #     target_attack_perf_dicts.append(self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt))
        # for i in range(len(self.attackers)-3):
        #     attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers[i:i+3], opt=self.opt, mode='all')
        #     target_attack_perf_dicts.append(self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt))
        
        for i, target_attack_perf_dict in enumerate(target_attack_perf_dicts):
            print('Target Attack Performance for Attacker {}:'.format(i))
            result = '\t'.join(['{}:{}'.format(key, target_attack_perf_dict[key]) for key in target_attack_perf_dict.keys()])
            print(result)
        return 0

    def analyse_prediction(self, netC, attackers, analyse_dl, opt, *args):
        
        attacked_analyse_dl = self.get_poisoned_eval_data(analyse_dl, attackers=attackers, opt=self.opt, mode='all')
        benign_probs = []
        for batch in analyse_dl:
            inputs, targets = batch['input'], batch['target']
            benign_probs.append(nn.functional.softmax(netC(inputs), axis=-1))
        poison_probs = []
        for batch in attacked_analyse_dl:
            inputs, targets = batch['input'], batch['target']
            poison_probs.append(nn.functional.softmax(netC(inputs), axis=-1))
        benign_probs = paddle.concat(benign_probs)
        poison_probs = paddle.concat(poison_probs)
        benign_entropy = -(benign_probs * paddle.log2(benign_probs)).sum(axis=-1)
        poison_entropy = -(poison_probs * paddle.log2(poison_probs)).sum(axis=-1)
        return benign_probs, poison_probs, benign_entropy, poison_entropy


class DefenceAttackTrainer(BaseAttackTrainer):
    def __init__(self, defencer, *args, **kwargs):
        super(DefenceAttackTrainer, self).__init__(*args, **kwargs)
        self.defencer = defencer

    def get_limit_poisoned_eval_data(self, eval_dl, opt, netC, *args, attackers=None, pred_probs=None, mode='all'):
        """
        Get the poisoned sample using dynamic number of attackers.

        Once the prediction label of the sample is the target class, we stop activating the rest LSBAs.
        """

        if not attackers:
            attackers = self.attackers
        inputs = []
        labels = []
        netC.eval()
        for batch in eval_dl:
            inputs.append(batch['input'])
            labels.append(batch['target'])
        if pred_probs:
            assert len(inputs) == len(pred_probs)
        inputs = np.concatenate(inputs)
        labels = np.concatenate(labels)
        #Denormalize
        inputs = self.denormalize(inputs)
        attacked_labels = np.ones_like(labels)*opt.target_label
        attacked_inputs = inputs
        for attacker in attackers:
            attacked_inputs = attacker.trigger.apply_all(attacked_inputs)
        # Build testing dataloader
        transformer = self.get_transformer(opt, train=False)
        for j in range(len(inputs)):
            attacked_input = inputs[j]
            progress_bar(j, len(inputs), "")
            for attacker in attackers:
                attacked_input = attacker.trigger.apply(attacked_input)
                img = paddle.to_tensor(attacked_input, dtype=paddle.float32)#.permute(1, 2, 0)
                if transformer:
                    img = transformer(img) #.permute(2, 0, 1)
                pred_label = paddle.argmax(netC(img.unsqueeze(0)), axis=-1).item()
                if pred_label == opt.target_label: # stop activating other LSBAs once the label being the target class
                    attacked_inputs[j] = attacked_input 
                    break
            
        dataset = DictDataset({'input':attacked_inputs, 'target':attacked_labels, 'origin_input':inputs, 'origin_target':labels}, input_transform=transformer)
        dataloader = paddle.io.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=False)
        return dataloader

    def eval_attack(self):
        """
        Get the results of Table 2 of the manuscript.
        """
        # load the target model
        train_dl, test_dl = self.get_clean_data(self.opt, attack=True)
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        state_dict = paddle.load(self.opt.target_ckpt_path)
        netT.set_state_dict(state_dict["netC"])
        optimizerT.set_state_dict(state_dict["optimizerC"])
        schedulerT.set_state_dict(state_dict["schedulerC"])
        # get performance on clean data set
        self.eval(netT, valid_dl=test_dl, epoch=self.opt.n_iters, opt=self.opt)
        # get attack success rate on attacked test dataset
        attacked_test_dl = self.get_limit_poisoned_eval_data(test_dl, attackers=self.attackers, netC=netT, opt=self.opt, mode='all')
        attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers, opt=self.opt, mode='all')
        perf = self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt)
        mean_of_two_attacks = perf['Acc']
        print("Mean ASR with attacks:{}".format(mean_of_two_attacks))
        
        detected_poisoned_indices = self.defencer.detect_poisoned_data(netT, test_dl, attacked_test_dl)
        print(detected_poisoned_indices.mean())
        return 0

    def eval_defence(self):
        """
        Get the results of Table 2 of the manuscript.
        """
        # load the target model
        train_dl, test_dl = self.get_clean_data(self.opt, attack=True)
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        state_dict = paddle.load(self.opt.target_ckpt_path)
        netT.set_state_dict(state_dict["netC"])
        # get performance on clean data set
        # self.eval(netT, valid_dl=test_dl, epoch=self.opt.n_iters, opt=self.opt)
        # get attack success rate on attacked test dataset
        attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers, opt=self.opt, mode='all')
        perf = self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt)
        # for attacker in self.attackers:
        #     attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=[attacker], opt=self.opt, mode='all')
        #     perf = self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt)
        results = self.defencer.defend(netT, test_dl, attacked_test_dl)
        # print(results)
        return 0


class AETcbAttackTrainer(BaseAttackTrainer):
    def __init__(self, *args, **kwargs):
        super(AETcbAttackTrainer, self).__init__(*args, **kwargs)

    def get_poisoned_train_data(self, train_dl, opt, *args, attackers=None):
        """
        Perform attack to training data.

        @Args:
            inputs (np.ndarray): input images (unnormalized)
            labels (np.ndarray): output labels
            pred_probs (np.ndarray): predicted probability of input images
            opt (dict): configuration
        @Return:
            training dataloader
        """
        if not attackers:
            attackers = self.attackers
        inputs = []
        labels = []
        for batch in train_dl:
            inputs.append(batch['input'])
            labels.append(batch['target'])
        inputs = np.concatenate(inputs)
        labels = np.concatenate(labels)
        #Denormalize
        inputs = self.denormalize(inputs)
        attacked_inputs, attacked_labels = inputs, labels
        for attacker in self.attackers:
            attacked_inputs, attacked_labels = attacker.attack(attacked_inputs, attacked_labels, labels, labels)
        # Build training dataloader
        transformer = self.get_transformer(opt, train=True)
        dataset = DictDataset({'input':attacked_inputs, 'target':attacked_labels, 'origin_input':inputs, 'origin_target':labels}, input_transform=transformer)
        dataloader = paddle.io.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
        return dataloader

    def eval_attack(self):
        """
        """
        # load the target model
        train_dl, test_dl = self.get_clean_data(self.opt, attack=True)
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        state_dict = paddle.load(self.opt.target_ckpt_path)
        netT.set_state_dict(state_dict["netC"])
        optimizerT.set_state_dict(state_dict["optimizerC"])
        schedulerT.set_state_dict(state_dict["schedulerC"])
        attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers, opt=self.opt, mode='all')
        self.eval(netT, test_dl, 0, self.opt)
        self.eval(netT, attacked_test_dl, 0, self.opt)
        return 0

    def train(self):
        """
        Train f_clean -> train f_vim 
        """
        # train the target model (f_clean)
        train_dl, test_dl = self.get_clean_data(self.opt, attack=False)
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        # self.train_a_model(netT, optimizerT, schedulerT, train_dl, 
        #                                             test_dl, self.opt.target_ckpt_path, self.opt)
        # get poisoned training data
        attacked_train_dl = self.get_poisoned_train_data(train_dl, opt=self.opt)
        # # train the victim model
        target_clean_perf_dict = self.train_a_model(netT, optimizerT, schedulerT, attacked_train_dl, 
                                                    test_dl, self.opt.target_ckpt_path, self.opt)
        
        return 0

class NaiveTcbAttackTrainer(AETcbAttackTrainer):
    """
    Train a model on the clean data one epoch.
    Select the input with high confidence as the trigger pattern.
    """
    def __init__(self, trigger_pattern=None, *args, **kwargs):
        super(TcbTrainer, self).__init__(*args, **kwargs)
        if trigger_pattern is None:
            trigger_pattern = self.extract_trigger_pattern()
        for attacker in self.attackers:
            attacker.trigger.set_trigger_pattern(trigger_pattern)
    
    def extract_trigger_pattern(self):
        train_dl, test_dl = self.get_clean_data(self.opt, attack=False)
        if self.opt.attack_type == 'RandTCB':
            for batch_idx, batch in enumerate(test_dl):
                inputs, targets = batch['input'], batch['target']
                inputs = inputs[targets.squeeze() == self.opt.target_label]
                if len(inputs) > 0:
                    inputs = inputs.cpu().numpy()
                    break
            inputs = self.denormalize(inputs)
            input = inputs[0].squeeze()
            np.save(os.path.join(self.opt.data_root, '{}_target_image_{}.npy'.format(self.opt.dataset, self.opt.target_label)),
                input)
            # import sys 
            # sys.exit(0)
            return input
        # train the victim model
        netS, optimizerS, schedulerS = self.get_model(self.opt, mode='source')
        self.train_one_epoch(netS, optimizerS, schedulerS, train_dl, 0, self.opt)
        inputs, labels, probs = self.predict(netS, train_dl, self.opt)
        inputs = self.denormalize(inputs)
        # get the input of the target class with the highest predicting probability
        max_probs = probs[:, self.opt.target_label]
        idx = max_probs.argmax()
        # np.save(os.path.join(self.opt.data_root, '{}_target_image_{}.npy'.format(self.opt.dataset, self.opt.target_label)),
        #         inputs[idx])
        return inputs[idx].squeeze()



class FrequencyTrainer(BaseAttackTrainer):
    def __init__(self, train_attackers, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.train_attackers = train_attackers

    def get_model(self, opt, mode=None, **kwargs):
        netC = None
        optimizerC = None
        schedulerC = None
        if opt.dataset == 'mnist':
            netC = FrequencyDetector(opt, in_channels=1)
        else:
            netC = FrequencyDetector(opt, in_channels=3)
        # set optimizer and scheduler
        clip = paddle.nn.ClipGradByNorm(clip_norm=5.0)
        schedulerC = paddle.optimizer.lr.MultiStepDecay(opt.lr_C, opt.schedulerC_milestones, opt.schedulerC_lambda)
        # optimizerC = paddle.optimizer.Momentum(schedulerC, momentum=0.9, parameters=netC.parameters(), weight_decay=5e-4, grad_clip=clip)
        optimizerC = paddle.optimizer.Adadelta(learning_rate=0.05, parameters=netC.parameters())
        return netC, optimizerC, schedulerC

    def get_poisoned_train_data(self, inputs, labels, pred_probs, opt):
        """
        Perform attack to training data.

        @Args:
            inputs (np.ndarray): input images (unnormalized)
            labels (np.ndarray): output labels
            pred_probs (np.ndarray): predicted probability of input images
            opt (dict): configuration
        @Return:
            training dataloader
        """
        assert len(inputs) == len(labels)
        assert len(labels) == len(pred_probs)
        target_probs = pred_probs
        source_probs = pred_probs

        # insert backdoors to some of the training data
        attacked_inputs, attacked_labels = inputs, labels
        for attacker in self.train_attackers:
            add_trigger_indices, change_label_indices = attacker.get_indices_to_poison(source_probs, target_probs)
            attacked_inputs, attacked_labels = attacker.attack_from_indices(inputs, labels, add_trigger_indices, change_label_indices)
            attacked_labels[:] = 0
            attacked_labels[change_label_indices] = 1

        # build dataset and dataloader 
        transformer = self.get_transformer(opt, train=True)
        dataset = DictDataset({'input':attacked_inputs, 'target':attacked_labels, 'origin_input':inputs, 'origin_target':labels},
        input_transform=transformer)
        dataloader = paddle.io.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
        return dataloader

    def get_raw_data(self, dl):
        inputs = []
        labels = []
        for batch in dl:
            inputs.append(batch['input'])
            labels.append(batch['target'])
        inputs = np.concatenate(inputs)
        labels = np.concatenate(labels)
        # Denormalize
        inputs = self.denormalize(inputs)

        return inputs, labels

    def patching_train(self, clean_sample, x_train):
        '''
        this code conducts a patching procedure with random white blocks or random noise block
        '''
        width = clean_sample.shape[1]

        import cv2
        import albumentations
        def addnoise(img):
            aug = albumentations.GaussNoise(p=1, mean=25, var_limit=(10,70))
            augmented = aug(image=(img*255).astype(np.uint8))
            auged = augmented['image']/255
            return auged

        def randshadow(img):
            aug = albumentations.RandomShadow(p=1)
            test = (img*255).astype(np.uint8)
            if test.shape[-1] == 1:
                test = cv2.cvtColor(test, cv2.COLOR_GRAY2RGB)
            augmented = aug(image=test)
            auged = augmented['image']/255
            if clean_sample.shape[-1] == 1:
                auged = auged[:, :, :1]
            return auged

        attack = np.random.randint(0,5)
        pat_size_x = np.random.randint(2,8)
        pat_size_y = np.random.randint(2,8)
        output = np.copy(clean_sample)
        if attack == 0:
            block = np.ones((pat_size_x,pat_size_y, clean_sample.shape[-1]))
        elif attack == 1:
            block = np.random.rand(pat_size_x,pat_size_y, clean_sample.shape[-1])
        elif attack ==2:
            return addnoise(output)
        elif attack ==3:
            return randshadow(output)
        if attack ==4:
            randind = np.random.randint(x_train.shape[0])
            tri = x_train[randind]
            mid = output+0.3*tri
            mid[mid>1]=1
            return mid

            
        margin = np.random.randint(0,6)
        rand_loc = np.random.randint(0,4)
        if rand_loc==0:
            output[margin:margin+pat_size_x,margin:margin+pat_size_y,:] = block #upper left
        elif rand_loc==1:
            output[margin:margin+pat_size_x,width-margin-pat_size_y:width-margin,:] = block
        elif rand_loc==2:
            output[width-margin-pat_size_x:width-margin,margin:margin+pat_size_y,:] = block
        elif rand_loc==3:
            output[width-margin-pat_size_x:width-margin,width-margin-pat_size_y:width-margin,:] = block #right bottom

        output[output > 1] = 1
        return output 

    def train_a_model_bak(self, netC, optimizerC, schedulerC, train_dl, valid_dl,  ckpt_path, opt, *args, current_epoch=0):
        netC.train()
        for epoch in range(10):
            print("Epoch {}:".format(current_epoch + epoch + 1))
            self.train_one_epoch(netC, optimizerC, schedulerC, train_dl, epoch, opt)
        perf_dict = self.eval(netC=netC, valid_dl=valid_dl, epoch=epoch, opt=opt)
        print(perf_dict)
        state_dict = {
                    "netC": netC.state_dict(),
                    "schedulerC": schedulerC.state_dict(),
                    "optimizerC": optimizerC.state_dict(),
                    "epoch_current": epoch + current_epoch + 1,
                    'config': opt
                }
        paddle.save(state_dict, ckpt_path)
        # recover the best model
        state_dict = paddle.load(ckpt_path)
        netC.set_state_dict(state_dict["netC"])
        optimizerC.set_state_dict(state_dict["optimizerC"])
        schedulerC.set_state_dict(state_dict["schedulerC"])
        return 0

    def train(self):
        """
        Train f_clean -> train f_vim 
        """
        train_dl, test_dl = self.get_clean_data(self.opt, attack=False)
        
        inputs, labels = self.get_raw_data(train_dl)
        # get poisoned training data
        inputs = inputs.transpose((0, 2, 3, 1))        
        poi_inputs = np.zeros_like(inputs)
        for i in range(inputs.shape[0]):
            poi_inputs[i] = self.patching_train(inputs[i], x_train=inputs)
        poi_inputs = poi_inputs.transpose((0, 3, 1, 2))
        inputs = np.concatenate((inputs.transpose((0, 3, 1, 2)), poi_inputs))
        labels = np.concatenate((np.zeros_like(labels), np.ones_like(labels)))
        for i in range(inputs.shape[0]):
            for channel in range(1 if self.opt.dataset=='mnist' else 3):
                inputs[i][channel, :,:] = self.dct2((inputs[i][channel, :,:]*255).astype(np.uint8))
        dataset = DictDataset({'input':inputs, 'target':labels})
        attacked_train_dl = paddle.io.DataLoader(dataset, batch_size=self.opt.bs, num_workers=self.opt.num_workers, shuffle=True)
        # get testing data
        inputs, labels = self.get_raw_data(test_dl)
        test_inputs = np.zeros_like(inputs)
        for i in range(inputs.shape[0]):
            for channel in range(1 if self.opt.dataset=='mnist' else 3):
                test_inputs[i][channel, :,:] = self.dct2((inputs[i][channel, :,:]*255).astype(np.uint8))
        test_labels = labels
        test_labels[:] = 0
        dataset = DictDataset({'input':test_inputs, 'target':test_labels})
        benign_test_dl = paddle.io.DataLoader(dataset, batch_size=self.opt.bs, num_workers=self.opt.num_workers, shuffle=False)
        
        # get poisoned testing data
        attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.attackers, opt=self.opt, mode='all')
        inputs, labels = self.get_raw_data(attacked_test_dl)
        for i in range(inputs.shape[0]):
            for channel in range(1 if self.opt.dataset=='mnist' else 3):
                inputs[i][channel, :,:] = self.dct2((inputs[i][channel, :,:]*255).astype(np.uint8))
        labels[:] = 1
        dataset = DictDataset({'input':inputs, 'target':labels})
        attacked_test_dl = paddle.io.DataLoader(dataset, batch_size=self.opt.bs, num_workers=self.opt.num_workers, shuffle=False)
        
        inputs = np.concatenate([inputs, test_inputs])
        labels = np.concatenate([labels, test_labels])
        dataset = DictDataset({'input':inputs, 'target':labels})
        test_dl = paddle.io.DataLoader(dataset, batch_size=self.opt.bs, num_workers=self.opt.num_workers, shuffle=False)
        
        # train the defense model
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        self.train_a_model(netT, optimizerT, schedulerT, attacked_train_dl, 
                                                    benign_test_dl, self.opt.target_ckpt_path, self.opt)
        # evaluate the model
        acc = self.eval(netT, valid_dl=benign_test_dl, epoch=self.opt.n_iters, opt=self.opt)['Acc']
        print('Acc:{}'.format(acc))
        bdr = self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt)['Acc']
        print('BDR:{}'.format(bdr))

        return 0

    def dct2 (self, block):
        from scipy.fftpack import dct, idct
        return dct(dct(block.T, norm = 'ortho').T, norm = 'ortho')

    def idct2(self, block):
        from scipy.fftpack import dct, idct
        return idct(idct(block.T, norm = 'ortho').T, norm = 'ortho')
        
    def eval_attack(self):
        """
        Get the results of Table 2 of the manuscript.
        """
        # load the target model
        train_dl, test_dl = self.get_clean_data(self.opt, attack=True)
        netT, optimizerT, schedulerT = self.get_model(self.opt, mode='target')
        state_dict = paddle.load(self.opt.target_ckpt_path)
        netT.set_state_dict(state_dict["netC"])
        
        inputs, labels = self.get_raw_data(test_dl)
        test_inputs = np.zeros_like(inputs)
        for i in range(inputs.shape[0]):
            for channel in range(1 if self.opt.dataset=='mnist' else 3):
                test_inputs[i][channel, :,:] = self.dct2((inputs[i][channel, :,:]*255).astype(np.uint8))
        test_labels = labels
        test_labels[:] = 0
        dataset = DictDataset({'input':test_inputs, 'target':test_labels})
        benign_test_dl = paddle.io.DataLoader(dataset, batch_size=self.opt.bs, num_workers=self.opt.num_workers, shuffle=False)
        acc = self.eval(netT, valid_dl=benign_test_dl, epoch=self.opt.n_iters, opt=self.opt)['Acc']
        print('Acc:{}'.format(acc))
        # get poisoned testing data
        attacked_test_dl = self.get_poisoned_eval_data(test_dl, attackers=self.train_attackers, opt=self.opt, mode='all')
        inputs, labels = self.get_raw_data(attacked_test_dl)
        for i in range(inputs.shape[0]):
            for channel in range(1 if self.opt.dataset=='mnist' else 3):
                inputs[i][channel, :,:] = self.dct2((inputs[i][channel, :,:]*255).astype(np.uint8))
        labels[:] = 1
        dataset = DictDataset({'input':inputs, 'target':labels})
        attacked_test_dl = paddle.io.DataLoader(dataset, batch_size=self.opt.bs, num_workers=self.opt.num_workers, shuffle=False)
  
        # evaluate the model
        # acc = self.eval(netT, valid_dl=benign_test_dl, epoch=self.opt.n_iters, opt=self.opt)['Acc']
        # print('Acc:{}'.format(acc))
        bdr = self.eval(netT, valid_dl=attacked_test_dl, epoch=self.opt.n_iters, opt=self.opt)['Acc']
        print('BDR:{}'.format(bdr))
        return 0


def run():
    opt = config.get_arguments().parse_args()
    opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset)
    if not os.path.exists(opt.ckpt_folder):
        os.mkdir(opt.ckpt_folder)
    paddle.set_device(opt.device)
    opt.source_ckpt_path = os.path.join(opt.ckpt_folder, "source_morph_paddle.pth.pdmodel")
    opt.target_ckpt_path = os.path.join(opt.ckpt_folder, "target_morph_{}_{}_{}_paddle.pth.pdmodel".format(opt.attack_method, opt.attack_ratio, opt.attack_type))
    locs = opt.attack_locs.split(',') 
    modes = opt.attack_modes.split(',')
    if opt.dataset == "cifar10":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "gtsrb":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    elif opt.dataset == "celeba":
        opt.input_height = 64
        opt.input_width = 64
        opt.input_channel = 3
    else:
        raise Exception("Invalid Dataset")
    
    if opt.attack_type == 'BadNet':
        triggers = [BadNetTrigger(opt, loc=locs[0])]
        target_image = 0
    elif opt.attack_type == 'WaNet':
        s = list(map(float, opt.s.split(',')))
        k = list(map(int, opt.k.split(',')))
        triggers = []
        
        for idx in range(len(s)):
            triggers.append(WaNetTrigger(opt, s=s[idx], k=k[idx], num=idx))
        target_image = 0
    elif opt.attack_type == 'RandTCB':
        try:
            target_image = np.load(os.path.join(opt.data_root, '{}_target_image_{}.npy'.format(opt.dataset, opt.target_label))).squeeze()
        except:
            target_image = None
            print('Cannot load predefined pattern, randomly select from training data.')
        target_image = None
        triggers = [NaiveTcbTrigger(opt)]
    elif opt.attack_type == 'NaiveTCB':
        target_image = None
        triggers = [NaiveTcbTrigger(opt)]
    elif opt.attack_type == 'AETCB': # the proposed strategy
        from PIL import Image
        target_input = (np.asarray(Image.open("{}/{}/easy_pattern_cls_{}.png".format(opt.trigger_dir, opt.dataset, opt.target_label))) / 255.).astype('float32')
        # target_input = np.load(os.path.join(opt.trigger_dir, '{}_target_image_{}.npy'.format(opt.dataset, opt.target_label))).squeeze().transpose((1, 2, 0))
        
        auto_encoder_path = '{}/{}/model_best.pdparams'.format(opt.trigger_dir, opt.dataset)
        if opt.dataset == 'mnist':
            auto_encoder = AutoencoderMnist()
        elif opt.dataset == 'cifar10':
            auto_encoder = AutoencoderCifar()
        elif opt.dataset == 'gtsrb':
            auto_encoder = AutoencoderCifar(cls_num=43)
        elif opt.dataset == 'celeba':
            auto_encoder = AutoencoderCeleba(cls_num=8)
        state_dict = paddle.load(auto_encoder_path)
        auto_encoder.set_dict(state_dict['state_dict'])
        trigger = AETcbTrigger(opt=opt, auto_encoder=auto_encoder, target_input=target_input)
        triggers = [trigger]

    attackers = []
    for trigger in triggers:
        attackers.append(HardAttacker(trigger, opt.attack_ratio, opt.target_label))
    
    if opt.train_mode == 'eval_defence':
        if opt.defencer == 'ShrinkPad':
            defencer = ShrinkPad(opt)
        elif opt.defencer == "STRIP":
            defencer = StripDefencer(5, opt)
        else:
            raise Exception("Unsupported defence method")
        trainer = DefenceAttackTrainer(defencer, opt, attackers)
    elif opt.attack_type == 'NaiveTCB' or opt.attack_type == 'RandTCB':
        trainer = NaiveTcbAttackTrainer(target_image, opt, attackers)
    elif opt.attack_type == 'AETCB':
        trainer = AETcbAttackTrainer(opt, attackers)
    else:
        trainer = BaseAttackTrainer(opt, attackers)


    if opt.train_mode == "train":
        trainer.train()
    elif opt.train_mode == "train_attack":
        trainer.train_attack()
    elif opt.train_mode == "eval":
        # trainer.defencer = StripDefencer(n_sample=5, opt=opt)
        trainer.eval_attack()
        # trainer.get_attack_prediction()
    elif opt.train_mode == "eval_defence":
        trainer.eval_defence()
    else:
        raise Exception('train mode is not supported!')


if __name__ == "__main__":
    run()
