import torch
import torch.nn.functional as F
from attack.rocl import pairwise_similarity, NT_xent

def project(x, original_x, epsilon, _type="linf"):

    if _type == "linf":
        max_x = original_x + epsilon
        min_x = original_x - epsilon

        x = torch.max(torch.min(x, max_x), min_x)
    else:
        raise NotImplementedError

    return x

class RepresentationAdv:
    def __init__(
        self,
        epsilon,
        alpha,
        min_val,
        max_val,
        max_iters,
        attack_type="linf",
        attack_loss_type="sim",
        random_start=True,
        regularize="original",
    ):
        
        self.regularize = regularize
        # Maximum perturbation
        self.epsilon = epsilon
        # Movement multiplier per iteration
        self.alpha = alpha
        # Minimum value of the pixels
        self.min_val = min_val
        # Maximum value of the pixels
        self.max_val = max_val
        # Maximum numbers of iteration to generated adversaries
        self.max_iters = max_iters
        # The perturbation of epsilon
        self.attack_type = attack_type
        self.random_start = random_start
        self.attack_loss_type = attack_loss_type

    def print_attack_info(self):
        print(
            "[Representation attack info]\nepsilon: {}\nalpha: {}\nmin_val: {}\nmax_val: {}\nmax_iters: {}\nattack_type: {}\nattack_loss_type: {}\nrandom_start: {}\n".format(
                self.epsilon,
                self.alpha,
                self.min_val,
                self.max_val,
                self.max_iters,
                self.attack_type,
                self.attack_loss_type,
                self.random_start,
            )
        )

    def perturb(self, model, original_images, target, inner_update_type='both', params=None, params2=None): #weight
        if self.random_start:
            rand_perturb = torch.FloatTensor(original_images.shape).uniform_(
                -self.epsilon, self.epsilon
            )
            rand_perturb = rand_perturb.float().cuda()
            x = original_images.float().clone() + rand_perturb
            x = torch.clamp(x, self.min_val, self.max_val)
        else:
            x = original_images.clone()

        x.requires_grad = True

        self.model = model
        self.model.eval()
        batch_size = len(x)

        with torch.enable_grad():
            for _iter in range(self.max_iters):

                self.model.zero_grad()

                if self.attack_loss_type == "sim":
                    if params2 is None:
                        inputs = torch.cat((x, target))
                        if (params is None) or (inner_update_type == 'linear_only'):
                            _, output = self.model(inputs, feat=True)
                        else:
                            _, output = self.model(inputs, params=params, feat=True)
                    else:
                        output1 = self.model(x, params=params, feat=False)
                        output2 = self.model(target, params=params2, feat=False)
                        output = torch.cat((output1, output2), dim=0)
                        
                    similarity, _ = pairwise_similarity(
                            output, temperature=0.5, multi_gpu=False, adv_type="None"
                        )
                    loss = NT_xent(similarity, "None")
                
                grads = torch.autograd.grad(
                    loss, x, grad_outputs=None, only_inputs=True, retain_graph=False
                )[0]

                if self.attack_type == "linf":
                    scaled_g = torch.sign(grads.data)

                x.data += self.alpha * scaled_g

                x = torch.clamp(x, self.min_val, self.max_val)
                x = project(x, original_images, self.epsilon, self.attack_type)

        self.model.train()

        return x.detach()
