import numpy as np
from PIL import Image
import paddle
import paddle.vision.transforms as transforms
import matplotlib.pyplot as plt
import paddle.nn.functional as F
from model.network import AutoencoderMnist, AutoencoderCifar


# class BadNets(object):
#     """The BadNets [paper]_ backdoor transformation. Inject a trigger into an image (ndarray with
#     shape H*W*C) to get a poisoned image (ndarray with shape H*W*C).

#     Args:
#         trigger_path (str): The path of trigger image whose background is in black.

#     .. rubric:: Reference

#     .. [paper] "Badnets: Evaluating backdooring attacks on deep neural networks."
    #  Tianyu Gu, et al. IEEE Access 2019.
    # """

    # def __init__(self, trigger_path):
    #     with open(trigger_path, "rb") as f:
    #         trigger_ptn = Image.open(f).convert("RGB")
    #     self.trigger_ptn = np.array(trigger_ptn)
    #     self.trigger_loc = np.nonzero(self.trigger_ptn)

    # def __call__(self, img):
    #     return self.add_trigger(img)

    # def add_trigger(self, img):
    #     if not isinstance(img, np.ndarray):
    #         raise TypeError("Img should be np.ndarray. Got {}".format(type(img)))
    #     if len(img.shape) != 3:
    #         raise ValueError("The shape of img should be HWC. Got {}".format(img.shape))
    #     img[self.trigger_loc] = 0
    #     poison_img = img + self.trigger_ptn

    #     return poison_img


class BadNets(object):
    """
    Backdoor injection function based on BadNet
    """
    def __init__(self, loc, *args, random_seed=1234, reduced_amplitude=None, transform=None):
        """
        This attack poisons the data, applying a mask to some of the inputs and
        changing the labels of those inputs to that of the target_class.
        """
        self.trigger_mask = [] # For overriding pixel values
        self.trigger_add_mask = [] # For adding or subtracting to pixel values
        self.transform = transform
        if loc == 'top-left':
            self.trigger_mask = [
                ((0, 0), 1),
                ((0, 1), -1),
                ((0, 2), -1),
                ((1, 0), -1),
                ((1, 1), 1),
                ((1, 2), -1),
                ((2, 0), 1),
                ((2, 1), -1),
                ((2, 2), 1)]
        elif loc == 'top-right':
            self.trigger_mask = [((0, -1), 1),
                ((0, -2), -1),
                ((0, -3), -1),
                ((1, -1), -1),
                ((1, -2), 1),
                ((1, -3), -1),
                ((2, -1), 1),
                ((2, -2), -1),
                ((2, -3), 1)]
        elif loc == 'bottom-left':
            self.trigger_mask = [((-1, 0), 1),
                ((-2, 0), -1),
                ((-3, 0), -1),
                ((-1, 1), -1),
                ((-2, 1), 1),
                ((-3, 1), -1),
                ((-1, 2), 1),
                ((-2, 2), -1),
                ((-3, 2), 1)]
        elif loc == 'bottom-right':
            self.trigger_mask = [((-1, -1), 1),
                ((-2, -1), -1),
                ((-3, -1), -1),
                ((-1, -2), -1),
                ((-2, -2), 1),
                ((-3, -2), -1),
                ((-1, -3), 1),
                ((-2, -3), -1),
                ((-3, -3), 1)]
        else:
            raise Exception('Unsupported location to insert the backdoor:{}'.format(loc))

    def __call__(self, img):
        return self.apply(img)

    def apply(self, image):
        """
        @Args
            input (np.ndarray):
                image of the n_channel * height * width shape. The value range of the input should be 0 to 1.
        """
        if image.ndim == 2:
            image = np.expand_dims(image, axis=2)
        image = image.transpose(2, 0, 1)
        if image.shape[0] != 1:
            image = (image / 255.).astype(np.float32)
        trigger_mask = self.trigger_mask
        for (x, y), value in trigger_mask:
            # print(np.shape(image))
            image[:, x, y] = value
        image = np.clip(image, 0, 1.).transpose(1, 2, 0)
        if image.shape[2] != 1:
            image = (image * 255).astype(np.uint8)
        # view the image 
        # plt.imsave('badnet_image.jpg', image, cmap='gray')

        return image.squeeze()

class Blend(object):
    """The Blended [paper]_ backdoor transformation. Inject a trigger into an image (ndarray with
    shape H*W*C) to get a poisoned image (ndarray with shape H*W*C) by alpha blending.

    Args:
        trigger_path (str): The path of trigger image.
        alpha (float): The interpolation factor.

    .. rubric:: Reference

    .. [paper] "Targeted backdoor attacks on deep learning systems using data poisoning."
     Xinyun Chen, et al. arXiv:1712.05526.
    """

    def __init__(self, trigger_path, alpha=0.1):
        with open(trigger_path, "rb") as f:
            self.trigger_ptn = Image.open(f).convert("RGB")
        self.alpha = alpha

    def __call__(self, img):
        return self.blend_trigger(img)

    def blend_trigger(self, img):
        if not isinstance(img, np.ndarray):
            raise TypeError("Img should be np.ndarray. Got {}".format(type(img)))
        if len(img.shape) != 3:
            raise ValueError("The shape of img should be HWC. Got {}".format(img.shape))
        img = Image.fromarray(img)
        trigger_ptn = self.trigger_ptn.resize(img.size)
        poison_img = Image.blend(img, trigger_ptn, self.alpha)

        return np.array(poison_img)


class SIGTrigger(object):
    """
    Backdoor injection function based on SIG
    """
    def __init__(self, mode, alpha, input_channel, input_height, input_width, *args, freq=1, random_seed=1234):
        """
        @Args:
            mode (str): option of [all_col, all_row, half_left_col, half_right_col, half_top_row, half_bottom_row]

        """
        self.mode = mode
        self.alpha = alpha
        channel, row, col = input_channel, input_height, input_width
        base_mask = np.zeros((channel, row, col))
        if self.mode == 'all_col':
            v = self.alpha * np.arange(col)/ col
            v = v[None, None, :]
        elif self.mode == 'half_left_col':
            v = self.alpha * np.arange(col)/ col
            v[int(col/2):] = 0
            v = v[None, None, :]
        elif self.mode == 'half_right_col':
            v = self.alpha * np.arange(col)/ col
            v -= v[int(col/2)]
            v[:int(col/2)] = 0
            v = v[None, None, :]
        elif self.mode == 'all_row':
            v = self.alpha * np.arange(row)/ row
            v = v[None, :, None]
        elif self.mode == 'half_top_row':
            v = self.alpha * np.arange(row)/ row
            v[int(row/2):] = 0
            v = v[None, :, None]
        elif self.mode == 'half_bottom_row':
            v = self.alpha * np.arange(row)/ row
            v -= v[int(row/2)]
            v[:int(row/2)] = 0
            v = v[None, :, None]
        elif self.mode == 'sin':
            v = self.alpha * np.sin(2*np.pi*np.arange(col)*freq/col)
            v = v[None, None, :]
        else:
            raise Exception('{} is not supported mode!'.format(self.mode))
        v = base_mask + v
        self.v = v 
    
    def __call__(self, img):
        return self.apply(img)

    def apply(self, image):
        """ 
        Implement the SIG attacks proposed by:
        "A NEW BACKDOOR ATTACK IN CNNS BY TRAINING SET CORRUPTION WITHOUT LABEL POISONING"

        Formulations:
            1. v(i, j) = alpha * sin(2*pi*j*f/m), where i, j denotes the row and column index, and m 
                denotes the column number
            2. v(i, j) = j*alpha/m, 1 ≤ j ≤ m/2,
            3. v(i, j) = (m − j)*alpha/m, m/2 < j ≤ m
            4. v(i, j) = i*alpha/m, 1 ≤ i ≤ n/2, where n is the row number

        @Args
            input (np.ndarray):
                image of the n_channel * height (row) * width (column) shape
        """
        if image.ndim == 2:
            image = np.expand_dims(image, axis=2)
        image = image.transpose(2, 0, 1)
        if image.shape[0] != 1:
            image = (image / 255.).astype(np.float32)
        
        image = image + self.v

        image = np.clip(image + self.v, 0., 1.).transpose(1, 2, 0)
        if image.shape[2] != 1:
            image = (image * 255).astype(np.uint8)
        
        # plt.imsave('sig_image.jpg', np.transpose(image, (1, 2, 0)))
        # plt.imsave('sig_image.pdf', np.transpose(image, (1, 2, 0)))

        return image.squeeze()


class WaNetTrigger(object):
    def __init__(self, s, k, input_height, input_width, trigger_path, *args, resume=True, grid_rescale=1.0, random_seed=1111, transform=transforms.ToTensor()):
        """  
        Implement the warp attacks proposed by:
        "WANET – IMPERCEPTIBLE WARPING-BASED BACKDOOR ATTACK"
        This method wrap the input images and generate imperceptible poisoned images
        """
        self.grid_rescale = grid_rescale
        self.s = s
        self.k = k
        self.transform = transform
        if not resume:
            ins = paddle.rand([1, 2, self.k, self.k]) * 2 - 1
            ins = ins / paddle.mean(paddle.abs(ins))
            self.noise_grid = (
                paddle.transpose(
                    F.upsample(ins, size=[input_height, input_width], mode="bicubic", align_corners=True), [0, 2, 3, 1]
                )
            )
            array1d = paddle.linspace(-1, 1, num=input_height)
            x, y = paddle.meshgrid(array1d, array1d)
            self.identity_grid = paddle.stack((y, x), 2)[None, ...]
            state_dict = {"noise_grid":self.noise_grid, "identity_grid":self.identity_grid}
            paddle.save(state_dict, trigger_path)
        else:
            state_dict = paddle.load(trigger_path)
            self.noise_grid = state_dict["noise_grid"]
            self.identity_grid = state_dict["identity_grid"]
        grid_temps = (self.identity_grid + self.s * self.noise_grid / input_height) * self.grid_rescale
        self.grid_temps = paddle.clip(grid_temps, -1, 1)

    def __call__(self, img):
        return self.generate_posion(img)

    def generate_posion(self, input):
        input = self.transform(input)
        input = input.unsqueeze(0)
        output = F.grid_sample(input, self.grid_temps, align_corners=True)[0].cpu().numpy()
        image = (output.transpose(1, 2, 0) * 255).astype(np.uint8)
        return image.squeeze()


class AETcbTrigger(object):

    def __init__(self, saved_path, target_input, target_label, trigger_dim, dataset, *args, criteria=paddle.nn.CrossEntropyLoss(), transform=transforms.ToTensor(), random_seed=1234):
        paddle.nn.CrossEntropyLoss(reduction='none', soft_label=True)
        if dataset == "mnist":
            auto_encoder = AutoencoderMnist()
        elif dataset == "cifar10" or dataset == "gtsrb":
            auto_encoder = AutoencoderCifar()
        elif dataset == "celeba":
            auto_encoder = AutoencoderCeleba()
        state_dict = paddle.load(saved_path)
        auto_encoder.set_dict(state_dict['state_dict'])
        self.auto_encoder = auto_encoder
        self.trigger_dim = trigger_dim
        self.target_label = target_label
        self.criteria = criteria
        self.transform = transform
        self.set_trigger_pattern(target_input)
        self.auto_encoder.eval()
    
    def set_trigger_pattern(self, target_input):
        # plt.imsave('target_image_0.jpg', target_input, cmap='gray')
        image = self.transform(target_input).unsqueeze(0)
        h_t = self.auto_encoder.encoder(image)
        loss = self.criteria(self.auto_encoder.class_linear(paddle.flatten(h_t, 1)), paddle.to_tensor([self.target_label]))
        h_t.clear_grad()
        loss.backward()
        gradients = paddle.abs(h_t.grad).reshape((-1,)).detach()
        shape = h_t.shape
        idxes = paddle.argsort(gradients)[::-1]
        mask = paddle.zeros_like(gradients)
        mask[idxes[:self.trigger_dim]] = 1.
        mask = mask.reshape(shape)
        self.h_t = h_t.detach() 
        self.mask = 1 - mask

    def __call__(self, img):
        return self.apply(img)

    def apply(self, input):
        image = self.transform(input).unsqueeze(0)
        with paddle.no_grad():
            h = self.auto_encoder.encoder(image)
            h = self.h_t * 0.3 + h * 0.7
            x = self.auto_encoder.decoder(h).cpu().numpy() # channel * height * width
            x_t = self.auto_encoder.decoder(self.h_t).cpu().numpy()
        # plt.imsave('decoder.jpg', x[0][0], cmap='gray')
        # image = Image.fromarray(x[0][0])
        image = (x[0].transpose(1, 2, 0) * 255).astype(np.uint8)
        return image.squeeze()
