import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from utils import *

def me_trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=0.008,
                epsilon=0.03125,
                perturb_steps=10,
                Lambda=6.0,
                beta=1.0,
                entropy_weight=0,
                adv_loss = 'kl'):
    device = y.device
    # define KL-loss
    criterion_kl = nn.KLDivLoss(reduction='sum')
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = x_natural.detach() + 0.
    if adv_loss == 'kl':
                x_adv += 0.001 * torch.randn(x_natural.shape).to(device).detach()

                for _ in range(perturb_steps):
                    x_adv.requires_grad_()
                    with torch.enable_grad():
                        loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                               F.softmax(model(x_natural), dim=1))
                    grad = torch.autograd.grad(loss_kl, [x_adv])[0]
                    x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
                    x_adv = torch.min(torch.max(x_adv, x_natural - epsilon),
                                      x_natural + epsilon)
                    x_adv = torch.clamp(x_adv, 0.0, 1.0)

    model.train()
    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    optimizer.zero_grad()
    logits = model(x_natural)
    adv_logits = model(x_adv)
    loss_natural =  F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(adv_logits, dim=1),
                                                        F.softmax(model(x_natural), dim=1))
    loss_entropy = - torch.softmax(logits,dim=1) * F.log_softmax(logits, dim=1)
    loss_entropy = loss_entropy.sum(dim=1).mean()
    
    loss = loss_natural - beta*loss_entropy + Lambda*loss_robust 

    return loss



def trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=0.008,
                epsilon=0.03125,
                perturb_steps=10,
                Lambda=6.0,
                distance='l_inf',
                adv_loss = 'kl',
                ):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)
    bs = x_natural.size()[0]
    device = x_natural.device
    # generate adversarial example
    if distance == 'l_inf':
        if adv_loss == 'cw':
            x_adv = inf_pgd(model, x_natural, y, 10,eps=0.03125,loss_fn=cw_loss)
        if adv_loss == 'kl':
            x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach()
            for _ in range(perturb_steps):
                x_adv.requires_grad_()
                with torch.enable_grad():
                    loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                           F.softmax(model(x_natural), dim=1))
                grad = torch.autograd.grad(loss_kl, [x_adv])[0]
                x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
                x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
                x_adv = torch.clamp(x_adv, 0.0, 1.0)
        if adv_loss == 'ce':
            x_adv = inf_pgd(model, x_natural, y, 10,eps=0.03125)
            
    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    # zero gradient
    
    optimizer.zero_grad()
    # calculate robust loss
    logits = model(x_natural)
    loss_natural = F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                        F.softmax(model(x_natural), dim=1))
    
    loss = loss_natural + Lambda * loss_robust
    return loss 

def ls_trades_loss(model,
                x_natural,
                labels,
                optimizer,
                step_size=0.008,
                epsilon=0.03125,
                perturb_steps=10,
                alpha=0.1,
                Lambda=6.0,
                distance='l_inf',
                adv_loss = 'kl',
                ):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)
    bs = x_natural.size()[0]
    device = x_natural.device
    # generate adversarial example
    if distance == 'l_inf':
            x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach()
            for _ in range(perturb_steps):
                x_adv.requires_grad_()
                with torch.enable_grad():
                    loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                           F.softmax(model(x_natural), dim=1))
                grad = torch.autograd.grad(loss_kl, [x_adv])[0]
                x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
                x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
                x_adv = torch.clamp(x_adv, 0.0, 1.0)
            
    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    optimizer.zero_grad()
    criterion=LabelSmoothingLoss(smoothing=alpha)
    logits = model(x_natural)
    loss_natural = criterion(logits,labels)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                        F.softmax(model(x_natural), dim=1))
    
    loss = loss_natural + Lambda * loss_robust
    return loss 




