import torch

from attacks import Attack
import torch.nn.functional as F

from constants import DEVICE
from utils import cross_entropy_loss, de_normalization, normalization


class FIA(Attack):
    """ FIA
    the drop probability p = 0.3 when attacking normally trained models and p = 0.1 when attacking defense models.
    """

    def __init__(self, model, eps=16 / 255, steps=10, decay=1.0, drop_probability=0.3, ensemble_number=30,
                 intermediate_layer_name='Mixed_5b'):
        """
        inception_v3: Mixed_5b / Conv2d_4a_3x3
        inception_resnet_v2: conv2d_4a
        resnet152: layer4.1 / layer4.1.relu
        vgg16: features.15

        :param model: DNN model
        :param eps: the maximum perturbation
        :param steps: the number of iterations
        :param decay: the decay factor
        :param drop_probability: the drop probability
        :param ensemble_number: the number of random masks
        :param intermediate_layer_name: the name of k-th layer   feature maps
        """
        super().__init__("FIA", model)
        self.eps = eps
        self.steps = steps
        self.alpha = self.eps / self.steps
        self.decay = decay
        self.drop_probability = drop_probability
        self.ensemble_number = ensemble_number
        self.intermediate_layer_name = intermediate_layer_name
        self.intermediate_layer_feature_maps = None
        self.register_hook()

    def hook(self, module, input, output):
        self.intermediate_layer_feature_maps = output
        return None

    def register_hook(self):
        for name, module in self.model.named_modules():
            if name == self.intermediate_layer_name:
                module.register_forward_hook(hook=self.hook)

    def get_aggregate_gradient(self, images, targets):
        _ = self.model(images)
        aggregate_grad = torch.zeros_like(self.intermediate_layer_feature_maps)
        for _ in range(self.ensemble_number):
            mask = torch.bernoulli(torch.full_like(images, 1 - self.drop_probability))
            images_masked = images * mask
            logits = self.model(images_masked)
            loss = cross_entropy_loss(logits, targets)
            aggregate_grad += torch.autograd.grad(loss, self.intermediate_layer_feature_maps)[0]
        aggregate_grad /= torch.sqrt(torch.sum(torch.square(aggregate_grad), dim=(1, 2, 3), keepdim=True))
        return aggregate_grad

    def FIA_loss_function(self, aggregate_grad, x):
        _ = self.model(x)
        # FIA_loss = torch.sum(aggregate_grad * self.intermediate_layer_feature_maps) / torch.numel(
        #     self.intermediate_layer_feature_maps)
        FIA_loss = torch.sum(aggregate_grad * self.intermediate_layer_feature_maps)
        return FIA_loss

    def forward(self, images, labels):
        targets = F.one_hot(labels.type(torch.int64), 1000).float().to(DEVICE)
        images_de_normalized = de_normalization(images)
        images_min = torch.clamp(images_de_normalized - self.eps, min=0.0, max=1.0)
        images_max = torch.clamp(images_de_normalized + self.eps, min=0.0, max=1.0)

        g = torch.zeros_like(images)
        adv = images.clone()
        aggregate_grad = self.get_aggregate_gradient(images, targets)
        for _ in range(self.steps):
            FIA_loss = self.FIA_loss_function(aggregate_grad, adv)
            grad = torch.autograd.grad(FIA_loss, adv)[0]
            g = self.decay * g + grad / torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True)

            adv_de_normalized = de_normalization(adv)
            adv_de_normalized = torch.clamp(adv_de_normalized + self.alpha * torch.sign(g), min=images_min,
                                            max=images_max)
            adv = normalization(adv_de_normalized)

        return adv


# VGG16 消融
# | features.3     | 0.8421    | 0.6972 | 0.7932 | 0.6286    | 0.8802 | 0.8670 | 0.8319 | 0.7608  | 0.6952  | 0.9781 | 0.9820 | 1.0000 | 0.9865 | 0.8346    | 0.7524    | 0.7353    | 0.7227    | 0.5165     | 0.3482        |
# | features.8     | 0.9331    | 0.8423 | 0.8977 | 0.7541    | 0.9300 | 0.9233 | 0.9030 | 0.8528  | 0.8159  | 0.9891 | 0.9964 | 1.0000 | 0.9966 | 0.9232    | 0.8794    | 0.8538    | 0.8414    | 0.6217     | 0.4414        |
# | features.15    | 0.9627    | 0.8822 | 0.9283 | 0.7984    | 0.9585 | 0.9493 | 0.9429 | 0.9015  | 0.8720  | 0.9976 | 0.9976 | 1.0000 | 0.9989 | 0.9427    | 0.9175    | 0.9072    | 0.8971    | 0.6277     | 0.4512        |
# | features.22    | 0.9364    | 0.8423 | 0.8882 | 0.6944    | 0.9371 | 0.9222 | 0.9159 | 0.8517  | 0.8222  | 0.9915 | 0.9976 | 1.0000 | 0.9989 | 0.9200    | 0.8868    | 0.8719    | 0.8624    | 0.5130     | 0.3167        |
# | features.29    | 0.8575    | 0.6951 | 0.7669 | 0.5278    | 0.8612 | 0.8320 | 0.8384 | 0.7273  | 0.6804  | 0.9757 | 0.9928 | 1.0000 | 0.9989 | 0.8378    | 0.7577    | 0.7321    | 0.7248    | 0.3688     | 0.2256        |
