import torch
import torch.nn.functional as F
import torch.nn as nn

upper_limit, lower_limit = 1,0

def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)


# FGSM-RS (random start)
# \alpha = 1.25 * \epsilon
def fgsm(model, X, y, normalize, rs=True, epsilon=8, targeted=False):
    # rs: random start
    epsilon = epsilon / 255.0
    alpha = epsilon * 1.25
    model.eval()

    delta = torch.zeros_like(X).cuda()
    if rs:
        delta.uniform_(-epsilon, epsilon)
    delta = clamp(delta, lower_limit - X, upper_limit - X)
    delta.requires_grad = True

    output = model(normalize(X + delta))

    if targeted:
        loss = -F.cross_entropy(output, y)
    else:
        loss = F.cross_entropy(output, y)
    loss.backward()
    grad = delta.grad.detach()

    delta = torch.clamp(delta + alpha * torch.sign(grad), min=-epsilon, max=epsilon)
    delta = clamp(delta, lower_limit - X, upper_limit - X)

    return delta.detach() + X



# FGSM-RS l2 norm
# \alpha = 1.25 * \epsilon
def fgsm_l2(model, X, y, normalize, rs=False, epsilon=255, targeted=False):
    # rs: random start
    epsilon = epsilon / 255.0
    alpha = epsilon * 1.25
    model.eval()

    X = X.clone().detach().cuda()
    y = y.clone().detach().cuda()

    adv_images = X.clone().detach()
    batch_size = len(X)
    if rs:
        # Starting at a uniformly random point
        delta = torch.empty_like(adv_images, requires_grad=True).normal_()
        d_flat = delta.view(adv_images.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(adv_images.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon
        
    else:
        delta = torch.zeros_like(adv_images, requires_grad=True)
    
    adv_images = torch.clamp(adv_images + delta, min=0, max=1).detach()
    adv_images.requires_grad = True
    output = model(normalize(adv_images))

    if targeted:
        loss = - F.cross_entropy(output, y)
    else:
        loss = F.cross_entropy(output, y)

    # Update adversarial images
    grad = torch.autograd.grad(loss, adv_images,
                               retain_graph=False, create_graph=False)[0]
    grad_norms = torch.norm(grad.view(batch_size, -1), p=2, dim=1) + 1e-10  # nopep8
    grad = grad / grad_norms.view(batch_size, 1, 1, 1)
    adv_images = adv_images.detach() + alpha * grad
    adv_images = torch.clamp(adv_images, min=0, max=1).detach()

    delta = adv_images - X
    delta_norms = torch.norm(delta.view(batch_size, -1), p=2, dim=1)
    factor = epsilon / delta_norms
    factor = torch.min(factor, torch.ones_like(delta_norms))
    delta = delta * factor.view(-1, 1, 1, 1)

    adv_images = torch.clamp(X + delta, min=0, max=1).detach()


    return adv_images





def fgsm_l1(model, X, y, normalize, rs=False, epsilon=8):
    # rs: random start
    epsilon = epsilon / 255.0
    model.eval()

    X = X.clone().detach()
    y = y.clone().detach()
    C, H, W = X.size(1), X.size(2), X.size(3)
    epsilon *= C*H*W
    alpha = epsilon * 1.25
    
    adv_images = X.clone().detach()
    batch_size = len(X)
    if rs:
        # Starting at a uniformly random point
        delta = torch.empty_like(adv_images).normal_()
        d_flat = delta.view(adv_images.size(0), -1)
        n = d_flat.norm(p=1, dim=1).view(adv_images.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon
        adv_images = torch.clamp(adv_images + delta, min=0, max=1).detach()

    # delta = clamp(delta, lower_limit - X, upper_limit - X)
    # delta.requires_grad = True
    adv_images.requires_grad = True
    output = model(normalize(adv_images))

    loss = F.cross_entropy(output, y)
    # loss.backward()

    # Update adversarial images
    grad = torch.autograd.grad(loss, adv_images,
                               retain_graph=False, create_graph=False)[0]
    grad_norms = torch.norm(grad.view(batch_size, -1), p=1, dim=1) + 1e-10  # nopep8
    grad = grad / grad_norms.view(batch_size, 1, 1, 1)
    adv_images = adv_images.detach() + alpha * grad
    adv_images = torch.clamp(adv_images, min=0, max=1).detach()

    delta = adv_images - X
    delta_norms = torch.norm(delta.view(batch_size, -1), p=2, dim=1)
    factor = epsilon / delta_norms
    factor = torch.min(factor, torch.ones_like(delta_norms))
    delta = delta * factor.view(-1, 1, 1, 1)

    adv_images = torch.clamp(X + delta, min=0, max=1).detach()


    return adv_images