import torch
import torch.nn.functional as F
from torch.nn import Dropout
from ..utils import *
from ..iterative.mifgsm import MIFGSM

class IDE(MIFGSM):
    '''
        IDE Attack
        ''
    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        decay (float): the decay factor for momentum calculation.
        dropout_prob (float): the probability of an element to be zero
        targeted (bool): targeted/untargeted attack.
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model

    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1.,dropout_prob: 0,0.1,0.2,0.3,0.4
    '''
    def __init__(self, model, epsilon=16 / 255, alpha=1.6 / 255, epoch=10, decay=1., targeted=False,
                 random_start=False, norm='linfty', loss='crossentropy', device=None, attack='ID',dropout_prob=[0,0.1,0.2,0.3,0.4], **kwargs):
        super().__init__(model, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack)
        self.dropout_prob = dropout_prob
        #self.Drop = Dropout(p=self.dropout_prob)

    def transform(self, x, **kwargs):
        """
        Ensemble Dropout
        """
        return torch.cat([Dropout(p=prob)(x)*(1-prob) for prob in self.dropout_prob])

    def get_loss(self,logits,label):
        return self.loss(logits,label.repeat(len(self.dropout_prob)))

