from filecmp import DEFAULT_IGNORES
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np


def project(x, original_x, epsilon, _type="linf"):

    if _type == "linf":
        max_x = original_x + epsilon
        min_x = original_x - epsilon

        x = torch.max(torch.min(x, max_x), min_x)
    else:
        raise NotImplementedError

    return x

class AttackPGD():
    def __init__(
        self,
        epsilon=8.0 / 255.0,
        alpha=2.0 / 255.0,
        max_iters=20,
        min_val=0.0, 
        max_val=1.0,
        random_start=True, 
        attack_type='linf',
        reduction4loss="mean",
    ):
        self.eps = epsilon
        self.alpha = alpha
        self.max_iters = max_iters
        self.min_val = min_val
        self.max_val = max_val
        self.random_start = random_start
        self.attack_type = attack_type
        self.reduction4loss = reduction4loss
    
    def perturb(self, model, original_images, labels, inner_update_type='both', params=None):
        if self.random_start:
            rand_perturb = torch.FloatTensor(original_images.shape).uniform_(-self.eps, self.eps)
            rand_perturb = rand_perturb.cuda()
            x = original_images.clone() + rand_perturb
            x = torch.clamp(x, self.min_val, self.max_val)
        else:
            x = original_images.clone()

        x.requires_grad = True 
        
        self.model = model
        self.model.eval()
        
        with torch.enable_grad():
            for _iter in range(self.max_iters):

                self.model.zero_grad()
                
                if (inner_update_type == 'linear_only') or (params is None):
                    outputs = self.model(x)
                else:
                    outputs = self.model(x, params=params)
                
                loss = F.cross_entropy(outputs, labels, reduction=self.reduction4loss)

                grad_outputs = None
                grads = torch.autograd.grad(loss, x, grad_outputs=grad_outputs, only_inputs=True, retain_graph=False)[0]

                if self.attack_type == 'linf':
                    scaled_g = torch.sign(grads.data)
                
                x.data += self.alpha * scaled_g

                x = torch.clamp(x, self.min_val, self.max_val)
                x = project(x, original_images, self.eps, self.attack_type)
        return x.detach()



class AttackBPDA(object):
    def __init__(
        self, 
        model=None, 
        linear=None,
        defense=None,
        eps=None, 
        learning_rate=0.5,
        max_iters=100, 
        min_val=0.0, 
        max_val=1.0,
    ):
        self.model = model
        self.linear = linear
        self.epsilon = eps
        self.loss_fn = nn.CrossEntropyLoss(reduction='sum')
        self.defense = defense
        self.clip_min = min_val
        self.clip_max = max_val

        self.LEARNING_RATE = learning_rate
        self.MAX_ITERATIONS = max_iters

    def attack(self, x, y):
        """
        Given examples (X_nat, y), returns their adversarial
        counterparts with an attack length of epsilon.
        """

        self.model.eval()
        if self.linear is not None:
            self.linear.eval()

        adv = x.detach().clone()
        lower = np.clip(x.detach().cpu().numpy() - self.epsilon, self.clip_min, self.clip_max)
        upper = np.clip(x.detach().cpu().numpy() + self.epsilon, self.clip_min, self.clip_max)


        with torch.enable_grad():    
            for i in range(self.MAX_ITERATIONS):
                self.model.zero_grad()
                if self.linear is not None:
                    self.linear.zero_grad()

                adv_purified = self.defense(adv)
                adv_purified.requires_grad_()
                adv_purified.retain_grad()

                if self.linear is None:
                    scores = self.model(adv_purified)
                else:
                    scores = self.linear(self.model(adv_purified))
                loss = F.cross_entropy(scores, y, reduction='sum')
                loss.backward()

                grad_sign = adv_purified.grad.data.sign()
                adv += self.LEARNING_RATE * grad_sign

                adv_img = np.clip(adv.detach().cpu().numpy(), lower, upper)
                adv = torch.Tensor(adv_img).to(y.device)
        return adv.detach()


class AttackFGSM():
    def __init__(
        self,
        epsilon=0.2,
        min_val=0.0, 
        max_val=1.0, 
        reduction4loss="mean",
    ):
        self.eps = epsilon
        self.min_val = min_val
        self.max_val = max_val
        self.reduction4loss = reduction4loss
    
    def perturb(self, model, original_images, labels, inner_update_type='both', params=None):
        x = original_images.clone()

        x.requires_grad = True 
        
        self.model = model
        self.model.eval()
        
        with torch.enable_grad():
            
            self.model.zero_grad()
                
            if (inner_update_type == 'linear_only') or (params is None):
                outputs = self.model(x)
            else:
                outputs = self.model(x, params=params)
                
            loss = F.cross_entropy(outputs, labels, reduction=self.reduction4loss)

            grad_outputs = None
            grads = torch.autograd.grad(loss, x, grad_outputs=grad_outputs, only_inputs=True, retain_graph=False)[0]
            grads = torch.sign(grads) # Take the sign of the gradient.
            x = torch.clamp(x.data + self.eps*grads,self.min_val,self.max_val)     # x_adv = x + epsilon*grad
                
        return x.detach()


def pgd_attack(
    model,
    original_images,
    labels, 
    linear=None, 
    params=None, 
    update_type='both', 
    eps=8.0 / 255.0,
    alpha=2.0 / 255.0,
    max_iters=20,
    min_val=0.0, 
    max_val=1.0,
    random_start=True, 
    _type='linf',
    reduction4loss="mean",
):
    if random_start:
        rand_perturb = torch.FloatTensor(original_images.shape).uniform_(-eps, eps)
        rand_perturb = rand_perturb.cuda()
        x = original_images.clone() + rand_perturb
        x = torch.clamp(x, min_val, max_val)
    else:
        x = original_images.clone()

    x.requires_grad = True 
    
    model.eval()
    if linear is not None:
        linear.eval()    

    with torch.enable_grad():
        for _iter in range(max_iters):

            model.zero_grad()
            if linear is not None:
                linear.zero_grad()
            
            if linear is None:
                outputs = model(x)
            else:
                if not (params is None):
                    if update_type == 'encoder_only':
                        outputs = linear(model(x, params=params))
                    elif update_type == 'linear_only':
                        outputs = linear(model(x), params=params)
                    elif update_type == 'both':
                        outputs = linear(model(x, params=params), params=params)
                else:
                    outputs = linear(model(x))
            loss = F.cross_entropy(outputs, labels, reduction=reduction4loss)
            
            grad_outputs = None
            grads = torch.autograd.grad(loss, x, grad_outputs=grad_outputs, only_inputs=True, retain_graph=False)[0]

            if _type == 'linf':
                scaled_g = torch.sign(grads.data)
            
            x.data += alpha * scaled_g

            x = torch.clamp(x, min_val, max_val)
            x = project(x, original_images, eps, _type)
    return x.detach()


def aq_r2d2_attack(
    x_qry, 
    feats_sprt, 
    target_qry, 
    target_sprt, 
    way, 
    shot, 
    batch_size, 
    train_n_qry, 
    model, 
    linear, 
    random_init=True,
    eps=8.0/255.0,
    alpha=8.0/2550.0,
    max_iters=20,
):  

    new_labels_query = target_qry
    new_labels_query = new_labels_query.view(new_labels_query.size()[0]*new_labels_query.size()[1])
    x = x_qry.detach()

    model.eval()
    linear.eval()

    if random_init:
        x = x + torch.zeros_like(x).uniform_(-eps, eps)
    for i in range(max_iters):
        x.requires_grad_()
        with torch.enable_grad():
            feats_qry_adv = model(x.reshape([-1] + list(x.shape[-3:]))).reshape(batch_size, train_n_qry, -1)
            ridge_sol = linear(feats_sprt, target_sprt, way, shot)
            logits = linear.scale * torch.bmm(feats_qry_adv, ridge_sol)

            logits = logits.view(logits.size()[0]*logits.size()[1],logits.size()[2])
            loss = F.cross_entropy(logits, new_labels_query, size_average=False)
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + alpha*torch.sign(grad.detach())
        x = torch.min(torch.max(x, x_qry - eps), x_qry + eps)
        x = torch.clamp(x, 0.0, 1.0)
    return x.detach()

