import torch
import torch.nn.functional as F
from torch import autograd, nn
from torch.nn import KLDivLoss

class KLLoss(nn.Module):
    def __init__(self, softmax=True, reduction='batchmean', **kwargs):
        super(KLLoss, self).__init__()
        self.softmax_mark = softmax
        self.loss = KLDivLoss(reduction=reduction)
    
    def forward(self, feature, target):
        if self.softmax_mark:
            feature = F.softmax(feature)
            target  = F.softmax(target)
        return self.loss(torch.log(target), feature)



class JSDivLoss(nn.Module):
    def __init__(self, softmax=True, reduction='batchmean',  **kwargs):
        super(JSDivLoss, self).__init__()
        self.softmax_mark = softmax
        self.kl=KLLoss(softmax=False, reduction=reduction)

    def forward(self, feature, target):
        if self.softmax_mark:
            feature = F.softmax(feature, dim=-1)
            target = F.softmax(target, dim=-1)
        f_t_mean = (feature+target)/2
        js_loss = 0.5*self.kl(feature, f_t_mean)+0.5*self.kl(target, f_t_mean)
        return js_loss

class CELoss(nn.Module):
    def __init__(self,  **kwargs):
        super(CELoss, self).__init__()

    def forward(self, input_data, target, softmax=True):
        if softmax:
            target = F.softmax(target, dim=-1)

        softmax_input = input_data.log_softmax(dim=-1)
        loss = torch.mean(torch.sum(-target * softmax_input, dim=-1))
        return loss
        

