import torch
import torch.nn as nn
from torch.autograd import Variable


class CWLoss(nn.Module):
    def __init__(self):
        super(CWLoss, self).__init__()

    def forward(self, logits, target, kappa=0, tar=False):
        """Carlini & Wagner attack loss.

            Args:
                logits (torch.cuda.FloatTensor): the predicted logits, [1, num_classes].
                target (torch.cuda.LongTensor): the label for points, [1].
        """
        num_classes = logits.shape[1]
        target = torch.ones(logits.size(0)).type(torch.cuda.FloatTensor).mul(target.float())
        target_one_hot = Variable(torch.eye(num_classes).type(torch.cuda.FloatTensor)[target.long()].cuda())

        real = torch.sum(target_one_hot * logits, 1)
        other = torch.max((1 - target_one_hot) * logits - (target_one_hot * 10000), 1)[0]
        kappa = torch.zeros_like(other).fill_(kappa)

        if tar:
            return torch.mean(torch.max(other - real, kappa))
        else:
            return torch.mean(torch.max(real - other, kappa))
