import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


def TRADES_loss(inputs, adv_inputs, targets, model):
    outputs = model(inputs)
    adv_outputs = model(adv_inputs)
    sup_loss = F.cross_entropy(outputs, targets)
    nat_probs = F.softmax(outputs, dim=1)
    adv_probs = F.softmax(adv_outputs, dim=1)
    rob_loss = F.kl_div((adv_probs+1e-12).log(), nat_probs, reduction='none').sum(dim=1).mean()
    
    return sup_loss, rob_loss

# Proposed loss
def ARoW_CE_loss(inputs, adv_inputs, targets, model, tau):
    outputs = model(inputs)
    adv_outputs = model(adv_inputs)
    adv_probs = F.softmax(adv_outputs, dim=1)
    nat_probs = F.softmax(outputs, dim=1)
    true_probs = torch.gather(adv_probs, 1, (targets.unsqueeze(1)).long()).squeeze()
    sup_loss = F.cross_entropy(outputs, targets)
    rob_loss = (F.kl_div((adv_probs+1e-12).log(), nat_probs, reduction='none').sum(dim=1) * torch.clamp(1. - true_probs - tau, min=0)).mean()
    
    return sup_loss, rob_loss

def ARoW_loss(inputs, adv_inputs, targets, model, tau, smoothing):
    LS_loss = LabelSmoothingCrossEntropy(smoothing)
    outputs = model(inputs)
    adv_outputs = model(adv_inputs)
    adv_probs = F.softmax(adv_outputs, dim=1)
    nat_probs = F.softmax(outputs, dim=1)
    true_probs = torch.gather(adv_probs, 1, (targets.unsqueeze(1)).long()).squeeze()
    sup_loss = LS_loss(outputs, targets)
    rob_loss = (F.kl_div((adv_probs+1e-12).log(), nat_probs, reduction='none').sum(dim=1) * torch.clamp(1. - true_probs - tau, min=0)).mean()
    
    return sup_loss, rob_loss


def HAT_loss(inputs, adv_inputs, targets, model, std_model):
    outputs = model(inputs)
    adv_outputs = model(adv_inputs)
    rob_loss = (F.kl_div((F.softmax(adv_outputs, dim=1) + 1e-12).log(), F.softmax(outputs, dim=1), reduction='none').sum(dim=1)).mean()
    sup_loss = F.cross_entropy(outputs, targets)
    helper_inputs = inputs + 2*(adv_inputs-inputs)
    with torch.no_grad():
        helper_targets = std_model(adv_inputs).argmax(dim=1).detach()
    help_loss = F.cross_entropy(model(helper_inputs), helper_targets)
    
    return sup_loss, rob_loss, help_loss

def CoW_loss(inputs, adv_inputs, targets, model, tau, smoothing):
    LS_loss = LabelSmoothingCrossEntropy(smoothing)
    outputs = model(inputs)
    adv_outputs = model(adv_inputs)
    adv_probs = F.softmax(adv_outputs, dim=1)
    nat_probs = F.softmax(outputs, dim=1)
    true_probs = torch.gather(nat_probs, 1, (targets.unsqueeze(1)).long()).squeeze()
    sup_loss = LS_loss(outputs, targets)
    rob_loss = (F.kl_div((adv_probs+1e-12).log(), nat_probs, reduction='none').sum(dim=1) * torch.clamp(true_probs + tau, max=1)).mean()
    
    return sup_loss, rob_loss

class LabelSmoothingCrossEntropy(nn.Module):
    """ NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()