"""
this code is modified from 
https://github.com/Kim-Minseon/RoCL
"""

import torch
import torch.nn.functional as F

def project(x, original_x, epsilon, _type='linf'):

    if _type == 'linf':
        max_x = original_x + epsilon
        min_x = original_x - epsilon

        x = torch.max(torch.min(x, max_x), min_x)
    else:
        raise NotImplementedError

    return x


class RepresentationAdv():

    def __init__(self, model, epsilon, alpha, min_val, max_val, max_iters, num_sample,  _type='linf', loss_type='sim', regularize='original',random_start=True):

        # Model
        self.model = model
        
        self.regularize = regularize
        # Maximum perturbation
        self.epsilon = epsilon
        # Movement multiplier per iteration
        self.alpha = alpha
        # Minimum value of the pixels
        self.min_val = min_val
        # Maximum value of the pixels
        self.max_val = max_val
        # Maximum numbers of iteration to generated adversaries
        self.max_iters = max_iters
        # The perturbation of epsilon
        self._type = _type
        # number of samples in the same class in a batch
        self.num_sample = num_sample
        # loss type
        self.loss_type = loss_type
        
        self.random_start = random_start


            
    def alignment_attacker(self, original_images, target, optimizer, weight, random_start=True):
        
        if self.random_start:
            rand_perturb = torch.FloatTensor(original_images.shape).uniform_(
                -self.epsilon, self.epsilon)
            rand_perturb = rand_perturb.float().cuda()
            x = original_images.float().clone() + rand_perturb
            x = torch.clamp(x,self.min_val, self.max_val)
            
        else:
            x = original_images.clone()
        x.requires_grad = True
        self.model.eval()
        batch_size = len(x)
        with torch.enable_grad():
            for _iter in range(self.max_iters):
                self.model.zero_grad()                
                output = self.model(x)
                norm = output.pow(2).sum(dim=1,keepdim=True).reshape((int(batch_size/self.num_sample), self.num_sample,-1))
                output = output.reshape((int(batch_size/self.num_sample), self.num_sample,-1))
                # because the norm of embeddings is 1, attacking alignment loss is equal to attack pairwise similarity
                pairwise_similarity = output.matmul(output.permute(0,2,1))
                if self.loss_type == 'sim':
                    loss = -pairwise_similarity.mean()
                grads = torch.autograd.grad(loss, x, grad_outputs=None, only_inputs=True, retain_graph=False)[0]
                if self._type == 'linf':
                    #FSGM method
                    scaled_g = torch.sign(grads.data)
                    x.data += self.alpha * scaled_g
                    x = project(x, original_images, self.epsilon, self._type)
        return x.detach()
    
    def uniformity_attacker(self, original_images, target, optimizer, weight, random_start=True):
        
        if self.random_start:
            rand_perturb = torch.FloatTensor(original_images.shape).uniform_(
                -self.epsilon, self.epsilon)
            rand_perturb = rand_perturb.float().cuda()
            x = original_images.float().clone() + rand_perturb
            x = torch.clamp(x,self.min_val, self.max_val)
            
        else:
            x = original_images.clone()
        x.requires_grad = True
        self.model.eval()
        batch_size = len(x)
        with torch.enable_grad():
            for _iter in range(self.max_iters):
                model.zero_grad()
#                 self.projector.zero_grad()

                output = self.model(x)
                norm = output.pow(2).sum(dim=1,keepdim=True).reshape((int(batch_size/self.num_sample), self.num_sample,-1))
                
#                 print(norm.shape)
                output1 = output.reshape((int(batch_size/self.num_sample), self.num_sample,-1))
#                 print(output.shape)
                pairwise_similarity = output1.matmul(output1.permute(0,2,1))
#                     print(output.shape, output.transpose(0,1).shape)
                total = output.matmul(output.transpose(0,1))
#                     shape = total.shape[0]
                unif = total.exp().sum() - pairwise_similarity.exp().sum()
                loss = torch.log(unif/(batch_size*(batch_size-self.num_sample)))

                grads = torch.autograd.grad(loss, x, grad_outputs=None, only_inputs=True, retain_graph=False)[0]
                if self._type == 'linf':
                    #FSGM method
                    scaled_g = torch.sign(grads.data)
                    x.data += self.alpha * scaled_g
                    x = project(x, original_images, self.epsilon, self._type)
        return x.detach()

    
