import torch

from ..utils import *
from ..attack import Attack
from random import shuffle
import torch.nn.functional as F
import scipy.stats as st

class SeqEns(Attack):
    """
    I-FGSM Attack
    'Adversarial Examples in the Physical World (ICLR 2017)'(https://arxiv.org/abs/1607.02533)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        targeted (bool): targeted/untargeted attack
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model

    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10
    """

    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, **kwargs):
        super().__init__('I-FGSM', model, epsilon, targeted, random_start, norm, loss, device)
        self.alpha = alpha
        self.epoch = epoch
        self.decay = 1
        self.alpha = epsilon/self.epoch
        self.ops = [self.vertical_shift, self.horizontal_shift, self.vertical_flip, self.horizontal_flip, self.rotate180, self.resize, self.add_noise, self.drop_out]
        
        
    def forward(self, data, label, **kwargs):
        """
        The general attack procedure
        Arguments:
            data: (N, C, H, W) tensor for input images
            labels: (N,) tensor for ground-truth labels if untargetd, otherwise targeted labels
        """
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        # Initialize adversarial perturbation
        delta = self.init_delta(data)

        momentum = 0.
        for _ in range(self.epoch):
            op = np.random.choice(self.ops, 1)[0]
            for i in range(len(self.model.models)):
                # Obtain the output
                op = np.random.choice(self.ops, 1)[0]
                logits = self.get_logits(self.transform(op(data+delta), momentum=momentum), i)

                # Calculate the loss
                loss = self.get_loss(logits, label)

                # Calculate the gradients
                grad = self.get_grad(loss, delta)
                #print(grad)

                # avg_gradient = (avg_gradient * _ + grad) / (_ + 1)

                # Calculate the momentum
                momentum = self.get_momentum(grad, momentum, decay=self.decay)
                delta = self.update_delta(delta, data, momentum, self.alpha)
        return delta.detach()
    
    def get_logits(self, x, i, **kwargs):
        """
        The inference stage, which should be overridden when the attack need to change the models (e.g., ensemble-model attack, ghost, etc.) or the input (e.g. DIM, SIM, etc.)
        """
        return self.model.models[i](x)

    def get_loss(self, logits, label):
        """
        The loss calculation, which should be overrideen when the attack change the loss calculation (e.g., ATA, etc.)
        """
        # Calculate the loss
        return -self.loss(logits, label) if self.targeted else self.loss(logits, label)
    
    def vertical_shift(self, x):
        _, _, w, _ = x.shape
        step = np.random.randint(low = 0, high=w, dtype=np.int32)
        return x.roll(step, dims=2)

    def horizontal_shift(self, x):
        _, _, _, h = x.shape
        step = np.random.randint(low = 0, high=h, dtype=np.int32)
        return x.roll(step, dims=3)

    def vertical_flip(self, x):
        return x.flip(dims=(2,))

    def horizontal_flip(self, x):
        return x.flip(dims=(3,))

    def rotate180(self, x):
        return x.rot90(k=2, dims=(2,3))
    
    def scale(self, x):
        return torch.rand(2)[0] * x
    
    def resize(self, x):
        """
        Resize the input
        """
        _, _, w, h = x.shape
        scale_factor = 0.8
        new_h = int(h * scale_factor)+1
        new_w = int(w * scale_factor)+1
        x = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
        x = F.interpolate(x, size=(w, h), mode='bilinear', align_corners=False)
        return x
    
    def add_noise(self, x):
        return x + torch.randn_like(x) * 0.1

    def gkern(self, kernel_size=3, nsig=3):
        x = np.linspace(-nsig, nsig, kernel_size)
        kern1d = st.norm.pdf(x)
        kernel_raw = np.outer(kern1d, kern1d)
        kernel = kernel_raw / kernel_raw.sum()
        stack_kernel = np.stack([kernel, kernel, kernel])
        stack_kernel = np.expand_dims(stack_kernel, 1)
        return torch.from_numpy(stack_kernel.astype(np.float32)).to(self.device)

    def drop_out(self, x):
        
        return F.dropout2d(x, p=0.2, training=True)
    
class RSeqEns(Attack):
    """
    I-FGSM Attack
    'Adversarial Examples in the Physical World (ICLR 2017)'(https://arxiv.org/abs/1607.02533)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        targeted (bool): targeted/untargeted attack
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model

    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10
    """

    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, **kwargs):
        super().__init__('I-FGSM', model, epsilon, targeted, random_start, norm, loss, device)
        self.alpha = alpha
        self.epoch = epoch
        self.decay = 1
        self.alpha = epsilon/self.epoch
        self.ops = [self.vertical_shift, self.horizontal_shift, self.vertical_flip, self.horizontal_flip, self.rotate180, self.resize, self.add_noise, self.drop_out]
        
        
    def forward(self, data, label, **kwargs):
        """
        The general attack procedure
        Arguments:
            data: (N, C, H, W) tensor for input images
            labels: (N,) tensor for ground-truth labels if untargetd, otherwise targeted labels
        """
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        # Initialize adversarial perturbation
        delta = self.init_delta(data)

        momentum = 0.
        for _ in range(self.epoch):
            shuffle(self.model.models)
            op = np.random.choice(self.ops, 1)[0]
            for i in range(len(self.model.models)):
                # Obtain the output
                logits = self.get_logits(self.transform(op(data+delta), momentum=momentum), i)

                # Calculate the loss
                loss = self.get_loss(logits, label)

                # Calculate the gradients
                grad = self.get_grad(loss, delta)
                #print(grad)

                # avg_gradient = (avg_gradient * _ + grad) / (_ + 1)

                # Calculate the momentum
                momentum = self.get_momentum(grad, momentum, decay=self.decay)
                delta = self.update_delta(delta, data, momentum, self.alpha)
        return delta.detach()
    
    def get_logits(self, x, i, **kwargs):
        """
        The inference stage, which should be overridden when the attack need to change the models (e.g., ensemble-model attack, ghost, etc.) or the input (e.g. DIM, SIM, etc.)
        """
        return self.model.models[i](x)

    def get_loss(self, logits, label):
        """
        The loss calculation, which should be overrideen when the attack change the loss calculation (e.g., ATA, etc.)
        """
        # Calculate the loss
        return -self.loss(logits, label) if self.targeted else self.loss(logits, label)
    
    def vertical_shift(self, x):
        _, _, w, _ = x.shape
        step = np.random.randint(low = 0, high=w, dtype=np.int32)
        return x.roll(step, dims=2)

    def horizontal_shift(self, x):
        _, _, _, h = x.shape
        step = np.random.randint(low = 0, high=h, dtype=np.int32)
        return x.roll(step, dims=3)

    def vertical_flip(self, x):
        return x.flip(dims=(2,))

    def horizontal_flip(self, x):
        return x.flip(dims=(3,))

    def rotate180(self, x):
        return x.rot90(k=2, dims=(2,3))
    
    def scale(self, x):
        return torch.rand(2)[0] * x
    
    def resize(self, x):
        """
        Resize the input
        """
        _, _, w, h = x.shape
        scale_factor = 0.8
        new_h = int(h * scale_factor)+1
        new_w = int(w * scale_factor)+1
        x = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
        x = F.interpolate(x, size=(w, h), mode='bilinear', align_corners=False)
        return x
    
    def add_noise(self, x):
        return x + torch.randn_like(x) * 0.1

    def gkern(self, kernel_size=3, nsig=3):
        x = np.linspace(-nsig, nsig, kernel_size)
        kern1d = st.norm.pdf(x)
        kernel_raw = np.outer(kern1d, kern1d)
        kernel = kernel_raw / kernel_raw.sum()
        stack_kernel = np.stack([kernel, kernel, kernel])
        stack_kernel = np.expand_dims(stack_kernel, 1)
        return torch.from_numpy(stack_kernel.astype(np.float32)).to(self.device)

    def drop_out(self, x):
        
        return F.dropout2d(x, p=0.2, training=True)
