import math
import torch
import torch.nn.functional as F
from .linf import Linf


'''PGD attack and EOT-PGD attack'''

# PGD attack [Data Normalize version]
# For CIFAR100, Tiny-ImageNet
def Linf_PGD(x_in, y_true, net, steps, eps):
    if eps == 0:
        return x_in
    training = net.training
    if training:
        net.eval()

    noise = 2*eps*torch.rand(x_in.shape).cuda()-eps
    x_adv = (x_in*0.25+0.5+noise).clone().clamp_(0, 1).requires_grad_()
    optimizer = Linf([x_adv], lr=eps/steps*2.3)
    for _ in range(steps):
        optimizer.zero_grad()
        net.zero_grad()

        out, _ = net((x_adv - 0.5) / 0.25)
        loss = -F.cross_entropy(out, y_true)
        loss.backward()
        optimizer.step()

        diff = (x_adv - x_in * 0.25-0.5)
        diff.clamp_(-eps, eps)
        x_adv.detach().copy_((diff + x_in * 0.25 + 0.5).clamp_(0, 1))

    net.zero_grad()
    if training:
        net.train()
        
    
    return (x_adv - 0.5) / 0.25


# EOT-PGD attack [Data Normalize version]
# For CIFAR100, Tiny-ImageNet
def EOT_Linf_PGD(x_in, y_true, net, steps, eps):
    if eps == 0:
        return x_in
    training = net.training
    if training:
        net.eval()

    loss = torch.nn.CrossEntropyLoss().cuda()

    noise = 2*eps*torch.rand(x_in.shape).cuda()-eps
    x_adv = (x_in*0.25+0.5+noise).clone().clamp_(0, 1).requires_grad_()

    for _ in range(steps):

        grad_ = torch.zeros_like(x_adv)

        for _ in range(10):
            net.zero_grad()
            outputs, _ = net((x_adv-0.5)/0.25)
            cost = loss(outputs, y_true)
            cost.backward()
            grad_ += x_adv.grad

        x_adv_ = x_adv + eps/steps * 2.3 * grad_.sign()
        diff = (x_adv_ - x_in * 0.25-0.5)
        diff.clamp_(-eps, eps)
        x_adv.detach().copy_((diff + x_in * 0.25 + 0.5).clamp_(0, 1))

    net.zero_grad()
    if training:
        net.train()

    return (x_adv - 0.5) / 0.25


# PGD attack [Not Data Normalize version]
# For CIFAR10, STL10
# def Linf_PGD(x_in, y_true, net, steps, eps):
#     if eps == 0:
#         return x_in
#     training = net.training
#     if training:
#         net.eval()

#     x_adv = x_in.clone().clamp_(0, 1).requires_grad_()
#     optimizer = Linf([x_adv], lr=0.007)
#     for _ in range(steps):
#         optimizer.zero_grad()
#         net.zero_grad()
#         out, _ = net(x_adv)
#         loss = -F.cross_entropy(out, y_true)
#         loss.backward()
#         optimizer.step()

#         diff = x_adv - x_in
#         diff.clamp_(-eps, eps)
#         x_adv.detach().copy_((diff + x_in).clamp_(0, 1))

#     net.zero_grad()
#     # reset to the original state
#     if training:
#         net.train()
    
#     return x_adv


# EOT-PGD attack [Not Data Normalize version]
# For CIFAR10, STL10
# def EOT_Linf_PGD(x_in, y_true, net, steps, eps):
#     if eps == 0:
#         return x_in
#     training = net.training
#     if training:
#         net.eval()

#     loss = torch.nn.CrossEntropyLoss().cuda()
#     x_adv = x_in.clone().clamp_(0, 1).requires_grad_()

#     for _ in range(steps):

#         grad_ = torch.zeros_like(x_adv)

#         for _ in range(10):
#             net.zero_grad()
#             outputs, _ = net(x_adv)
#             cost = loss(outputs, y_true)
#             cost.backward()
#             grad_ += x_adv.grad

#         x_adv_ = x_adv + 0.007 * grad_.sign()
#         diff = x_adv_ - x_in
#         diff.clamp_(-eps, eps)
#         x_adv.detach().copy_((diff + x_in).clamp_(0, 1))

#     net.zero_grad()
#     if training:
#         net.train()

#     return x_adv
