'''From https://github.com/alinlab/LfF/blob/master/module/loss.py'''

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

class GeneralizedCELoss(nn.Module):

    def __init__(self, q=0.7):
        super(GeneralizedCELoss, self).__init__()
        self.q = q
             
    def forward(self, logits, targets):
        p = F.softmax(logits, dim=1)
        if np.isnan(p.mean().item()):
            raise NameError('GCE_p')
        Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1)) # 256 x 1 , 256 x 1 -> probability of target label
        # modify gradient of cross entropy
        loss_weight = (Yg.squeeze().detach()**self.q)*np.abs(self.q)
        if np.isnan(Yg.mean().item()):
            raise NameError('GCE_Yg')

        loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight
        return loss
    
