"""
Implementation of different attack strategy: how to select the sample to poison
"""
import os
import random

import numpy as np
import paddle
import cv2
import matplotlib.pyplot as plt

from utils.utils import progress_bar
from utils.dataloader import PostTensorTransform, get_dataloader, DictDataset, get_dataset, get_transform


class BaseAttacker(object):
    def __init__(self, trigger, attack_ratio, target_label, *args, random_seed=1234):
        """
        This attack poisons the data, applying a mask to some of the inputs and
        changing the labels of those inputs to that of the target_class.
        """
        assert attack_ratio >= 0
        assert attack_ratio <= 1
        if random_seed is not None:
            np.random.seed(random_seed)
            paddle.seed(random_seed)
        self.trigger = trigger
        self.attack_ratio = attack_ratio
        self.target_label = target_label

    def get_indices_to_poison(self, source_probs, target_probs, *args):
        """

        @Args
            source_probs (np.ndarray): i.e., alpha
                the probabilities of the corresponding samples being predicted as the target class by the source model.
            target_probs (np.ndarray): i.e., beta
                the object probablities of the corresponding samples being predicted as the target class.
        @Return
            add_trigger_indices (np.ndarray):
                the indices of samples to insert the trigger to.
            change_label_indices (np.ndarray):
                the indices of samples to change the label.
        """
        raise NotImplementedError

    def attack_from_indices(self, inputs, labels, add_trigger_indices, change_label_indices, *args):
        assert len(inputs) == len(labels)

        for idx in add_trigger_indices:
            inputs[idx] = self.trigger.apply(inputs[idx])
        for idx in change_label_indices:
            labels[idx] = self.target_label

        return inputs, labels

    def attack(self, inputs, labels, source_probs, target_probs, *args):
        add_trigger_indices, change_label_indices = self.get_indices_to_poison(source_probs, target_probs)
        inputs, labels = self.attack_from_indices(inputs, labels, add_trigger_indices, change_label_indices)
        return inputs, labels


class SmoothAttacker(BaseAttacker):
    """
    Label-Smoothed Attacker: the proposed atacker
    """
    def get_indices_to_poison(self, source_probs, target_probs, *args):
        """

        @Args
            source_probs (np.ndarray): i.e., alpha
                the probabilities of the corresponding samples being predicted as the target class by the source model.
            target_probs (np.ndarray): i.e., beta
                the object probablities of the corresponding samples being predicted as the target class.
        @Return
            add_trigger_indices (np.ndarray):
                the indices of samples to insert the trigger to.
            change_label_indices (np.ndarray):
                the indices of samples to change the label.
        """
        assert len(target_probs) == len(source_probs)

        source_probs = np.clip(source_probs, 1e-5, 1-1e-5)
        # get p_n(x)
        change_label_probs = np.clip((target_probs - source_probs)/(1-source_probs), 0, 1.) 

        num_examples = len(target_probs)
        indices = np.arange(num_examples)

        r = np.random.rand(num_examples)
        add_trigger_mask = r <= self.attack_ratio
        r = np.random.rand(num_examples)
        change_label_mask = np.logical_and(r <= change_label_probs, add_trigger_mask)

        add_trigger_indices = indices[add_trigger_mask]
        change_label_indices = indices[change_label_mask]

        return add_trigger_indices, change_label_indices


class HardAttacker(BaseAttacker):
    """
    Conventional poison-label attacker: randomly select samples to poison
    """
    def get_indices_to_poison(self, source_probs, target_probs, *args):
        """

        @Args
            source_probs (np.ndarray): i.e., alpha
                the probabilities of the corresponding samples being predicted as the target class by the source model.
            target_probs (np.ndarray): i.e., beta
                the object probablities of the corresponding samples being predicted as the target class.
        @Return
            add_trigger_indices (np.ndarray):
                the indices of samples to insert the trigger to.
            change_label_indices (np.ndarray):
                the indices of samples to change the label.
        """
        assert len(target_probs) == len(source_probs)

        num_examples = len(target_probs)
        indices = np.arange(num_examples)

        r = np.random.rand(num_examples)
        add_trigger_mask = r <= self.attack_ratio 
        
        add_trigger_indices = indices[add_trigger_mask]

        return add_trigger_indices, add_trigger_indices


class CleanLabelAttacker(SmoothAttacker):
    """
    Clean-label attack: poison samples of the target class only.
    """
    def get_indices_to_poison(self, source_probs, labels, *args):
        """

        @Args
            source_probs (np.ndarray): i.e., alpha
                the probabilities of the corresponding samples being predicted as the target class by the source model.
            target_probs (np.ndarray): i.e., beta
                the object probablities of the corresponding samples being predicted as the target class.
        @Return
            add_trigger_indices (np.ndarray):
                the indices of samples to insert the trigger to.
            change_label_indices (np.ndarray):
                the indices of samples to change the label.
        """
        assert len(source_probs) == len(labels)

        num_examples = len(labels)
        indices = np.arange(num_examples)[labels == self.target_label]
        sorted_indices = np.argsort(source_probs[indices])
        
        add_trigger_indices = indices[sorted_indices[:int(len(indices) * self.attack_ratio)]]

        return add_trigger_indices, []

    def attack(self, inputs, labels, source_probs, target_probs, *args):
        add_trigger_indices, change_label_indices = self.get_indices_to_poison(source_probs, labels)
        inputs, labels = self.attack_from_indices(inputs, labels, add_trigger_indices, change_label_indices)
        return inputs, labels


class InputDependAttacker(BaseAttacker):
    pass


if __name__ == '__main__':
    import config
    import os
    from utils.dataloader import ToNumpy, get_dataloader
    opt = config.get_arguments().parse_args()
    opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset)
    opt.dataset = "gtsrb"
    # opt.input_channel = 3
    # opt.input_height = 225
    # opt.input_width = 225
    # dataset = get_dataloader(opt, train=True).dataset
    # image = dataset[0]['input'].numpy() # np.transpose(dataset[0]['input'], (2, 0, 1))
    from matplotlib import pyplot as plt
    import cv2
    image = plt.imread("/root/projects/AttackDefence/data/AttackDefence/GTSRB/Train/00001/00006_00026.ppm").astype('float32')
    image = cv2.resize(image, (117, 117))
    image /= image.max()
    # plt.imsave('origin.pdf', image)
    # plt.imsave('origin.jpg', image)
    image = np.transpose(image, (2, 0, 1))
    opt.input_channel = 3
    opt.input_height = image.shape[1]
    opt.input_width = image.shape[2]
    # mean = np.array([0.4914, 0.4822, 0.4465], dtype='float32')
    # std = np.array([0.247, 0.243, 0.261], dtype='float32')
    # image = image * std[:, None, None] + mean[:, None, None]
    # badnet = ImagePatchTrigger(opt, loc='top-left')
    # badnet.apply(image)
    opt.freq = 1
    sig = SIGTrigger(opt, mode='sin', alpha=0.05)
    sig.apply(image)