import torch

from ..utils import *
from ..attack import Attack
from random import shuffle
import torch.nn.functional as F
import scipy.stats as st


def pcgrad(domain_grads):
    """ Projecting conflicting gradients (PCGrad). """

    task_order = list(range(len(domain_grads)))

    # Run tasks in random order
    shuffle(task_order)

    # Initialize task gradients
    grad_pc = [g.clone() for g in domain_grads]

    for i in task_order:

        # Run other tasks
        other_tasks = [j for j in task_order if j != i]

        for j in other_tasks:
            grad_j = domain_grads[j]

            # Compute inner product and check for conflicting gradients
            inner_prod_sign = torch.sum(grad_pc[i].sign()*grad_j.sign())
            inner_prod = torch.sum(grad_pc[i]*grad_j)

            if inner_prod_sign < 0:
                # Sustract conflicting component
                grad_pc[i] -= inner_prod / (grad_j ** 2).sum() * grad_j

    # Sum task gradients
    new_grads = torch.stack(grad_pc).mean(0)

    return new_grads


class PredictEns(Attack):
    """
    I-FGSM Attack
    'Adversarial Examples in the Physical World (ICLR 2017)'(https://arxiv.org/abs/1607.02533)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        targeted (bool): targeted/untargeted attack
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model

    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10
    """

    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, **kwargs):
        super().__init__('I-FGSM', model, epsilon, targeted, random_start, norm, loss, device)
        self.alpha = alpha
        self.epoch = epoch
        self.decay = 0
        self.alpha = epsilon/self.epoch
        
    def get_logits(self, x, **kwargs):
        """
        The inference stage, which should be overridden when the attack need to change the models (e.g., ensemble-model attack, ghost, etc.) or the input (e.g. DIM, SIM, etc.)
        """
        return self.model.forward_predict(x)
        
        
class AIT_predict(Attack):
    """
    I-FGSM Attack
    'Adversarial Examples in the Physical World (ICLR 2017)'(https://arxiv.org/abs/1607.02533)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        targeted (bool): targeted/untargeted attack
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model

    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10
    """

    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, **kwargs):
        super().__init__('I-FGSM', model, epsilon, targeted, random_start, norm, loss, device)
        self.alpha = alpha
        self.epoch = epoch
        self.decay = 1
        self.alpha = epsilon/self.epoch
        self.ops = [self.vertical_shift, self.horizontal_shift, self.vertical_flip, self.horizontal_flip, self.rotate180, self.resize, self.add_noise, self.drop_out]
    
    def transform(self, data, **kwargs):
        selected_ops = np.random.choice(self.ops, 1)
        return selected_ops[0](data)
    
    
    
    def forward(self, data, label, **kwargs):
        """
        The general attack procedure
        Arguments:
            data: (N, C, H, W) tensor for input images
            labels: (N,) tensor for ground-truth labels if untargetd, otherwise targeted labels
        """
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        # Initialize adversarial perturbation
        delta = self.init_delta(data, label = label)
        self.random_start = True
        delta_list = [self.init_delta(data) for _ in range(len(self.model.models))]
        self.random_start = False
        
        momentum = 0.
        momentum_list = [0. for _ in range(len(self.model.models))]
        # momentum_list = []
        for _ in range(self.epoch):
            # Obtain the output
            grads = []
            ls=[]
            os=[]
            ens_grad = 0
            op = np.random.choice(self.ops, 1)
            for idx, sub_model in enumerate(self.model.models):
                # op = np.random.choice(self.ops, 1)
                o = sub_model(op[0](data+delta))
                o = torch.nn.Softmax(dim=1)(o)
                os.append(o)
            
            # grads = pcgrad(grads)
            # avg_gradient = (avg_gradient * _ + grad) / (_ + 1)
            os_mean = torch.stack(os).mean(0)
            
            ls = self.get_loss(os_mean, label)
            grad = self.get_grad(ls, delta)
            # grad = torch.autograd.grad(ls, delta, allow_unused=True)[0] # self.get_grad(ls, delta)
            # Calculate the momentum
            # grad = ens_grad/len(self.model.models)
            momentum = self.get_momentum(grad, momentum, decay=self.decay)
            delta = self.update_delta(delta, data, momentum, self.alpha)
            
            # print('iter: ', _,'mean: ', torch.mean(delta),'std: ', torch.std(delta),'grad l2 norm: ', torch.norm(grad,p=2),'2-th order grad: mean:', torch.mean(second_grad),'std: ', torch.std(second_grad))
        # exit()
        return delta.detach()   
    def vertical_shift(self, x):
        _, _, w, _ = x.shape
        step = np.random.randint(low = 0, high=w, dtype=np.int32)
        return x.roll(step, dims=2)

    def horizontal_shift(self, x):
        _, _, _, h = x.shape
        step = np.random.randint(low = 0, high=h, dtype=np.int32)
        return x.roll(step, dims=3)

    def vertical_flip(self, x):
        return x.flip(dims=(2,))

    def horizontal_flip(self, x):
        return x.flip(dims=(3,))

    def rotate180(self, x):
        return x.rot90(k=2, dims=(2,3))
    
    def scale(self, x):
        return torch.rand(2)[0] * x
    
    def resize(self, x):
        """
        Resize the input
        """
        _, _, w, h = x.shape
        scale_factor = 0.8
        new_h = int(h * scale_factor)+1
        new_w = int(w * scale_factor)+1
        x = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
        x = F.interpolate(x, size=(w, h), mode='bilinear', align_corners=False)
        return x
    
    def add_noise(self, x):
        return x + torch.randn_like(x) * 0.1

    def gkern(self, kernel_size=3, nsig=3):
        x = np.linspace(-nsig, nsig, kernel_size)
        kern1d = st.norm.pdf(x)
        kernel_raw = np.outer(kern1d, kern1d)
        kernel = kernel_raw / kernel_raw.sum()
        stack_kernel = np.stack([kernel, kernel, kernel])
        stack_kernel = np.expand_dims(stack_kernel, 1)
        return torch.from_numpy(stack_kernel.astype(np.float32)).to(self.device)

    def drop_out(self, x):
        
        return F.dropout2d(x, p=0.2, training=True)
    
    

class GA_predict(Attack):
    """
    I-FGSM Attack
    'Adversarial Examples in the Physical World (ICLR 2017)'(https://arxiv.org/abs/1607.02533)

    Arguments:
        model (torch.nn.Module): the surrogate model for attack.
        epsilon (float): the perturbation budget.
        alpha (float): the step size.
        epoch (int): the number of iterations.
        targeted (bool): targeted/untargeted attack
        random_start (bool): whether using random initialization for delta.
        norm (str): the norm of perturbation, l2/linfty.
        loss (str): the loss function.
        device (torch.device): the device for data. If it is None, the device would be same as model

    Official arguments:
        epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10
    """

    def __init__(self, model, epsilon=16/255, alpha=1.6/255, epoch=10, targeted=False, random_start=False, 
                norm='linfty', loss='crossentropy', device=None, **kwargs):
        super().__init__('I-FGSM', model, epsilon, targeted, random_start, norm, loss, device)
        self.alpha = alpha
        self.epoch = epoch
        self.decay = 1
        self.alpha = epsilon/self.epoch
        self.ops = [self.vertical_shift, self.horizontal_shift, self.vertical_flip, self.horizontal_flip, self.rotate180, self.resize, self.add_noise, self.drop_out]
    
    def transform(self, data, **kwargs):
        selected_ops = np.random.choice(self.ops, 1)
        return selected_ops[0](data)
    
    
    
    def forward(self, data, label, **kwargs):
        """
        The general attack procedure
        Arguments:
            data: (N, C, H, W) tensor for input images
            labels: (N,) tensor for ground-truth labels if untargetd, otherwise targeted labels
        """
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        # Initialize adversarial perturbation
        delta = self.init_delta(data, label = label)
        self.random_start = True
        delta_list = [self.init_delta(data) for _ in range(len(self.model.models))]
        self.random_start = False
        
        momentum = 0.
        momentum_list = [0. for _ in range(len(self.model.models))]
        # momentum_list = []
        for _ in range(self.epoch):
            # Obtain the output
            grads = []
            ls=[]
            os=[]
            ens_grad = 0
            op = np.random.choice(self.ops, 1)
            delta_list = [delta.clone() for _ in range(len(self.model.models))]
            for idx, sub_model in enumerate(self.model.models):
                # op = np.random.choice(self.ops, 1)
                o = sub_model(op[0](data+delta_list[idx]))
                o = torch.nn.Softmax(dim=1)(o)
                os.append(o)
            
            # grads = pcgrad(grads)
            # avg_gradient = (avg_gradient * _ + grad) / (_ + 1)
            os_mean = torch.stack(os).mean(0)
            
            ls = self.get_loss(os_mean, label)
            grad_list = [self.get_grad(ls, delta_list[i]) for i in range(len(self.model.models))]
            grad = pcgrad(grad_list)
            # grad = self.get_grad(ls, delta)
            # grad = torch.autograd.grad(ls, delta, allow_unused=True)[0] # self.get_grad(ls, delta)
            # Calculate the momentum
            # grad = ens_grad/len(self.model.models)
            momentum = self.get_momentum(grad, momentum, decay=self.decay)
            delta = self.update_delta(delta, data, momentum, self.alpha)
            
            # print('iter: ', _,'mean: ', torch.mean(delta),'std: ', torch.std(delta),'grad l2 norm: ', torch.norm(grad,p=2),'2-th order grad: mean:', torch.mean(second_grad),'std: ', torch.std(second_grad))
        # exit()
        return delta.detach()   
    def get_grad(self, loss, delta, **kwargs):
        """
        The gradient calculation, which should be overridden when the attack need to tune the gradient (e.g., TIM, variance tuning, enhanced momentum, etc.)
        """
        return torch.autograd.grad(loss, delta, retain_graph=True, create_graph=False)[0]
    def vertical_shift(self, x):
        _, _, w, _ = x.shape
        step = np.random.randint(low = 0, high=w, dtype=np.int32)
        return x.roll(step, dims=2)

    def horizontal_shift(self, x):
        _, _, _, h = x.shape
        step = np.random.randint(low = 0, high=h, dtype=np.int32)
        return x.roll(step, dims=3)

    def vertical_flip(self, x):
        return x.flip(dims=(2,))

    def horizontal_flip(self, x):
        return x.flip(dims=(3,))

    def rotate180(self, x):
        return x.rot90(k=2, dims=(2,3))
    
    def scale(self, x):
        return torch.rand(2)[0] * x
    
    def resize(self, x):
        """
        Resize the input
        """
        _, _, w, h = x.shape
        scale_factor = 0.8
        new_h = int(h * scale_factor)+1
        new_w = int(w * scale_factor)+1
        x = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
        x = F.interpolate(x, size=(w, h), mode='bilinear', align_corners=False)
        return x
    
    def add_noise(self, x):
        return x + torch.randn_like(x) * 0.1

    def gkern(self, kernel_size=3, nsig=3):
        x = np.linspace(-nsig, nsig, kernel_size)
        kern1d = st.norm.pdf(x)
        kernel_raw = np.outer(kern1d, kern1d)
        kernel = kernel_raw / kernel_raw.sum()
        stack_kernel = np.stack([kernel, kernel, kernel])
        stack_kernel = np.expand_dims(stack_kernel, 1)
        return torch.from_numpy(stack_kernel.astype(np.float32)).to(self.device)

    def drop_out(self, x):
        
        return F.dropout2d(x, p=0.2, training=True)