import torch
import torch.nn as nn
import torch.nn.functional as F

'''
EFCEnTLoss: eliminating fuzzy domain based cross-entropy loss for non-Target attack
'''
class EFCEnTLoss(nn.Module):
    def __init__(self, temperature_scale=1.0, fuzzy_scale=1.0):
        super(EFCEnTLoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()
        self.temperature_scale = temperature_scale
        self.fuzzy_scale = fuzzy_scale

    def forward(self, outputs, ys):
        outputs = outputs/self.temperature_scale
        ys_onehot = F.one_hot(ys, outputs.shape[1])
        outputs[ys_onehot.bool()] = outputs[ys_onehot.bool()]*self.fuzzy_scale
        # outputs[ys_onehot.bool()] = outputs[ys_onehot.bool()]*((outputs[ys_onehot.bool()]<0.0).float())+outputs[ys_onehot.bool()]*((outputs[ys_onehot.bool()]>=0.0).float())*self.scale
        # outputs[ys_onehot.bool()] = outputs[ys_onehot.bool()]*((outputs[ys_onehot.bool()]<0.0).float())*(self.scale)+outputs[ys_onehot.bool()]*((outputs[ys_onehot.bool()]>=0.0).float())
        loss = self.criterion(outputs, ys)
        return loss
