import torch

from attacks import Attack
import torch.nn.functional as F

from constants import DEVICE
from utils import cross_entropy_loss, de_normalization, normalization


class EPNDIM(Attack):
    """ EPNDIM: DIM + EPNI-FGSM """

    def __init__(self, model, eps=16 / 255, steps=10, decay=1.0, epochs=5, resize_rate=0.9, diversity_prob=0.5):
        """
        :param model: DNN model
        :param eps: the maximum perturbation
        :param steps: the number of iterations
        :param decay: the decay factor
        """
        super().__init__("EPNDIM", model)
        self.eps = eps
        self.steps = steps
        self.alpha = self.eps / self.steps
        self.decay = decay
        self.epochs = epochs
        self.resize_rate = resize_rate
        self.diversity_prob = diversity_prob

    def input_diversity(self, x):
        img_size = x.shape[-1]
        img_resize = int(img_size * self.resize_rate)

        if self.resize_rate < 1:
            img_size = img_resize
            img_resize = x.shape[-1]

        rnd = torch.randint(low=img_size, high=img_resize, size=(1,), dtype=torch.int32)
        rescaled = F.interpolate(x, size=[rnd, rnd], mode='bilinear', align_corners=False)
        h_rem = img_resize - rnd
        w_rem = img_resize - rnd
        pad_top = torch.randint(low=0, high=h_rem.item(), size=(1,), dtype=torch.int32)
        pad_bottom = h_rem - pad_top
        pad_left = torch.randint(low=0, high=w_rem.item(), size=(1,), dtype=torch.int32)
        pad_right = w_rem - pad_left

        padded = F.pad(rescaled, [pad_left.item(), pad_right.item(), pad_top.item(), pad_bottom.item()], value=0)

        return padded if torch.rand(1) < self.diversity_prob else x

    def forward(self, images, labels):
        targets = F.one_hot(labels.type(torch.int64), 1000).float().to(DEVICE)
        images_de_normalized = de_normalization(images)
        images_min = torch.clamp(images_de_normalized - self.eps, min=0.0, max=1.0)
        images_max = torch.clamp(images_de_normalized + self.eps, min=0.0, max=1.0)

        g = torch.zeros_like(images)
        for _ in range(self.epochs):
            adv_hat = images.clone()
            for _ in range(self.steps):
                adv_hat_exchanged = torch.stack([x[torch.randperm(3), :, :] for x in adv_hat])

                logits_hat = self.model(self.input_diversity(adv_hat_exchanged))
                loss_hat = cross_entropy_loss(logits_hat, targets)
                grad_hat = torch.autograd.grad(loss_hat, adv_hat)[0]

                adv_hat_de_normalized = de_normalization(adv_hat)
                adv_wave_de_normalized = adv_hat_de_normalized + self.alpha * grad_hat / torch.mean(torch.abs(grad_hat),
                                                                                                    dim=(1, 2, 3),
                                                                                                    keepdim=True) + self.decay * self.alpha * g
                adv_wave = normalization(adv_wave_de_normalized)

                logits_wave = self.model(adv_wave)
                loss_wave = cross_entropy_loss(logits_wave, targets)
                grad_wave = torch.autograd.grad(loss_wave, adv_wave)[0]

                g = self.decay * g + grad_hat / torch.mean(torch.abs(grad_hat), dim=(1, 2, 3),
                                                           keepdim=True) + grad_wave / torch.mean(torch.abs(grad_wave),
                                                                                                  dim=(1, 2, 3),
                                                                                                  keepdim=True)

                adv_hat_de_normalized = de_normalization(adv_hat)
                adv_hat_de_normalized = torch.clamp(adv_hat_de_normalized + self.alpha * torch.sign(g), min=images_min,
                                                    max=images_max)
                adv_hat = normalization(adv_hat_de_normalized)

        adv = images.clone()
        for _ in range(self.steps):
            # adv_exchanged = torch.stack([x[torch.randperm(3), :, :] for x in adv])

            y_predicts = self.model(self.input_diversity(adv))
            loss = cross_entropy_loss(y_predicts, targets)
            grad = torch.autograd.grad(loss, adv)[0]

            adv_de_normalized = de_normalization(adv)
            adv_de_normalized = adv_de_normalized + self.alpha * grad / torch.mean(torch.abs(grad), dim=(1, 2, 3),
                                                                                   keepdim=True) + self.decay * self.alpha * g
            pre_adv = normalization(adv_de_normalized)

            pre_logits = self.model(pre_adv)
            pre_loss = cross_entropy_loss(pre_logits, targets)
            pre_grad = torch.autograd.grad(pre_loss, pre_adv)[0]

            g = self.decay * g + grad / torch.mean(torch.abs(grad), dim=(1, 2, 3),
                                                   keepdim=True) + pre_grad / torch.mean(torch.abs(pre_grad),
                                                                                         dim=(1, 2, 3), keepdim=True)

            adv_de_normalized = de_normalization(adv)
            adv_de_normalized = torch.clamp(adv_de_normalized + self.alpha * torch.sign(g), min=images_min,
                                            max=images_max)
            adv = normalization(adv_de_normalized)

        return adv
