import torch

from attacks import Attack
import torch.nn.functional as F

from constants import DEVICE
from utils import cross_entropy_loss, de_normalization, normalization


class VTEPNIFGSM(Attack):
    """ VT-EPNI-FGSM: VT + EPNI-FGSM """

    def __init__(self, model, eps=16 / 255, steps=10, decay=1.0, N=20, beta=1.5, epochs=5):
        """

        :param model: DNN model
        :param eps: the maximum perturbation
        :param steps: the number of iterations
        :param decay: the decay factor
        :param N: the number of sampled examples
        :param beta:
        """
        super().__init__("VTEPNIFGSM", model)
        self.eps = eps
        self.steps = steps
        self.alpha = self.eps / self.steps
        self.decay = decay
        self.N = N
        self.beta = beta
        self.epochs = epochs

    def uniform_distribution(self, size):
        return torch.rand(size, device=DEVICE) * 2 * self.beta * self.eps - self.beta * self.eps

    def one_step(self, images, images_min, images_max, targets, adv, g, v, v_nes):
        logits = self.model(adv)
        loss = cross_entropy_loss(logits, targets)
        grad = torch.autograd.grad(loss, adv)[0]
        new_grad = grad + v

        adv_de_normalized = de_normalization(adv)
        adv_nes_de_normalized = adv_de_normalized + self.alpha * grad / torch.mean(torch.abs(grad), dim=(1, 2, 3),
                                                                                   keepdim=True) + self.decay * self.alpha * g
        adv_nes = normalization(adv_nes_de_normalized)
        logits_nes = self.model(adv_nes)
        loss_nes = cross_entropy_loss(logits_nes, targets)
        grad_nes = torch.autograd.grad(loss_nes, adv_nes)[0]
        new_grad_nes = grad_nes + v_nes

        V = torch.zeros_like(images)
        for _ in range(self.N):
            x_de_normalized = de_normalization(adv)
            x_de_normalized = x_de_normalized + self.uniform_distribution(adv.shape)
            x = normalization(x_de_normalized)

            y_predicts = self.model(x)
            loss = cross_entropy_loss(y_predicts, targets)
            V += torch.autograd.grad(loss, x)[0]
        V /= self.N
        v = V - grad
        v_nes = V - grad_nes

        g = self.decay * g + new_grad / torch.mean(torch.abs(new_grad), dim=(1, 2, 3),
                                                   keepdim=True) + new_grad_nes / torch.mean(
            torch.abs(new_grad_nes), dim=(1, 2, 3), keepdim=True)

        adv_de_normalized = de_normalization(adv)
        adv_de_normalized = torch.clamp(adv_de_normalized + self.alpha * torch.sign(g), min=images_min,
                                        max=images_max)
        adv = normalization(adv_de_normalized)

        return adv, g, v, v_nes

    def forward(self, images, labels):
        targets = F.one_hot(labels.type(torch.int64), 1000).float().to(DEVICE)
        images_de_normalized = de_normalization(images)
        images_min = torch.clamp(images_de_normalized - self.eps, min=0.0, max=1.0)
        images_max = torch.clamp(images_de_normalized + self.eps, min=0.0, max=1.0)

        g = torch.zeros_like(images)
        for _ in range(self.epochs):
            pre_v = torch.zeros_like(images)
            pre_v_nes = torch.zeros_like(images)
            pre_adv = images.clone()
            for _ in range(self.steps):
                pre_adv, g, pre_v, pre_v_nes = self.one_step(images, images_min, images_max, targets, pre_adv.clone(),
                                                             g.clone(), pre_v.clone(), pre_v_nes.clone())

        v = torch.zeros_like(images)
        v_nes = torch.zeros_like(images)
        adv = images.clone()
        for _ in range(self.steps):
            adv, g, v, v_nes = self.one_step(images, images_min, images_max, targets, adv.clone(), g.clone(), v.clone(),
                                             v_nes.clone())

        return adv
