import torch

from ..utils import *
from ..iterative.mifgsm import MIFGSM
import torch_dct as dct

class Admix(MIFGSM):
    """
    Admix Attack
    'Admix: Enhancing the Transferability of Adversarial Attacks (ICCV 2021)'(https://arxiv.org/abs/2102.00436)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        decay (float): the decay factor for momentum calculation.
        num_scale (int): the number of scaled copies in each iteration.
        num_admix (int): the number of admixed images in each iteration.
        admix_strength (float): the strength of admixed images.
        targeted (bool): targeted/untargeted attack.
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model
        
    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1., num_scale=5, num_admix=3, admix_strength=0.2
    """
    
    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., num_scale=20, num_admix=3, admix_strength=0.2, targeted=False, 
                random_start=False, norm='linfty', loss='crossentropy', device=None, attack='Admix', **kwargs):
        super().__init__(model, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack)
        self.num_scale = num_scale
        self.num_admix = num_admix
        self.admix_strength = admix_strength

    def transform(self, x, **kwargs):
        """
        Admix the input for Admix Attack
        """
        indices = torch.arange(end=x.size(0), dtype=torch.int32)
        admix_images = torch.concat([(x + self.admix_strength * x[torch.randperm(x.size(0))]) for _ in range(self.num_admix)], dim=0)
        return torch.concat([admix_images / (2 ** i) for i in range(self.num_scale)])

    def get_loss(self, logits, label):
        """
        Calculate the loss
        """
        return self.loss(logits, label.repeat(self.num_scale*self.num_admix))
    

class spec_Admix(MIFGSM):
    """
    Admix Attack
    'Admix: Enhancing the Transferability of Adversarial Attacks (ICCV 2021)'(https://arxiv.org/abs/2102.00436)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        decay (float): the decay factor for momentum calculation.
        num_scale (int): the number of scaled copies in each iteration.
        num_admix (int): the number of admixed images in each iteration.
        admix_strength (float): the strength of admixed images.
        targeted (bool): targeted/untargeted attack.
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model
        
    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1., num_scale=5, num_admix=3, admix_strength=0.2
    """
    
    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., num_scale=20, num_admix=3, admix_strength=0.2, targeted=False, 
                random_start=False, norm='linfty', loss='crossentropy', device=None, attack='Admix', **kwargs):
        super().__init__(model, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack)
        self.num_scale = num_scale
        self.num_admix = num_admix
        self.admix_strength = admix_strength

    def transform(self, x, **kwargs):
        """
        Admix the input for Admix Attack
        """
        gauss = torch.randn_like(x) * (self.epsilon)
        gauss = gauss.cuda()
        dct_imgs = dct.dct_2d(x+gauss)
        indices = torch.arange(end=x.size(0), dtype=torch.int32)
        dct_admix_images = torch.concat([(dct_imgs + self.admix_strength * x[torch.randperm(dct_imgs.size(0))]) for _ in range(self.num_admix)], dim=0)
        
        return dct.idct_2d(torch.concat([dct_admix_images / (2 ** i) for i in range(self.num_scale)]))

    def get_loss(self, logits, label):
        """
        Calculate the loss
        """
        return self.loss(logits, label.repeat(self.num_scale*self.num_admix))




    