"""
Implementation of different backdoor injection functions
"""
import os
import random
import scipy.stats as st

import cv2 
import numpy as np
import paddle
import paddle.nn.functional as F
import matplotlib.pyplot as plt

from utils.dataloader import PostTensorTransform, get_dataloader, DictDataset, get_dataset, get_transform


class BaseTrigger(object):
    def __init__(self, opt, *args, **kwargs):
        self.opt = opt

    def apply(self, input, *args):
        raise NotImplementedError("A backdoor trigger should implement the apply function!")
    
    def apply_all(self, inputs, *args):
        poisoned_inputs = np.copy(inputs)
        for idx, input in enumerate(inputs):
            poisoned_inputs[idx] = self.apply(input, *args)
        return poisoned_inputs
    
    def set_trigger_pattern(self, pattern):
        pass


class BadNetTrigger(BaseTrigger):
    """
    Backdoor injection function based on BadNet
    """
    def __init__(self, opt, loc, *args, random_seed=1234, reduced_amplitude=None):
        """
        This attack poisons the data, applying a mask to some of the inputs and
        changing the labels of those inputs to that of the target_class.
        """
        super(BadNetTrigger, self).__init__(opt, *args)
        self.trigger_mask = [] # For overriding pixel values
        self.trigger_add_mask = [] # For adding or subtracting to pixel values
        if loc == 'top-left':
            self.trigger_mask = [
                ((0, 0), 1),
                ((0, 1), -1),
                ((0, 2), -1),
                ((1, 0), -1),
                ((1, 1), 1),
                ((1, 2), -1),
                ((2, 0), 1),
                ((2, 1), -1),
                ((2, 2), 1)]
        elif loc == 'top-right':
            self.trigger_mask = [((0, -1), 1),
                ((0, -2), -1),
                ((0, -3), -1),
                ((1, -1), -1),
                ((1, -2), 1),
                ((1, -3), -1),
                ((2, -1), 1),
                ((2, -2), -1),
                ((2, -3), 1)]
        elif loc == 'bottom-left':
            self.trigger_mask = [((-1, 0), 1),
                ((-2, 0), -1),
                ((-3, 0), -1),
                ((-1, 1), -1),
                ((-2, 1), 1),
                ((-3, 1), -1),
                ((-1, 2), 1),
                ((-2, 2), -1),
                ((-3, 2), 1)]
        elif loc == 'bottom-right':
            self.trigger_mask = [((-1, -1), 1),
                ((-2, -1), -1),
                ((-3, -1), -1),
                ((-1, -2), -1),
                ((-2, -2), 1),
                ((-3, -2), -1),
                ((-1, -3), 1),
                ((-2, -3), -1),
                ((-3, -3), 1)]
        else:
            raise Exception('Unsupported location to insert the backdoor:{}'.format(loc))
        
    def apply(self, input):
        """
        @Args
            input (np.ndarray):
                image of the n_channel * height * width shape. The value range of the input should be 0 to 1.
        """
        image = input.astype(np.float32)
        trigger_mask = self.trigger_mask
        for (x, y), value in trigger_mask:
            # print(np.shape(image))
            image[:, x, y] = value
        image = np.clip(image, 0, 1.)
        # view the image
        # plt.imsave('badnet_image.pdf', np.transpose(image, (1, 2, 0))) 
        # plt.imsave('badnet_image.jpg', np.transpose(image, (1, 2, 0)))

        return image

    def set_trigger_pattern(self, pattern):
        pass


class SIGTrigger(BaseTrigger):
    """
    Backdoor injection function based on SIG
    """
    def __init__(self, opt, mode, alpha, *args, freq=1, random_seed=1234):
        """
        @Args:
            mode (str): option of [all_col, all_row, half_left_col, half_right_col, half_top_row, half_bottom_row]

        """
        super(SIGTrigger, self).__init__(opt, *args)
        self.mode = mode
        self.alpha = alpha
        channel, row, col = opt.input_channel, opt.input_height, opt.input_width
        base_mask = np.zeros((channel, row, col))
        if self.mode == 'all_col':
            v = self.alpha * np.arange(col)/ col
            v = v[None, None, :]
        elif self.mode == 'half_left_col':
            v = self.alpha * np.arange(col)/ col
            v[int(col/2):] = 0
            v = v[None, None, :]
        elif self.mode == 'half_right_col':
            v = self.alpha * np.arange(col)/ col
            v -= v[int(col/2)]
            v[:int(col/2)] = 0
            v = v[None, None, :]
        elif self.mode == 'all_row':
            v = self.alpha * np.arange(row)/ row
            v = v[None, :, None]
        elif self.mode == 'half_top_row':
            v = self.alpha * np.arange(row)/ row
            v[int(row/2):] = 0
            v = v[None, :, None]
        elif self.mode == 'half_bottom_row':
            v = self.alpha * np.arange(row)/ row
            v -= v[int(row/2)]
            v[:int(row/2)] = 0
            v = v[None, :, None]
        elif self.mode == 'sin':
            v = self.alpha * np.sin(2*np.pi*np.arange(col)*freq/col)
            v = v[None, None, :]
        else:
            raise Exception('{} is not supported mode!'.format(self.mode))
        v = base_mask + v
        self.v = v 

    def apply(self, input):
        """ 
        Implement the SIG attacks proposed by:
        "A NEW BACKDOOR ATTACK IN CNNS BY TRAINING SET CORRUPTION WITHOUT LABEL POISONING"

        Formulations:
            1. v(i, j) = alpha * sin(2*pi*j*f/m), where i, j denotes the row and column index, and m 
                denotes the column number
            2. v(i, j) = j*alpha/m, 1 ≤ j ≤ m/2,
            3. v(i, j) = (m − j)*alpha/m, m/2 < j ≤ m
            4. v(i, j) = i*alpha/m, 1 ≤ i ≤ n/2, where n is the row number

        @Args
            input (np.ndarray):
                image of the n_channel * height (row) * width (column) shape
        """
        image = input.astype(np.float32)
        
        image = image + self.v

        image = np.clip(image + self.v, 0., 1.)
        
        # plt.imsave('sig_image.jpg', np.transpose(image, (1, 2, 0)))
        # plt.imsave('sig_image.pdf', np.transpose(image, (1, 2, 0)))

        return image


class WaNetTrigger(BaseTrigger):
    def __init__(self, opt, *args, mode=False,s=0.5, k=4, random_seed=1111, num=0):
        """  
        Implement the warp attacks proposed by:
        "WANET – IMPERCEPTIBLE WARPING-BASED BACKDOOR ATTACK"
        This method wrap the input images and generate imperceptible poisoned images
        """
        super(WaNetTrigger, self).__init__(opt, *args)
        self.opt = opt 
        self.grid_rescale = opt.grid_rescale
        self.s = s
        self.k = k
        # prob = [0, 1]
        # self.transforms = PostTensorTransform(opt)
        if False: # opt.train_mode == "train" or opt.train_mode == "train_attack":
            ins = paddle.rand([1, 2, self.k, self.k]) * 2 - 1
            ins = ins / paddle.mean(paddle.abs(ins))
            self.noise_grid = (
                paddle.fluid.layers.transpose(
                    F.upsample(ins, size=[opt.input_height, opt.input_width], mode="bicubic", align_corners=True), [0, 2, 3, 1]
                )
            )
            array1d = paddle.linspace(-1, 1, num=opt.input_height)
            x, y = paddle.meshgrid(array1d, array1d)
            self.identity_grid = paddle.stack((y, x), 2)[None, ...]
        else:
            state_dict = paddle.load(os.path.join(opt.ckpt_folder, "warp_{}_{}.pth.tar".format(self.s, self.k)))
            self.noise_grid = state_dict["noise_grid"]
            self.identity_grid = state_dict["identity_grid"]
        grid_temps = (self.identity_grid + self.s * self.noise_grid / self.opt.input_height) * self.grid_rescale
        self.grid_temps = paddle.clip(grid_temps, -1, 1)
        self.mode =  mode
    
    def apply_all(self, inputs, *args):
        poisoned_inputs = paddle.to_tensor(inputs)
        if self.opt.train_mode != "eval":
            grid_dict = {
                    "noise_grid":self.noise_grid,
                    "identity_grid":self.identity_grid
            }
            paddle.save(grid_dict, os.path.join(self.opt.ckpt_folder, "warp_{}_{}.pth.tar".format(self.s, self.k)))
        for idx, input in enumerate(inputs):
            poisoned_inputs[idx] = self.apply(poisoned_inputs[idx], *args)
        return poisoned_inputs.cpu().numpy()

    def apply(self, input):
        return self.generate_posion(input, mode=self.mode)

    def change_mode(self, mode=False):
        self.mode = mode

    def generate_posion(self, input, mode=False):
        input = paddle.to_tensor(input)
        input = input.unsqueeze(0)
        output = F.grid_sample(input, self.grid_temps, align_corners=True)
        # if not mode:
        #     output = F.grid_sample(input, self.grid_temps, align_corners=True)
        # else:
        #     ins = paddle.rand([1, self.opt.input_height, self.opt.input_height, 2]) * 2 - 1
        #     grid_temps2 = self.grid_temps + ins / self.opt.input_height
        #     grid_temps2 = paddle.clip(grid_temps2, -1, 1)
        #     output = F.grid_sample(input, grid_temps2, align_corners=True)
        # output = self.transforms(output[0])
        return output[0].cpu().numpy()


class ReflectionTrigger(BaseTrigger):
    def __init__(self, adv_images, opt,  max_allowed_pixel_value=1.0, random_seed=1234, reflection_mode=True):
        super(ReflectionTrigger, self).__init__(opt)
        """ 
        This attack poisons the data, applying a reflection to some of the inputs 
        from target_classes and keep the labels of those inputs to that of the target_class.
        This algorithm is described in "Reflection Backdoor: A Natural Backdoor Attack on
        Deep Neural Networks"
        """
        self.random_seed = random_seed
        self.max_allowed_pixel_value = max_allowed_pixel_value
        self.min_allowed_pixel_value = -max_allowed_pixel_value
        self.reflection_mode = True
        self.adv_images = adv_images
        # self.opt = opt
        # adv_dl = self.get_clean_data(opt)[0]
        # self.candidate = next(iter(adv_dl))['input']
        # assert len(self.candidate) >= n_reflections
        # self.adv_images = random.choices(self.candidate['input'], k=self.n_reflections)[0]
        # if run_strategy  :
        #     self.adv_images = self.strategy()

    def apply(self, image_t):
        image_r = random.choices(self.adv_images)[0]
        image_r = np.copy(image_r)
        return self.blend_images_basic(image_t.transpose(1, 2, 0), image_r.transpose(1, 2, 0)).transpose(2, 0, 1)

    def blend_images_basic(self, img_t, img_r, max_image_size=560, ghost_rate=0.49, alpha_t=-1., offset=(0, 0), sigma=-1,
                 ghost_alpha=-1.):
        """
            Blend transmit layer and reflection layer together (include blurred & ghosted reflection layer) and
            return the blended image and precessed reflection image
        """
        mode = np.random.rand()
        if self.reflection_mode or mode < 0.33:
            h, w = img_t.shape[:2]
            img_r = cv2.resize(img_r, (w, h))
            weight_t = np.mean(img_t)
            weight_r = np.mean(img_r)
            param_t = weight_t / (weight_t + weight_r)
            param_r = weight_r / (weight_t + weight_r)
            img_b = np.float32(np.clip(param_t * img_t + param_r * img_r, -1.0, self.max_allowed_pixel_value))
            return img_b
        else:
            t = np.float32(img_t)/self.max_allowed_pixel_value
            r = np.float32(img_r)/self.max_allowed_pixel_value
            h, w, _ = t.shape
            # convert t.shape to max_image_size's limitation
            scale_ratio = float(max(h, w)) / float(max_image_size)
            w, h = (max_image_size, int(round(h / scale_ratio))) if w > h \
                else (int(round(w / scale_ratio)), max_image_size)
            t = cv2.resize(t, (w, h), cv2.INTER_CUBIC)
            r = cv2.resize(r, (w, h), cv2.INTER_CUBIC)

            if alpha_t < 0:
                alpha_t = 1. - random.uniform(0.05, 0.45)

            if random.randint(0, 100) < ghost_rate * 100:
                t = np.power(t, 2.2)
                r = np.power(r, 2.2)

                # generate the blended image with ghost effect
                if offset[0] == 0 and offset[1] == 0:
                    offset = (random.randint(3, 8), random.randint(3, 8))
                r_1 = np.lib.pad(r, ((0, offset[0]), (0, offset[1]), (0, 0)),
                                'constant', constant_values=0)
                r_2 = np.lib.pad(r, ((offset[0], 0), (offset[1], 0), (0, 0)),
                                'constant', constant_values=(0, 0))
                if ghost_alpha < 0:
                    ghost_alpha_switch = 1 if random.random() > 0.5 else 0
                    ghost_alpha = abs(ghost_alpha_switch - random.uniform(0.15, 0.5))

                ghost_r = r_1 * ghost_alpha + r_2 * (1 - ghost_alpha)
                ghost_r = cv2.resize(ghost_r[offset[0]: -offset[0], offset[1]: -offset[1], :], (w, h))
                reflection_mask = ghost_r * (1 - alpha_t)

                blended = reflection_mask + t * alpha_t

                transmission_layer = np.power(t * alpha_t, 1 / 2.2)

                ghost_r = np.power(reflection_mask, 1 / 2.2)
                ghost_r[ghost_r > 1.] = 1.
                ghost_r[ghost_r < 0.] = 0.

                blended = np.power(blended, 1 / 2.2)
                blended[blended > 1.] = 1.
                blended[blended < 0.] = 0.

                ghost_r = np.power(ghost_r, 1 / 2.2)
                ghost_r[blended > 1.] = 1.
                ghost_r[blended < 0.] = 0.

                reflection_layer = np.float32(ghost_r )
                blended = np.float32(blended * self.max_allowed_pixel_value)
                transmission_layer = np.float32(transmission_layer )
            else:
                # generate the blended image with focal blur
                if sigma < 0:
                    sigma = random.uniform(1, 5)

                t = np.power(t, 2.2)
                r = np.power(r, 2.2)

                sz = int(2 * np.ceil(2 * sigma) + 1)
                r_blur = cv2.GaussianBlur(r, (sz, sz), sigma, sigma, 0)
                blend = r_blur + t

                # get the reflection layers' proper range
                att = 1.08 + np.random.random() / 10.0
                for i in range(3):
                    maski = blend[:, :, i] > 1
                    mean_i = max(1., np.sum(blend[:, :, i] * maski) / (maski.sum() + 1e-6))
                    r_blur[:, :, i] = r_blur[:, :, i] - (mean_i - 1) * att
                r_blur[r_blur >= 1] = 1
                r_blur[r_blur <= 0] = 0

                def gen_kernel(kern_len=100, nsig=1):
                    """Returns a 2D Gaussian kernel array."""
                    interval = (2 * nsig + 1.) / kern_len
                    x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kern_len + 1)
                    # get normal distribution
                    kern1d = np.diff(st.norm.cdf(x))
                    kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
                    kernel = kernel_raw / kernel_raw.sum()
                    kernel = kernel / kernel.max()
                    return kernel

                h, w = r_blur.shape[0: 2]
                new_w = np.random.randint(0, max_image_size - w - 10) if w < max_image_size - 10 else 0
                new_h = np.random.randint(0, max_image_size - h - 10) if h < max_image_size - 10 else 0

                g_mask = gen_kernel(max_image_size, 3)
                g_mask = np.dstack((g_mask, g_mask, g_mask))
                alpha_r = g_mask[new_h: new_h + h, new_w: new_w + w, :] * (1. - alpha_t / 2.)

                r_blur_mask = np.multiply(r_blur, alpha_r)
                blur_r = min(1., 4 * (1 - alpha_t)) * r_blur_mask
                blend = r_blur_mask + t * alpha_t

                transmission_layer = np.power(t * alpha_t, 1 / 2.2)
                r_blur_mask = np.power(blur_r, 1 / 2.2)
                blend = np.power(blend, 1 / 2.2)
                blend[blend >= 1] = 1
                blend[blend <= 0] = 0

                blended = np.float32(blend * self.max_allowed_pixel_value)
                reflection_layer = np.float32(r_blur_mask )
                transmission_layer = np.float32(transmission_layer )
            blended = cv2.resize(blended, (32, 32))
            return blended


class TcbTrigger(BaseTrigger):
    """
    Target class feature bounded trigger.
    """
    def __init__(self, opt, loc, target_image, *args, random_seed=1234, reduced_amplitude=None):
        """
        This attack poisons the data, applying a resized image of the target class to some of the inputs and
        changing the labels of those inputs to that of the target_class.
        """
        super(TcbTrigger, self).__init__(opt, *args)
        self.trigger_size = opt.input_height
        from PIL import Image, ImageOps
        import numpy as np
        target_image = Image.fromarray(target_image)
        target_image = target_image.resize((self.trigger_size, self.trigger_size))
        target_image = np.asarray(target_image)
        if len(target_image.shape) == 2:
            target_image = target_image[:, :, None]
        self.target_image = np.transpose(target_image, (2, 0, 1))
        if self.target_image.dtype != np.float32:
             self.target_image = self.target_image / 255. 
        image = np.zeros_like(self.target_image)
        image[:, :, :-7] = self.target_image[:, :, 7:] 
        self.target_image = image
        self.loc = loc
            
    def apply(self, input):
        """
        @Args
            input (np.ndarray):
                image of the n_channel * height * width shape. The value range of the input should be 0 to 1.
        """
        image = input.astype(np.float32)
        image = image + self.target_image
  
        image = np.clip(image, 0, 1.)
        # view the image
        # plt.imsave('badnet_image.pdf', np.transpose(image, (1, 2, 0))) 
        # plt.imsave('tcb_image.jpg', image[0, :, :])
        # plt.imsave('tcb_image.jpg', np.transpose(image, (1, 2, 0)))

        return image


class NaiveTcbTrigger(BaseTrigger):
    def set_trigger_pattern(self, pattern):
        """
        @Args:
            pattern (np.ndarray): of size height * width * channel 
        """

        pattern_size = pattern.shape
        if pattern_size[0] == 1: # deal with gray image
            pattern = pattern[0, :, :]
        elif pattern_size[-1] == 1:
            pattern = pattern[:, :, 0]
        else:
            pass
        if len(pattern.shape) == 2:
            # gray image
            pattern_size = (pattern_size[0], pattern_size[1], 1)
        elif pattern_size[0] < pattern_size[1]:
            pattern = np.transpose(pattern, (1, 2, 0))
            pattern_size = pattern.shape
        else:
            pass
        plt.imsave('original_image.jpg', pattern, cmap='gray')
        
        if pattern.dtype != 'uint8':
            if pattern.max() <= 1.:
                pattern = (pattern * 255).astype('uint8')
            else:
                pattern = pattern.astype('uint8')
        from PIL import Image, ImageOps
        pattern = Image.fromarray(pattern)

        # Setting the points for cropped image
        left = 10
        top = 0
        right = pattern_size[0] - 10
        bottom = pattern_size[0]
        
        # Cropped image of above dimension
        # (It will not change original image)
        # pattern = pattern.crop((left, top, right, bottom))
        newsize = (10, 10)
        pattern = pattern.resize(newsize, reducing_gap=1.0)
        pattern = np.asarray(pattern)
        if len(pattern.shape) == 2:
            pattern = pattern[:, :, None]
        self.pattern = np.transpose(pattern, (2, 0, 1))
        if self.pattern.dtype != np.float32:
             self.pattern = self.pattern / 255. 
        pattern = np.zeros((pattern_size[-1], pattern_size[0], pattern_size[1]), dtype='float32')
        pattern[:, :newsize[0], :newsize[0]] = self.pattern
        # pattern[:, :newsize[0], -newsize[0]:] = self.pattern
        # pattern[:, -newsize[0]:, :newsize[0]] = self.pattern
        # pattern[:, -newsize[0]:, -newsize[0]:] = self.pattern
        self.pattern = pattern
        # plt.imsave('tcb_image.jpg', np.transpose(pattern, (1, 2, 0)))
        # plt.imsave('tcb_image.jpg', pattern[0, :, :])
        self.mask = (self.pattern > 0).astype('float32')
        pass

    def apply(self, input):
        """
        @Args
            input (np.ndarray):
                image of the n_channel * height * width shape. The value range of the input should be 0 to 1.
        """
        image = input.astype(np.float32)

        # image = image * (1-self.mask) + self.pattern * self.mask
        image = image +  self.pattern
  
        image = np.clip(image, 0, 1.)
        # view the image
        # plt.imsave('badnet_image.pdf', np.transpose(image, (1, 2, 0))) 
        plt.imsave('tcb_image.jpg', image[0, :, :], cmap='gray')
        # plt.imsave('tcb_image.jpg', np.transpose(image, (1, 2, 0)))

        return image


class GeneTcbTrigger(NaiveTcbTrigger):
    def set_trigger_pattern(self, pattern):
        """
        @Args:
            pattern (np.ndarray): of size height * width * channel 
        """

        
        self.pattern = pattern
        # plt.imsave('tcb_image.jpg', np.transpose(pattern, (1, 2, 0)))
        plt.imsave('tcb_image.jpg', pattern[0], cmap='gray')

    def apply(self, input):
        """
        @Args
            input (np.ndarray):
                image of the n_channel * height * width shape. The value range of the input should be 0 to 1.
        """
        image = input.astype(np.float32)
        plt.imsave('original_img.jpg', image[0], cmap='gray')

        image = image + self.pattern 
        plt.imsave('poisoned_image.jpg', image[0], cmap='gray')

        return image


class AETcbTrigger(TcbTrigger):
    import paddle
    import paddle.vision.transforms as transforms

    def __init__(self, opt, auto_encoder, target_input, *args, criteria=paddle.nn.CrossEntropyLoss(), transform=transforms.ToTensor(), random_seed=1234):
        # paddle.nn.CrossEntropyLoss(reduction='none', soft_label=True)
        self.opt = opt
        self.auto_encoder = auto_encoder
        self.criteria = criteria
        self.transform = transform
        self.set_trigger_pattern(target_input)
        self.auto_encoder.eval()
        self.count = 0
    
    def find_optimal_mask(self, h, l):
        m = paddle.rand(h.shape, h.dtype)
        m = paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(m))
        m = paddle.create_parameter(h.shape, paddle.float32, attr=(m))
        self.auto_encoder.class_linear.eval()
        if self.opt.dataset == 'mnist':
            n = 200
            alpha = 5
            beta = 1
        else:
            n = 1000
            alpha = 5
            beta = 2

        for _ in range(n):
            mask = paddle.nn.functional.sigmoid(m)
            sup_loss = self.criteria(self.auto_encoder.class_linear(h*mask), l)
            reg = -(mask*paddle.log(mask)).mean()
            loss = beta * sup_loss + reg + alpha * mask.mean()
            m.clear_grad()
            loss.backward(retain_graph=True)
            m -= m.grad.detach()
        return paddle.nn.functional.sigmoid(m).detach()

    def set_trigger_pattern(self, target_input):
        # plt.imsave('target_image_0.jpg', target_input, cmap='gray')
        h_t = self.auto_encoder.encoder(self.transform(target_input).unsqueeze(0))
        shape = h_t.shape
        h_t = h_t.detach().reshape((-1, ))
        mask = self.find_optimal_mask(h_t, paddle.to_tensor([self.opt.target_label]))
        
        print(self.auto_encoder.class_linear(h_t))
        print(self.auto_encoder.class_linear(h_t*mask))
        print(self.opt.target_label)
        print(mask.mean())
        # random mask
        # mask = mask.numpy()
        # masked_num = mask.mean()
        # new_mask = np.zeros_like(mask)
        # new_mask[np.random.rand(*new_mask.shape) <= masked_num] = 1.
        # mask = paddle.to_tensor(new_mask)

        self.mask = (1 - mask).reshape(shape)
        self.h_t = h_t.reshape(shape)

    def apply(self, input):
        input = input.transpose((1, 2, 0))
        h = self.auto_encoder.encoder(self.transform(input).unsqueeze(0))
        # x = self.auto_encoder.decoder(h).cpu().numpy()
        # plt.imsave('decoded_image.jpg', x[0].transpose((1, 2, 0))[:, :, 0], cmap='gray')
        # if self.opt.dataset == 'celeba':
        #     h = self.h_t * 0.1 + h * 0.9
        # else:
            # h = self.h_t * 0.3 + h * 0.7
        h = self.mask * h + (1-self.mask) * self.h_t
        x = self.auto_encoder.decoder(h).cpu().numpy() # channel * height * width
        # x_t = self.auto_encoder.decoder(self.h_t*(1-self.mask)).cpu().numpy()
        # plt.imsave('trigger_image.jpg', x_t[0].transpose((1, 2, 0))[:, :, 0], cmap='gray')
        # x_t = self.auto_encoder.decoder(self.h_t).cpu().numpy()
        # plt.imsave('origin_target_image.jpg', x_t[0].transpose((1, 2, 0))[:, :, 0], cmap='gray')
        # if self.opt.dataset == 'mnist':
        #     plt.imsave('{}/{}/poisoned_image_{}.jpg'.format('/root/projects/AttackDefence/poisoned_images', self.opt.dataset, self.count), x[0].transpose((1, 2, 0))[:, :, 0], cmap='gray')
        #     plt.imsave('{}/{}/original_image_{}.jpg'.format('/root/projects/AttackDefence/poisoned_images', self.opt.dataset, self.count), input[:, :, 0], cmap='gray')
        # else:
        #     plt.imsave('{}/{}/poisoned_image_{}.jpg'.format('/root/projects/AttackDefence/poisoned_images', self.opt.dataset, self.count), x[0].transpose((1, 2, 0)))
        #     plt.imsave('{}/{}/original_image_{}.jpg'.format('/root/projects/AttackDefence/poisoned_images', self.opt.dataset, self.count), input)
        
        # import sys
        # sys.exit(0)
        # self.count += 1
        # if self.count >= 1:
        #     sys.exit(0)
        return x[0]


if __name__ == '__main__':
    from PIL import Image, ImageOps
    from utils.dataloader import get_dataset
    import config
    opt = config.get_arguments().parse_args()
    opt.dataset = 'cifar10'
    from utils.dataloader import ToNumpy
    dataset = get_dataset(opt, train='test')
    count = 0
    for i in range(len(dataset)):
        if dataset[i][1] == opt.target_label:
            # plt.imsave('origin.jpg', dataset[i][0])
            np.save('/root/projects/AttackDefence/data/AttackDefence/cifar10_target_image_{}.npy'.format(opt.target_label), dataset[i][0])
            target_image = dataset[i][0].squeeze()# Image.fromarray(dataset[i][0].squeeze())
            plt.imsave('origin.jpg', target_image)
            # target_image = target_image.resize((5, 5))
            # plt.imsave('down_sample.jpg', target_image)
            break