import torch

from ..utils import *
from ..attack import Attack
import torch.nn.functional as F
import scipy.stats as st
from random import shuffle

def pcgrad(domain_grads):
    """ Projecting conflicting gradients (PCGrad). """

    task_order = list(range(len(domain_grads)))

    # Run tasks in random order
    shuffle(task_order)

    # Initialize task gradients
    grad_pc = [g.clone() for g in domain_grads]
    
    for i in task_order:

        # Run other tasks
        other_tasks = [j for j in task_order if j != i]

        for j in other_tasks:
            grad_j = domain_grads[j]

            # Compute inner product and check for conflicting gradients
            inner_prod_sign = torch.sum(grad_pc[i].sign()*grad_j.sign())
            inner_prod = torch.sum(grad_pc[i]*grad_j)
            # print(inner_prod_sign)
            if inner_prod_sign < 0:
                # Sustract conflicting component
                grad_pc[i] = inner_prod / (grad_j ** 2).sum() * grad_j

    # Sum task gradients
    new_grads = torch.stack(grad_pc).mean(0)

    return new_grads


class LossEns(Attack):
    """
    I-FGSM Attack
    'Adversarial Examples in the Physical World (ICLR 2017)'(https://arxiv.org/abs/1607.02533)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        targeted (bool): targeted/untargeted attack
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model

    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10
    """

    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, adapted=False, **kwargs):
        super().__init__('I-FGSM', model, epsilon, targeted, random_start, norm, loss, device)
        self.alpha = alpha
        self.epoch = epoch
        self.decay = 0
        self.alpha = epsilon/self.epoch
    def forward(self, data, label, **kwargs):
        """
        The general attack procedure
        Arguments:
            data: (N, C, H, W) tensor for input images
            labels: (N,) tensor for ground-truth labels if untargetd, otherwise targeted labels
        """
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        # Initialize adversarial perturbation
        delta = self.init_delta(data)

        momentum = 0.
        for _ in range(self.epoch):
            # Obtain the output
            logits = self.get_logits(self.transform(data+delta, momentum=momentum))

            # Calculate the loss
            loss = self.get_loss(logits, label)

            # Calculate the gradients
            grad = self.get_grad(loss, delta)
            #print(grad)

            # avg_gradient = (avg_gradient * _ + grad) / (_ + 1)

            # Calculate the momentum
            momentum = self.get_momentum(grad, momentum, decay=self.decay)
            delta = self.update_delta(delta, data, momentum, self.alpha)
        return delta.detach()
    
    def get_logits(self, x, **kwargs):
        """
        The inference stage, which should be overridden when the attack need to change the models (e.g., ensemble-model attack, ghost, etc.) or the input (e.g. DIM, SIM, etc.)
        """
        return torch.concatenate([self.model.models[i](x) for i in range(len(self.model.models))])

    def get_loss(self, logits, label):
        """
        The loss calculation, which should be overrideen when the attack change the loss calculation (e.g., ATA, etc.)
        """
        # Calculate the loss
        return -self.loss(logits, label.repeat(len(self.model.models))) if self.targeted else self.loss(logits, label.repeat(len(self.model.models)))
    
    
    
class LossEns_AIT(Attack):
    """
    I-FGSM Attack
    'Adversarial Examples in the Physical World (ICLR 2017)'(https://arxiv.org/abs/1607.02533)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        targeted (bool): targeted/untargeted attack
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model

    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10
    """

    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, adapted=False, **kwargs):
        super().__init__('I-FGSM', model, epsilon, targeted, random_start, norm, loss, device)
        self.alpha = alpha
        self.epoch = epoch
        self.decay = 1
        self.alpha = epsilon/self.epoch
        self.ops = [self.vertical_shift, self.horizontal_shift, self.vertical_flip, self.horizontal_flip, self.rotate180, self.resize, self.add_noise, self.drop_out]
        
        
    def forward(self, data, label, **kwargs):
        """
        The general attack procedure
        Arguments:
            data: (N, C, H, W) tensor for input images
            labels: (N,) tensor for ground-truth labels if untargetd, otherwise targeted labels
        """
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        # Initialize adversarial perturbation
        delta = self.init_delta(data)

        momentum = 0.
        momentum_list = [0. for _ in range(len(self.model.models))]
        for _ in range(self.epoch):
            # Obtain the output
            # logits = self.get_logits(self.transform(data+delta, momentum=momentum))

            # Calculate the loss
            # loss = self.get_loss(logits, label)
            gs = []
            op = np.random.choice(self.ops, 1)
            for idx, sub_model in enumerate(self.model.models):
                # op = np.random.choice(self.ops, 1)
                # o = sub_model(op[0](data+delta))
                l = self.get_loss(o,label)
                g = self.get_grad(l, delta) 
                momentum_list[idx] = self.get_momentum(g, momentum_list[idx], decay=self.decay)
                gs.append(g)
            grad = pcgrad(gs)
            # grad = torch.stack(gs).mean(0)
            # print(grad.shape)
            # grad = grad.mean(0)
            # Calculate the gradients
            # grad = self.get_grad(loss, delta)
            # print(grad)

            # avg_gradient = (avg_gradient * _ + grad) / (_ + 1)

            # Calculate the momentum
            momentum = self.get_momentum(grad, momentum, decay=self.decay)
            delta = self.update_delta(delta, data, momentum, self.alpha)
        return delta.detach()
    
    def get_logits(self, x, **kwargs):
        """
        The inference stage, which should be overridden when the attack need to change the models (e.g., ensemble-model attack, ghost, etc.) or the input (e.g. DIM, SIM, etc.)
        """
        op = np.random.choice(self.ops, 1)[0]
        # return torch.concatenate([self.model.models[i](op(x)) for i in range(len(self.model.models))])
        return torch.concatenate([self.model.models[i](np.random.choice(self.ops, 1)[0](x)) for i in range(len(self.model.models))])

    def get_loss(self, logits, label):
        """
        The loss calculation, which should be overrideen when the attack change the loss calculation (e.g., ATA, etc.)
        """
        # Calculate the loss
        # return -self.loss(logits, label.repeat(len(self.model.models))) if self.targeted else self.loss(logits, label.repeat(len(self.model.models)))
        return -self.loss(logits, label) if self.targeted else self.loss(logits, label)
    
    def vertical_shift(self, x):
        _, _, w, _ = x.shape
        step = np.random.randint(low = 0, high=w, dtype=np.int32)
        return x.roll(step, dims=2)

    def horizontal_shift(self, x):
        _, _, _, h = x.shape
        step = np.random.randint(low = 0, high=h, dtype=np.int32)
        return x.roll(step, dims=3)

    def vertical_flip(self, x):
        return x.flip(dims=(2,))

    def horizontal_flip(self, x):
        return x.flip(dims=(3,))

    def rotate180(self, x):
        return x.rot90(k=2, dims=(2,3))
    
    def scale(self, x):
        return torch.rand(2)[0] * x
    
    def resize(self, x):
        """
        Resize the input
        """
        _, _, w, h = x.shape
        scale_factor = 0.8
        new_h = int(h * scale_factor)+1
        new_w = int(w * scale_factor)+1
        x = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
        x = F.interpolate(x, size=(w, h), mode='bilinear', align_corners=False)
        return x
    
    def add_noise(self, x):
        return x + torch.randn_like(x) * 0.1

    def gkern(self, kernel_size=3, nsig=3):
        x = np.linspace(-nsig, nsig, kernel_size)
        kern1d = st.norm.pdf(x)
        kernel_raw = np.outer(kern1d, kern1d)
        kernel = kernel_raw / kernel_raw.sum()
        stack_kernel = np.stack([kernel, kernel, kernel])
        stack_kernel = np.expand_dims(stack_kernel, 1)
        return torch.from_numpy(stack_kernel.astype(np.float32)).to(self.device)

    def drop_out(self, x):
        
        return F.dropout2d(x, p=0.2, training=True)
    
    
