import torch
from torch import nn

import pdb


def nce(predict, labels):
    return -nn.CrossEntropyLoss(reduction="sum")(predict, labels)

def compute_consts(images, delta, eps):

    n = images.shape[0]
    with torch.no_grad():
        cu = delta.view(n, -1) - torch.min(eps * torch.ones_like(images.view(n, -1)), 1-images.view(n, -1)) # upper half inequality constraints
        cl = - delta.view(n, -1) - torch.min(eps * torch.ones_like(images.view(n, -1)), images.view(n, -1)) # lower half inequality constraints
        #c = torch.cat([cu, cl], 1)
    return cu, cl

class Attack(object):

    def __init__(self, predict, loss_fn, clip_min, clip_max):
        """Create an Attack instance."""
        self.predict = predict
        self.loss_fn = loss_fn
        self.clip_min = clip_min
        self.clip_max = clip_max

    def perturb(self, x, **kwargs):

        error = "Sub-classes must implement perturb."
        raise NotImplementedError(error)

    def __call__(self, *args, **kwargs):
        return self.perturb(*args, **kwargs)


class LogBarrierAttacks(object):

    def __init__(self, model, eps=8/255, alpha=2/255, steps=7,
                 random_start=True,
                 attack_loss=nce,
                 mu=0.1,
                 clip_min=0.,
                 clip_max=1.,
                 epsilon=1e-6,
                 lbs=1e-6):

        self.model = model
        self.eps = eps
        self.steps = steps
        self.lr = alpha
        self.attack_loss = attack_loss
        self.clip_min = clip_min
        self.clip_max = clip_max
        self.rand_init = random_start
        self.mu = mu
        self.e = epsilon
        self.lbs = lbs
        self.delta_gross = None

    def attack(self, images, labels):
        """log barrier attack"""

        delta = torch.zeros_like(images) # same device as images
        if self.rand_init:
            # initialize with random point interior to constraint set
            delta = delta.uniform_(-self.eps + self.e, self.eps - self.e)
            delta = torch.clamp(delta + images, min=self.clip_min + self.e, max=self.clip_max - self.e) - images

        loss = self.attack_loss

        mu = 1. * self.mu
        delta_var = delta.detach().requires_grad_(True)
        opt = torch.optim.Adam([delta_var], lr=self.lr, betas=(0.1, 0.1))

        for ii in range(self.steps):
            delta = delta_var.detach().requires_grad_(True)
            opt.zero_grad()

            outputs = self.model(images + delta)
            cost = loss(outputs, labels)

            grad_attack_loss = torch.autograd.grad(cost, delta, retain_graph=False, create_graph=False)[0]

            cup, clow = compute_consts(images, delta, self.eps)
            grad_barrier = -mu/cup + mu/clow

            grad = grad_attack_loss + self.lbs * grad_barrier.view(delta.shape)

            delta_var.grad = grad.data

            opt.step()
            # pdb.set_trace()

            self.delta_gross = delta_var.clone().detach()

            # the minimization of the barrier function tend to make the projection step nonimportant
            delta_var.data = torch.clamp(delta_var, min=-self.eps + self.e, max=self.eps - self.e)
            delta_var.data = torch.clamp(delta_var + images, min=self.clip_min + self.e, max=self.clip_max - self.e) - images
            #

            mu = mu * 0.1

        del opt
        adv_images = images + delta_var
        return adv_images.detach()



    def attack_be(self, images, labels):
        """log barrier extension attack"""

        delta = torch.zeros_like(images)  # same device as images
        if self.rand_init:
            # initialize with random point in constraint set
            delta = delta.uniform_(-self.eps, self.eps)
            delta = torch.clamp(delta + images, min=self.clip_min, max=self.clip_max) - images

        loss = self.attack_loss

        for ii in range(self.steps):
            delta = delta.detach().requires_grad_(True)

            outputs = self.model(images + delta)
            cost = loss(outputs, labels)

            grad_attack_loss = torch.autograd.grad(cost, delta, retain_graph=False, create_graph=False)[0]

            cu, cl = compute_consts(images, delta, self.eps)

            cu[cu > -self.mu**2] = 1. / self.mu
            cu[cu <= -self.mu**2] = -self.mu / cu[cu <= -self.mu**2]
            cl[cl > -self.mu**2] = 1. / self.mu
            cl[cl <= -self.mu**2] = -self.mu / cl[cl <= -self.mu**2]
            grad_barrier = cu - cl

            grad = grad_attack_loss + grad_barrier.view(delta.shape)
            delta = delta.detach() - self.lr * grad.detach()

            self.delta_gross = delta.clone().detach()

            # the minimization of the barrier function will tend to make the projection step useless
            delta = torch.clamp(delta, min=-self.eps, max=self.eps)
            delta = torch.clamp(delta + images, min=self.clip_min, max=self.clip_max) - images

            self.mu = self.mu * 0.1

        adv_images = images + delta
        return adv_images.detach()





