import torch
import torch.nn as nn
import torch.nn.functional as F

'''
EFCETLoss: eliminating fuzzy domain based cross-entropy loss for target attacks
'''
class EFCETLoss(nn.Module):
    def __init__(self, temperature_scale=1.0, fuzzy_scale=1.0):
        super(EFCETLoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()
        self.temperature_scale = temperature_scale
        self.fuzzy_scale = fuzzy_scale

    def forward(self, outputs, target_ys):
        # outputs = outputs/self.temperature_scale
        # outputs_transformed = torch.zeros(outputs.shape[0], 2).to(outputs.device)
        # target_ys_onehot = F.one_hot(target_ys, outputs.shape[1])
        # outputs_transformed[:, 0] = outputs[target_ys_onehot.bool()]
        # outputs_transformed[:, 1] = outputs[~(target_ys_onehot.bool())].view(outputs.shape[0], outputs.shape[1]-1).sum(dim=1)*self.fuzzy_scale
        # target_ys_new = torch.zeros_like(target_ys).to(torch.int64).to(target_ys.device)
        # loss = self.criterion(outputs_transformed, target_ys_new)

        outputs = outputs/self.temperature_scale
        target_ys_onehot = F.one_hot(target_ys, outputs.shape[1])
        outputs[~(target_ys_onehot.bool())] = outputs[~(target_ys_onehot.bool())]*self.fuzzy_scale
        loss = self.criterion(outputs, target_ys)
        return loss