import torch
import torch.nn.functional as F
import torch.nn as nn


upper_limit, lower_limit = 1,0

def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def pgd(model, X, y, normalize, targeted=False, rs=True, 
        epsilon=8, attack_iters=10, restarts=1):

    epsilon = epsilon / 255.0
    alpha = epsilon / 4
    # alpha = alpha / 255.0
    device = X.device
    model.eval()
    assert restarts > 0
    for _ in range(restarts):
        if rs:
            x_adv = X.detach() + torch.zeros_like(X).uniform_(-epsilon, epsilon).to(device).detach()
        x_adv = torch.clamp(x_adv, min=0, max=1).detach()

        for _ in range(attack_iters):
            x_adv.requires_grad_()
            with torch.enable_grad():
                output = model(normalize(x_adv))
                if targeted:
                    loss = - F.cross_entropy(output, y)
                else:
                    loss = F.cross_entropy(output, y)
            grad = torch.autograd.grad(loss, [x_adv])[0]
            x_adv = x_adv.detach() + alpha * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, X - epsilon), X + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)

    return x_adv


def pgd_l2(model, X, y, normalize, targeted=False, rs=True, epsilon=255,
                 attack_iters=10, restarts=1):

    epsilon = epsilon / 255.0
    alpha = epsilon / 5
    model.eval()
    
    X = X.clone().detach().to(X.device)
    y = y.clone().detach().to(X.device)
    
    x_adv = X.clone().detach()
    batch_size = len(X)

    assert restarts > 0
    for _ in range(restarts):
        if rs:
            # Starting at a uniformly random point
            delta = torch.empty_like(x_adv).normal_()
            d_flat = delta.view(x_adv.size(0), -1)
            n = d_flat.norm(p=2, dim=1).view(x_adv.size(0), 1, 1, 1)
            r = torch.zeros_like(n).uniform_(0, 1)
            delta *= r/n* epsilon
            x_adv = torch.clamp(x_adv + delta, min=0, max=1).detach()
        x_adv = torch.clamp(x_adv, min=0, max=1).detach()

        for _ in range(attack_iters):
            x_adv.requires_grad_()
            with torch.enable_grad():
                output = model(normalize(x_adv))               
                if targeted:
                    loss = - F.cross_entropy(output, y)
                else:
                    loss = F.cross_entropy(output, y)
            
            # Update adversarial images
            grad = torch.autograd.grad(loss, x_adv,
                                       retain_graph=False, create_graph=False)[0]
            grad_norms = torch.norm(grad.view(batch_size, -1), p=2, dim=1) + 1e-10  # nopep8
            grad = grad / grad_norms.view(batch_size, 1, 1, 1)
            x_adv = x_adv.detach() + alpha * grad

            delta = x_adv - X
            delta_norms = torch.norm(delta.view(batch_size, -1), p=2, dim=1)
            factor = epsilon / delta_norms
            factor = torch.min(factor, torch.ones_like(delta_norms))
            delta = delta * factor.view(-1, 1, 1, 1)

            x_adv = torch.clamp(X + delta, min=0, max=1).detach()

    return x_adv



def pgd_whitebox2(model, X, y, normalize, epsilon=8, alpha=2, attack_iters=20, restarts=1):
    epsilon = epsilon / 255.0
    alpha = alpha / 255.0
    model.eval()
    max_loss = torch.zeros(y.shape[0]).cuda()
    max_delta = torch.zeros_like(X).cuda()
    assert restarts > 0
    for _ in range(restarts):
        delta = torch.zeros_like(X).cuda()
        delta.uniform_(-epsilon, epsilon)
        delta.data = clamp(delta, lower_limit - X, upper_limit - X)
        delta.requires_grad = True
        for _ in range(attack_iters):
            output = model(normalize(X + delta))
            index = slice(None, None, None)
            # early stop
            # index = torch.where(output.max(1)[1] == y)[0]
            if not isinstance(index, slice) and len(index) == 0:
                break

            loss = F.cross_entropy(output, y)
            loss.backward()
            grad = delta.grad.detach()

            d = delta[index, :, :, :]
            g = grad[index, :, :, :]
            x = X[index, :, :, :]

            d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
            d = clamp(d, lower_limit - x, upper_limit - x)
            delta.data[index, :, :, :] = d
            delta.grad.zero_()
        all_loss = F.cross_entropy(model(normalize(X + delta)), y, reduction='none').detach()
        max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)

    model.train()
    return max_delta


# suitable for mixup and cutmix
def pgd_whitebox_mix(model, X, y, y2, lam, normalize, epsilon=8, alpha=2,
                 attack_iters=20, restarts=1):

    epsilon = epsilon / 255.0
    alpha = alpha / 255.0
    model.eval()
    assert restarts > 0
    for _ in range(restarts):
        x_adv = X.detach() + torch.zeros_like(X).uniform_(-epsilon, epsilon).cuda().detach()
        x_adv = torch.clamp(x_adv, min=0, max=1).detach()

        for _ in range(attack_iters):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss = mixup_criterion(nn.CrossEntropyLoss(), model(normalize(x_adv)), y, y2, lam)
                # loss = F.cross_entropy(model(normalize(x_adv)), y)
            grad = torch.autograd.grad(loss, [x_adv])[0]
            x_adv = x_adv.detach() + alpha * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, X - epsilon), X + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)

    return x_adv
