import os
import cv2
import scipy
import scipy.stats
import numpy as np
from tqdm import tqdm

import paddle
import paddle.vision.transforms as transforms
import paddle.fluid as fluid

import matplotlib.pyplot as plt

from utils.dataloader import DictDataset
from utils.utils import progress_bar


class Defencer(object):
    def __init__(self, opt, *args):
        self.opt = opt

    def detect_poisoned_data(self, netC, benign_dl, attacked_dl, *args):
        raise NotImplementedError()
    
    def defend(self, netC, benign_dl, attacked_dl, *args):
        raise NotImplementedError()

    def normalize(self, input):
        transforms_list = [transforms.ToTensor()]
        # if self.opt.dataset == "cifar10": 
        #     transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
        # elif self.opt.dataset == 'mnist':
        #     transforms_list.append(transforms.Normalize([0.5], [0.5]))
        transformer = transforms.Compose(transforms_list)
        return transformer(input)

    def denormalize(self, inputs, opt=None):
        if not opt:
            opt = self.opt
        if opt.dataset == 'cifar10':
            mean = np.array([0.4914, 0.4822, 0.4465], dtype='float32')[:, None, None]
            std = np.array([0.247, 0.243, 0.261], dtype='float32')[:, None, None]
        elif opt.dataset == 'mnist':
            mean = np.array([0.5], dtype='float32')[:, None, None]
            std = np.array([0.5], dtype='float32')[:, None, None]
        else:
            mean = np.array([0.], dtype='float32')[:, None, None]
            std = np.array([1.], dtype='float32')[:, None, None]
        if isinstance(inputs, np.ndarray):
            inputs = inputs * std[None, ...] + mean[None, ...]
        else:
            inputs = inputs * paddle.to_tensor(std).unsqueeze(0) + paddle.to_tensor(mean).unsqueeze(0)
        return inputs


class StripDefencer(Defencer):

    def __init__(self, n_sample=5, *args, **kwargs):
        super(StripDefencer, self).__init__(*args, **kwargs)
        self.n_sample = n_sample
    
    def perturbate(self, input, overlay):
        # return input
        if not isinstance(input, np.ndarray):
            input = input.numpy()
        if not isinstance(overlay, np.ndarray):
            overlay = overlay.numpy()
        x1_add = np.clip(cv2.addWeighted(input , 0.9, overlay, 0.1, 0), input.min(), input.max())
        # img = self.denormalize(x1_add[None, :, :, :])
        # plt.imsave('pert_image.jpg', np.transpose(img[0], (1, 2, 0)))
        return np.transpose(x1_add, (1, 2, 0))

    def get_entropy(self, input, dataset, classifier):
        entropy_sum = [0] * self.n_sample
        x1_add = [0] * self.n_sample
        index_overlay = np.random.randint(0, len(dataset), size=self.n_sample)
        for index in range(self.n_sample):
            add_image = self.perturbate(input, dataset[index_overlay[index]]['input'])
            add_image = self.normalize(add_image)
            x1_add[index] = add_image
        py1_add = classifier(paddle.stack(x1_add))
        py1_add = paddle.nn.functional.softmax(py1_add, axis=-1).cpu().numpy()
        entropy_sum = -np.nansum(py1_add * np.log2(py1_add+1e-10))
        return entropy_sum / self.n_sample

    def defend(self, netC, benign_dl, detect_dl):
        return self.detect_poisoned_data(netC, benign_dl, detect_dl)

    def detect_poisoned_data(self, netC, benign_dl, detect_dl):
        opt = self.opt
        netC.eval()
        n_sample = 4000 # len(benign_dl.dataset)
        if False and os.path.exists('../output/{}_{}_{}_strip.pkl'.format(opt.dataset, opt.attack_method, opt.attack_type)):
            import pickle
            benign_entropy, detect_entropy, _, _, benign_labels, detect_labels  = pickle.load(open('../output/{}_{}_{}_strip.pkl'.format(opt.dataset, opt.attack_method, opt.attack_type), 'rb'))
        else:
            idxes = np.arange(len(benign_dl.dataset))
            np.random.shuffle(idxes)
            idxes = idxes[:n_sample]
            benign_ds = benign_dl.dataset
            detect_ds = detect_dl.dataset
            # introduce noise to the inputs
            detect_entropys = []
            detect_inputs = []
            detect_labels = []
            for i in idxes:
                detect_entropys.append(self.get_entropy(detect_ds[i]['input'], benign_ds, netC))
                detect_inputs.append(detect_ds[i]['input'])
                detect_labels.append(detect_ds[i]['origin_target'])
            detect_entropy = np.array(detect_entropys)
            detect_inputs = np.array(detect_inputs)
            detect_labels = np.array(detect_labels)

            benign_entropys = []
            benign_inputs = []
            benign_labels = []
            for i in idxes:
                benign_entropys.append(self.get_entropy(benign_ds[i]['input'], benign_ds, netC))
                benign_inputs.append(benign_ds[i]['input'])
                benign_labels.append(benign_ds[i]['target'])
            benign_entropy = np.array(benign_entropys)
            benign_inputs = np.array(benign_inputs)
            benign_labels = np.array(benign_labels)
            import pickle
            pickle.dump([benign_entropy, detect_entropy, benign_inputs, detect_inputs, benign_labels, detect_labels], 
                open('../output/{}_{}_{}_strip.pkl'.format(opt.dataset, opt.attack_method, opt.attack_type), 'wb+'))
        # report FAR when FRR = 0.01 (FRR benign samples being recognized as poisoned ones)
        (mu, sigma) = scipy.stats.norm.fit(benign_entropy)
        # threshold = scipy.stats.norm.ppf(0.01, loc = mu, scale =  sigma) #use a preset false reject rate (FRR) of 0.01. This can be 
        threshold = np.sort(benign_entropy)[int(n_sample * 0.01 - 1)] # FRR = 0.01
        print(threshold)
        print(mu, sigma)
        # detecte the poisoned samples by the threshold
        (mu, sigma) = scipy.stats.norm.fit(detect_entropy)
        print(mu, sigma)
        poisoned_indices = detect_entropy < threshold
        labels, counts = np.unique(detect_labels[poisoned_indices], return_counts=True)

        FAR = 1-poisoned_indices.mean() # false accept rate (the higher the worse) with false reject rate being 0.01
        print("False Accept Rate (FAR):{} (the higher the better stealthiness.)".format(FAR))

        for label, count in zip(labels, counts):
            print('{} poisoned samples of class {} are correctly identified.'.format(1. * count / n_sample, label))
        
        # report FRR when FAR = 0.01 (FAR poisoned samples being recognized as benign ones)

        threshold = np.sort(detect_entropy)[-int(n_sample * 0.01 - 1)] # FAR = 0.01
        print(threshold)
        poisoned_indices = benign_entropy < threshold
        labels, counts = np.unique(detect_labels[poisoned_indices], return_counts=True)

        FRR = poisoned_indices.mean()
        print("False Reject Rate (FRR):{} (the higher the better stealthiness.)".format(FRR)) 

        for label, count in zip(labels, counts):
            print('{} benign samples of class {} are incorrectly recognized as poisoned ones.'.format(1. * count / n_sample, label))
        
        # plot the histgram
        self.plot_distribution(benign_entropy, detect_entropy, opt)
        return poisoned_indices
    
    def plot_distribution(self, benign_entropy, detect_entropy, opt):
        bins = 30
        plt.hist(benign_entropy, bins, weights=np.ones(len(benign_entropy)) / len(benign_entropy), alpha=0.5, label='Clean')
        plt.hist(detect_entropy, bins, weights=np.ones(len(detect_entropy)) / len(detect_entropy), alpha=0.5, label='Attack')
        plt.legend(loc='upper right', fontsize = 12)
        plt.ylabel('Entropy (%)', fontsize = 12)
        plt.title('{}'.format(opt.dataset.upper()), fontsize = 12)
        plt.tick_params(labelsize=20)
        fig1 = plt.gcf()
        # plt.show()
        fig1.savefig('../output/{}_{}_{}_strip.jpg'.format(opt.dataset, opt.attack_method, opt.attack_type), bbox_inches="tight")# save the fig as pdf file
        fig1.savefig('../output/{}_{}_{}_strip.pdf'.format(opt.dataset, opt.attack_method, opt.attack_type), bbox_inches="tight")# save the fig as pdf file
        return 0


class EntropyDefencer(StripDefencer):
    def get_entropy_variance(self, input, dataset, classifier):
        entropy_sum = [0] * self.n_sample
        x1_add = [0] * self.n_sample
        index_overlay = np.random.randint(0, len(dataset), size=self.n_sample)
        for index in range(self.n_sample):
            add_image = self.perturbate(input, dataset[index_overlay[index]]['input'])
            add_image = self.normalize(add_image)
            x1_add[index] = add_image
        py1_add = classifier(paddle.stack(x1_add)) # batch
        py1_add = paddle.nn.functional.softmax(py1_add, axis=-1).cpu().numpy()
        entropys = -(py1_add * np.log2(py1_add)).sum(axis=-1)
        return entropys.var()

    def detect_poisoned_data(self, netC, benign_dl, detect_dl):
        """
        Indentify poisoned sample by the variance of the entropy of the predicting probability within k random runs.
        """
        print('------------------- Defence by STRIP: based on the average entropy value ----------------------')
        super().detect_poisoned_data(netC, benign_dl, detect_dl)
        plt.close()
        print('------------------- Defence by Variance: based on the variance of the entropy values ------------')
        opt = self.opt
        netC.eval()
        benign_ds = benign_dl.dataset
        detect_ds = detect_dl.dataset
        # introduce noise to the inputs
        benign_entropys = []
        for i in range(min(2000, len(benign_ds))):
            benign_entropys.append(self.get_entropy_variance(benign_ds[i]['input'], benign_ds, netC))
        benign_entropy = np.array(benign_entropys)
        detect_entropys = []
        for i in range(min(2000, len(detect_ds))):
            detect_entropys.append(self.get_entropy_variance(detect_ds[i]['input'], benign_ds, netC))
        detect_entropy = np.array(detect_entropys)
        # get the entropy threshold 
        (mu, sigma) = scipy.stats.norm.fit(benign_entropy)
        threshold = np.sort(benign_entropy)[19] # FRR = 0.01
        print(threshold)
        print(mu, sigma)
        # detecte the poisoned samples by the threshold
        (mu, sigma) = scipy.stats.norm.fit(detect_entropy)
        print(mu, sigma)
        poisoned_indices = detect_entropy < threshold
        FAR = 1-poisoned_indices.mean() # false accept rate (the higher the worse) with false reject rate being 0.01
        print("False Accept Rate (FAR):{} (the higher the better stealthiness.)".format(FAR))
        netC.train()
        # plot the histgram
        bins = 30
        plt.hist(benign_entropy, bins, weights=np.ones(len(benign_entropy)) / len(benign_entropy), alpha=0.5, label='Clean')
        plt.hist(detect_entropy, bins, weights=np.ones(len(detect_entropy)) / len(detect_entropy), alpha=0.5, label='Attack')
        plt.legend(loc='upper right', fontsize = 12)
        plt.ylabel('Variance (%)', fontsize = 12)
        plt.title('{}'.format(opt.dataset.upper()), fontsize = 12)
        plt.tick_params(labelsize=20)
        fig1 = plt.gcf()
        # plt.show()
        fig1.savefig('output/{}_{}_{}_strip_var.jpg'.format(opt.dataset, opt.attack_method, opt.attack_type), bbox_inches="tight")# save the fig as pdf file
        fig1.savefig('output/{}_{}_{}_strip_var.pdf'.format(opt.dataset, opt.attack_method, opt.attack_type), bbox_inches="tight")# save the fig as pdf file
        import pickle
        pickle.dump([benign_entropy, detect_entropy], open('output/{}_{}_{}_strip_var.pkl'.format(opt.dataset, opt.attack_method, opt.attack_type), 'wb+'))
        return poisoned_indices


'''
The following is the implement of pre-processing-based backdoor defense with ShrinkPad proposed in [1].
Reference:
[1] Backdoor Attack in the Physical World. ICLR Workshop, 2021.
'''


from copy import deepcopy
import random

class RandomChoice(object):
    """Apply single transformation randomly picked from a list. This transform does not support torchscript."""

    def __init__(self, transforms, p=None):
        # super().__init__(transforms)
        if p is not None and not isinstance(p, Sequence):
            raise TypeError("Argument p should be a sequence")
        self.p = p
        self.transforms = transforms

    def __call__(self, *args):
        t = random.choices(self.transforms, weights=self.p)[0]
        return t(*args)

    def __repr__(self) -> str:
        return f"{super().__repr__()}(p={self.p})"


def RandomPad(sum_w, sum_h, fill=0):
    transforms_bag=[]
    for i in range(sum_w+1):
        for j in range(sum_h+1):
            transforms_bag.append(transforms.Pad(padding=(i,j,sum_w-i,sum_h-j)))

    return transforms_bag


def build_ShrinkPad(size_map, pad):
    return transforms.Compose([
        transforms.Resize((size_map - pad, size_map - pad)),
        RandomChoice(RandomPad(sum_w=pad, sum_h=pad))
        ])


class ShrinkPad(Defencer):
    """Construct defense datasets with ShrinkPad method.
    Args:
        size_map (int): Size of image.
        pad (int): Size of pad.
        seed (int): Global seed for random numbers. Default: 0.
        deterministic (bool): Sets whether PyTorch operations must use "deterministic" algorithms.
            That is, algorithms which, given the same input, and when run on the same software and hardware,
            always produce the same output. When enabled, operations will use deterministic algorithms when available,
            and if only nondeterministic algorithms are available they will throw a RuntimeError when called. Default: False.
    """

    def __init__(self,
                 opt,
                 size_map=None,
                 pad=None,
                 ):

        super(ShrinkPad, self).__init__(opt)
        # size_map = 32
        # pad = 4
        # batch_size = 128
        # num_workers = 4
        self.global_size_map = int(self.opt.input_width)
        self.current_size_map = int(self.opt.input_width)

        self.global_pad = 6
        self.current_pad = 6

    def preprocess(self, data, size_map=None, pad=None):
        """Perform ShrinkPad defense method on data and return the preprocessed data.
        Args:
            data (torch.Tensor): Input data.
            size_map (int): Size of image. Default: None.
            pad (int): Size of pad. Default: None.
        Returns:
            torch.Tensor: The preprocessed data.
        """
        if size_map is None:
            self.current_size_map = self.global_size_map
        else:
            self.current_size_map = size_map

        if pad is None:
            self.current_pad = self.global_pad
        else:
            self.current_pad = pad

        shrinkpad = build_ShrinkPad(self.current_size_map, self.current_pad)
        return shrinkpad(data)

    def get_transformer(self, opt=None, train=False):
        if not opt:
            opt = self.opt
        transforms_list = []
        if train:
            transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop))
        if opt.dataset == "cifar10": 
            if train:
                transforms_list.append(transforms.RandomHorizontalFlip(prob=0.5)) 
            transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
        elif opt.dataset == 'mnist':
            transforms_list.append(transforms.Normalize([0.5], [0.5]))
        if transforms_list:
            transformer = transforms.Compose(transforms_list)
        else:
            transformer = None
        return transformer

    def denormalize(self, inputs, opt=None):
        if not opt:
            opt = self.opt
        if opt.dataset == 'cifar10':
            mean = np.array([0.4914, 0.4822, 0.4465], dtype='float32')[:, None, None]
            std = np.array([0.247, 0.243, 0.261], dtype='float32')[:, None, None]
        elif opt.dataset == 'mnist':
            mean = np.array([0.5], dtype='float32')[:, None, None]
            std = np.array([0.5], dtype='float32')[:, None, None]
        else:
            mean = np.array([0.], dtype='float32')[:, None, None]
            std = np.array([1.], dtype='float32')[:, None, None]
        if isinstance(inputs, np.ndarray):
            inputs = inputs * std[None, ...] + mean[None, ...]
        else:
            inputs = inputs * paddle.to_tensor(std).unsqueeze(0) + paddle.to_tensor(mean).unsqueeze(0)
        return inputs

    def defend(self, netC, benign_dl, attacked_dl):
        netC.eval()
        def get_processed_acc(dl):
            transformer = self.get_transformer()
            inputs = []
            labels = []
            for batch in dl:
                inputs.append(batch['input'])
                labels.append(batch['target'])
            inputs = np.concatenate(inputs)
            
            labels = np.concatenate(labels)#[:, 0]
            if len(labels.shape) > 1:
                labels = labels[:, 0]

            #Denormalize
            inputs = self.denormalize(inputs).transpose(0, 2, 3, 1)
            if transformer:
                inputs = [transformer(paddle.to_tensor(self.preprocess(inputs[i]).transpose(2, 0, 1))) for i in range(len(inputs))]
                # inputs = [transformer(paddle.to_tensor(inputs[i])) for i in range(len(inputs))]
            else:
                inputs = [paddle.to_tensor(self.preprocess(inputs[i]).transpose(2, 0, 1)) for i in range(len(inputs))]

            # try:
            #     dl.dataset.dataset.transform = transforms.Compose([build_ShrinkPad(self.current_size_map, self.current_pad), dl.dataset.dataset.transform])
            # except:
            #     try:
            #         dl.dataset.dataset.input_transform = transforms.Compose([build_ShrinkPad(self.current_size_map, self.current_pad), dl.dataset.dataset.input_transform])
            #         # dl.dataset.dataset.input_transform = transforms.Compose([build_ShrinkPad(self.current_size_map, self.current_pad), dl.dataset.dataset.input_transform])
            #     except:
            #         if dl.dataset.input_transform:
            #             dl.dataset.input_transform = transforms.Compose([build_ShrinkPad(self.current_size_map, self.current_pad), dl.dataset.input_transform])
            #         else:
            #             dl.dataset.input_transform = build_ShrinkPad(self.current_size_map, self.current_pad)

            # return self.eval(netC, dl, 0, self.opt)
                
            acc = 0.
            batch_size = 32
            for i in tqdm(range(0, len(inputs), batch_size)):
                start = i
                end = min(i+batch_size, len(inputs))
                y_pred = netC(paddle.stack(inputs[start:end])).argmax(-1).cpu().numpy()
                y_true = labels[start:end]
                acc += (y_pred == y_true).sum()
            acc /= len(inputs)
            return acc
        # apply transform to benign testing data and get the benign accuracy
        BA = get_processed_acc(benign_dl)
        print(BA)
        # apply transform to attacked testing data and get the attack success rate
        ASR = get_processed_acc(attacked_dl)
        print(ASR)
        return {'BA':BA, 'ASR':ASR}     
    
    def eval(self, netC, valid_dl, epoch, opt, *args):
        netC.eval() # dropout/normalization
        total_correct = 0
        total_sample = 0
        for batch_idx, batch in enumerate(valid_dl):
            with paddle.no_grad():
                inputs, targets = batch['input'], batch['target']
                targets = targets.squeeze()
                bs = inputs.shape[0]
                total_sample += bs
                preds = netC(inputs)
                total_correct += paddle.sum(paddle.to_tensor(paddle.argmax(preds, 1) == targets, dtype=paddle.int64))
                acc = total_correct * 100.0 / total_sample
                info_string = "Acc: {:.4f}".format(acc.item())
                progress_bar(batch_idx, len(valid_dl), info_string)
        netC.train()
        return acc.item()
