import torch
import torch.nn.functional as F


class L2PGDAttack(object):
    def __init__(self, model, num_steps=10, epsilon=1.0):
        self.model = model
        self.num_steps = num_steps
        self.epsilon = epsilon
        self.alpha = epsilon / num_steps * 2

    def perturb(self, x, y):
        B = x.shape[0]
        adv_x = x.detach()
        adv_x = adv_x + torch.zeros_like(adv_x).uniform_(-self.epsilon, self.epsilon)
        delta = adv_x - x
        delta_norm = delta.view(B, -1).norm(dim=1).view(-1,1,1)
        clamp_delta = delta / delta_norm * torch.clamp(delta_norm,0,self.epsilon)
        adv_x = x + clamp_delta
        #adv_x = torch.clamp(adv_x,0,1)

        for i in range(self.num_steps):
            adv_x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(adv_x)
                loss = F.cross_entropy(logits, y)
            grad = torch.autograd.grad(loss, [adv_x])[0]
            grad_norm = grad.view(B, -1).norm(dim=1).view(-1,1,1)
            grad = grad / grad_norm
            adv_x = adv_x.detach() + self.alpha * grad.detach()

            delta = adv_x - x
            delta_norm = delta.view(B, -1).norm(dim=1).view(-1,1,1)
            clamp_delta = delta / delta_norm * torch.clamp(delta_norm,0,self.epsilon)
            adv_x = x + clamp_delta
            #adv_x = torch.clamp(adv_x,0,1)
        return adv_x.detach()
