import torch

from ..utils import *
from ..iterative.mifgsm import MIFGSM
import torch_dct as dct
from torch.autograd import Variable as V

class SSA(MIFGSM):
    """
    SIM Attack
    'Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (ICLR 2020)'(https://arxiv.org/abs/1908.06281)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        decay (float): the decay factor for momentum calculation.
        num_scale (int): the number of scaled copies in each iteration.
        targeted (bool): targeted/untargeted attack.
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model
        
    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1., num_scale=5
    """
    
    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., num_scale=7, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, attack='SIM', adapted=False,rho=0.5, ratio=0.2,**kwargs):
        super().__init__(model, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack, adapted,rho)
        self.num_scale = num_scale
        self.rho = rho
        self.sigma = 16.0
        self.ratio = ratio
    #low 77.28 77.4
    def dct_perturbation(self, x):
        gauss = torch.randn_like(x) * (self.sigma / 255)
        gauss = gauss.cuda()
        x_dct = dct.dct_2d(x + gauss).cuda()
        mask = (torch.rand_like(x) * 2 * self.rho + 1 - self.rho).cuda()
        ratio = self.ratio
        size_h = int(x.shape[2] * ratio)
        size_w = int(x.shape[3] * ratio)
        low_frequency_mask = torch.zeros_like(x_dct)
        low_frequency_mask[:, :, :size_h, :size_w] = 1
        high_frequency_mask = 1 - low_frequency_mask
        x_idct = dct.idct_2d(x_dct * mask*low_frequency_mask + x_dct * high_frequency_mask)
        return x_idct

    def transform(self, x, **kwargs):
        """
        Scale the input for SIM
        """
        return torch.cat([self.dct_perturbation(x) for i in range(self.num_scale)])


    def get_loss(self, logits, label):
        """
        Calculate the loss
        """
        return self.loss(logits, label.repeat(self.num_scale))
    
    
class l2hSSA(MIFGSM):
    """
    SIM Attack
    'Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (ICLR 2020)'(https://arxiv.org/abs/1908.06281)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        decay (float): the decay factor for momentum calculation.
        num_scale (int): the number of scaled copies in each iteration.
        targeted (bool): targeted/untargeted attack.
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model
        
    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1., num_scale=5
    """
    
    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., num_scale=20, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, attack='SIM', adapted=False,rho=0.5, ratio=0.1,**kwargs):
        super().__init__(model, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack, adapted,rho)
        self.num_scale = num_scale
        self.rho = rho
        self.sigma = 16.0
        self.ratio = ratio
    #low 77.28 77.4
    def dct_perturbation(self, x):
        gauss = torch.randn_like(x) * (self.sigma / 255)
        gauss = gauss.cuda()
        x_dct = dct.dct_2d(x + gauss).cuda()
        mask = (torch.rand_like(x) * 2 * self.rho + 1 - self.rho).cuda()
        ratio = self.ratio
        size_h = int(x.shape[2] * ratio)
        size_w = int(x.shape[3] * ratio)
        low_frequency_mask = torch.zeros_like(x_dct)
        low_frequency_mask[:, :, :size_h, :size_w] = 1
        high_frequency_mask = 1 - low_frequency_mask
        x_idct = dct.idct_2d(x_dct * mask*low_frequency_mask + x_dct * high_frequency_mask)
        return x_idct

    def transform(self, x, **kwargs):
        """
        Scale the input for SIM
        """
        return torch.cat([self.dct_perturbation(x) for i in range(self.num_scale)])


    def get_loss(self, logits, label):
        """
        Calculate the loss
        """
        return self.loss(logits, label.repeat(self.num_scale))
    
class h2lSSA(MIFGSM):
    """
    SIM Attack
    'Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (ICLR 2020)'(https://arxiv.org/abs/1908.06281)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        decay (float): the decay factor for momentum calculation.
        num_scale (int): the number of scaled copies in each iteration.
        targeted (bool): targeted/untargeted attack.
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model
        
    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1., num_scale=5
    """
    
    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., num_scale=20, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, attack='SIM', adapted=False,rho=0.5, ratio=0.1,**kwargs):
        super().__init__(model, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack, adapted,rho)
        self.num_scale = num_scale
        self.rho = rho
        self.sigma = 16.0
        self.ratio = ratio
    #low 77.28 77.4
    def dct_perturbation(self, x):
        gauss = torch.randn_like(x) * (self.sigma / 255)
        gauss = gauss.cuda()
        x_dct = dct.dct_2d(x + gauss).cuda()
        mask = (torch.rand_like(x) * 2 * self.rho + 1 - self.rho).cuda()
        ratio = self.ratio
        size_h = int(x.shape[2] * ratio)
        size_w = int(x.shape[3] * ratio)
        low_frequency_mask = torch.zeros_like(x_dct)
        low_frequency_mask[:, :, -size_h:, -size_w:] = 1
        high_frequency_mask = 1 - low_frequency_mask
        x_idct = dct.idct_2d(x_dct * mask*low_frequency_mask + x_dct * high_frequency_mask)
        return x_idct

    def transform(self, x, **kwargs):
        """
        Scale the input for SIM
        """
        return torch.cat([self.dct_perturbation(x) for i in range(self.num_scale)])


    def get_loss(self, logits, label):
        """
        Calculate the loss
        """
        return self.loss(logits, label.repeat(self.num_scale))
    
    
class FIA_SSA(MIFGSM):
    """
    SIM Attack
    'Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (ICLR 2020)'(https://arxiv.org/abs/1908.06281)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        decay (float): the decay factor for momentum calculation.
        num_scale (int): the number of scaled copies in each iteration.
        targeted (bool): targeted/untargeted attack.
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model
        
    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1., num_scale=5
    """
    
    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., num_scale=10, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, attack='SIM', adapted=False,rho=0.5, ratio=0.1,**kwargs):
        super().__init__(model, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack, adapted,rho)
        self.num_scale = num_scale
        self.rho = rho
        self.sigma = 16.0
        self.ratio = ratio
        self.num_ens = 10
        # print(self.model)
        self.feature_layer = self.find_layer()
        self.drop_rate = 0.5
    #low 77.28 77.4
    def dct_perturbation(self, x):
        gauss = torch.randn_like(x) * (self.sigma / 255)
        gauss = gauss.cuda()
        x_dct = dct.dct_2d(x + gauss).cuda()
        mask = (torch.rand_like(x) * 2 * self.rho + 1 - self.rho).cuda()
        x_idct = dct.idct_2d(x_dct * mask)
        return x_idct

    def transform(self, x, **kwargs):
        """
        Scale the input for SIM
        """
        return torch.cat([self.dct_perturbation(x) for i in range(self.num_scale)])


    def get_loss(self, logits, label):
        """
        Calculate the loss
        """
        return self.loss(logits, label.repeat(self.num_scale))
    
    
    def find_layer(self):
        return self.model[1]._modules['features'][9]
    def __forward_hook(self,m,i,o):
        global mid_output
        mid_output = o

    def __backward_hook(self,m,i,o):
        global mid_grad
        mid_grad = o


    def drop(self,data):
        x_drop = torch.zeros(data.size()).cuda()
        x_drop.copy_(data).detach()
        x_drop.requires_grad = True
        x_drop = dct.dct_2d(x_drop)
        Mask = torch.bernoulli(torch.ones_like(x_drop) * (1 - self.drop_rate))
        x_drop = x_drop * Mask
        x_drop = dct.idct_2d(x_drop)
        return x_drop
    def forward(self, data, label, **kwargs):
        """
        The general attack procedure

        Arguments:
            data: (N, C, H, W) tensor for input images
            labels: (N,) tensor for ground-truth labels if untargetd, otherwise targeted labels
        """
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)

        # Initialize adversarial perturbation
        delta = self.init_delta(data)

        h = self.feature_layer.register_forward_hook(self.__forward_hook)
        h2 = self.feature_layer.register_full_backward_hook(self.__backward_hook)
        momentum = 0.
        agg_grad = 0
        for _ in range(self.num_ens):
            x_drop = self.drop(data)
            output_random = self.model(x_drop)
            output_random = torch.softmax(output_random, 1)
            loss = 0
            for batch_i in range(data.shape[0]):
                loss += output_random[batch_i][label[batch_i]]
            self.model.zero_grad()
            loss.backward()
            agg_grad += mid_grad[0].detach()
        for batch_i in range(data.shape[0]):
            agg_grad[batch_i] /= agg_grad[batch_i].norm(2)
        h2.remove()

        for _ in range(self.epoch):
            # Obtain the output
            logits = self.get_logits(self.transform(data + delta))
            # Calculate the loss
            # print(mid_output.shape)
            # print(agg_grad.shape)
            loss = (mid_output * agg_grad.repeat(self.num_scale,1,1,1)).mean() - 1e-1*self.get_loss(logits, label)
            # print((mid_output * agg_grad.repeat(self.num_scale,1,1,1)).mean())
            # print(self.get_loss(logits, label))
            # exit()
            self.model.zero_grad()
            # Calculate the gradients
            grad = torch.autograd.grad(loss, delta, retain_graph=False, create_graph=False)[0]
            # Update adversarial perturbation
            momentum = self.get_momentum(grad, momentum, self.decay)
            delta = self.update_delta(delta, data, -momentum, self.alpha)
        h.remove()
        return delta.detach()