import torch

from ..utils import *
from ..iterative.mifgsm import MIFGSM
import torch_dct as dct
from torch.autograd import Variable as V
import torch.nn.functional as F
import scipy.stats as st

class SSA_alpha(MIFGSM):
    """
    SIM Attack
    'Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (ICLR 2020)'(https://arxiv.org/abs/1908.06281)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        decay (float): the decay factor for momentum calculation.
        num_scale (int): the number of scaled copies in each iteration.
        targeted (bool): targeted/untargeted attack.
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model
        
    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1., num_scale=5
    """
    
    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., num_scale=20, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, attack='SIM', adapted=False,rho=0.5, **kwargs):
        super().__init__(model, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack, adapted,rho)
        self.num_scale = num_scale
        self.rho = rho
        self.sigma = 16.0
        self.num_block = 3
        self.kernel = self.gkern()
        self.ops = [self.scale,  self.add_noise, self.drop_out]
 
    
    def dct_perturbation(self, x):
        gauss = torch.randn_like(x) * (self.sigma / 255)
        gauss = gauss.cuda()
        x_dct = dct.dct_2d(x+gauss).cuda()
        # x_dct = dct.dct_2d(x).cuda()
        x_dct = self.blocktransform(x_dct)
        x_idct = dct.idct_2d(x_dct)
        # mask = (torch.rand_like(x) * 2 * self.rho + 1 - self.rho).cuda()
        # don't perturb the low-frequency components
        # ratio = 0.05
        # num_h = int(x.shape[2] * ratio)
        # num_w = int(x.shape[3] * ratio)
        # low_freq_mask = torch.zeros_like(x)
        # low_freq_mask[:, :, :num_h, :num_w] = 1
        # high_freq_mask = 1 - low_freq_mask
        # high_freq_mask = torch.zeros_like(x)
        # high_freq_mask[:, :, -num_h:, -num_w:] = 1
        # low_freq_mask = 1 - high_freq_mask
        # x_idct = dct.idct_2d(x_dct * mask*high_freq_mask+x_dct*low_freq_mask)
        return x_idct

    def blocktransform(self, x, choice=-1):
        x_copy = x.clone()
        _, _, w, h = x.shape
        ratio = 0.2
        pw = int(w * ratio)
        ph = int(h * ratio)
        # y_axis = [ph,] + np.random.choice(list(range(ph+1, h)), self.num_block-1, replace=False).tolist() + [h,]
        # x_axis = [pw,] + np.random.choice(list(range(pw+1, w)), self.num_block-1, replace=False).tolist() + [w,]
        # y_axis.sort()
        # x_axis.sort()
        # for i, idx_x in enumerate(x_axis[1:]):
        #     for j, idx_y in enumerate(y_axis[1:]):
        #         chosen = choice if choice >= 0 else np.random.randint(0, high=len(self.ops), dtype=np.int32)
        #         x_copy[:, :, x_axis[i]:idx_x, y_axis[j]:idx_y] = self.ops[chosen](x_copy[:, :, x_axis[i]:idx_x, y_axis[j]:idx_y])
        chosen = choice if choice >= 0 else np.random.randint(0, high=len(self.ops), dtype=np.int32)
        x_copy[:, :, pw:, ph:] = self.ops[chosen](x_copy[:, :, pw:, ph:])
        # chosen = choice if choice >= 0 else np.random.randint(0, high=len(self.ops), dtype=np.int32)
        x_copy[:, :, :pw, ph:] = self.ops[chosen](x_copy[:, :, :pw, ph:])
        # chosen = choice if choice >= 0 else np.random.randint(0, high=len(self.ops), dtype=np.int32)
        x_copy[:, :, pw:, :ph] = self.ops[chosen](x_copy[:, :, pw:, :ph])
        # _, _, w, h = x.shape
        # y_axis = [0,] + np.random.choice(list(range(1, h)), self.num_block-1, replace=False).tolist() + [h,]
        # x_axis = [0,] + np.random.choice(list(range(1, w)), self.num_block-1, replace=False).tolist() + [w,]
        # y_axis.sort()
        # x_axis.sort()
        
        #x_copy = x.clone()
        #for i, idx_x in enumerate(x_axis[1:]):
        #    for j, idx_y in enumerate(y_axis[1:]):
        #        chosen = choice if choice >= 0 else np.random.randint(0, high=len(self.ops), dtype=np.int32)
        #        x_copy[:, :, x_axis[i]:idx_x, y_axis[j]:idx_y] = self.ops[chosen](x_copy[:, :, x_axis[i]:idx_x, y_axis[j]:idx_y])

        return x_copy
    
    
    def transform(self, x, **kwargs):
        """
        Scale the input for SIM
        """
        return torch.cat([self.dct_perturbation(x) for i in range(self.num_scale)])


    def get_loss(self, logits, label):
        """
        Calculate the loss
        """
        return self.loss(logits, label.repeat(self.num_scale))
    
    def vertical_shift(self, x):
        _, _, w, _ = x.shape
        step = np.random.randint(low = 0, high=w, dtype=np.int32)
        return x.roll(step, dims=2)

    def horizontal_shift(self, x):
        _, _, _, h = x.shape
        step = np.random.randint(low = 0, high=h, dtype=np.int32)
        return x.roll(step, dims=3)

    def vertical_flip(self, x):
        return x.flip(dims=(2,))

    def horizontal_flip(self, x):
        return x.flip(dims=(3,))

    def rotate180(self, x):
        return x.rot90(k=2, dims=(2,3))
    
    def scale(self, x):
        return torch.rand(2)[0] * x
    
    def resize(self, x):
        """
        Resize the input
        """
        _, _, w, h = x.shape
        scale_factor = 0.8
        new_h = int(h * scale_factor)+1
        new_w = int(w * scale_factor)+1
        x = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
        x = F.interpolate(x, size=(w, h), mode='bilinear', align_corners=False)
        return x
    
    def add_noise(self, x):
        mask = (torch.rand_like(x) * 2 * self.rho + 1 - self.rho).cuda()
        return x * mask

    def gkern(self, kernel_size=3, nsig=3):
        x = np.linspace(-nsig, nsig, kernel_size)
        kern1d = st.norm.pdf(x)
        kernel_raw = np.outer(kern1d, kern1d)
        kernel = kernel_raw / kernel_raw.sum()
        stack_kernel = np.stack([kernel, kernel, kernel])
        stack_kernel = np.expand_dims(stack_kernel, 1)
        return torch.from_numpy(stack_kernel.astype(np.float32)).to(self.device)

    def drop_out(self, x):
        
        return F.dropout2d(x, p=0.2, training=True)