'''
Implementation of the following loss functions:
1. Cross Entropy
2. Focal Loss
3. Cross Entropy + MMCE_weighted
4. Cross Entropy + MMCE
5. Brier Score
'''

from torch.nn import functional as F
from Losses.focal_loss import FocalLoss
from Losses.focal_loss_adaptive_gamma import FocalLossAdaptive
from Losses.mmce import MMCE, MMCE_weighted
from Losses.brier_score import BrierScore
from Losses.dual_focal_loss import DualFocalLoss
from Losses.soft_ece import SoftBinnedECE, SoftAvUCLoss
# from Losses.maxEnt_loss import MaxEntLoss
from Losses.BSCE_GRA import BSCE_GRA_Loss

def cross_entropy(logits, targets, **kwargs):
    return F.cross_entropy(logits, targets, reduction='sum')


def focal_loss(logits, targets, **kwargs):
    return FocalLoss(gamma=kwargs['gamma'])(logits, targets)


def focal_loss_adaptive(logits, targets, **kwargs):
    return FocalLossAdaptive(gamma=kwargs['gamma'], 
                             device=kwargs['device'])(logits, targets)

def dual_focal_loss(logits, targets, **kwargs):
    return DualFocalLoss(gamma=kwargs['gamma'])(logits, targets)

def mmce(logits, targets, **kwargs):
    ce = F.cross_entropy(logits, targets)
    mmce = MMCE(kwargs['device'])(logits, targets)
    return ce + (kwargs['lamda'] * mmce)


def mmce_weighted(logits, targets, **kwargs):
    ce = F.cross_entropy(logits, targets)
    mmce = MMCE_weighted(kwargs['device'])(logits, targets)
    return ce + (kwargs['lamda'] * mmce)


def brier_score(logits, targets, **kwargs):
    return BrierScore()(logits, targets)

def bsce_gra(logits, targets, **kwargs):
    return BSCE_GRA_Loss(gamma=kwargs['gamma'])(logits, targets)

# def maxEnt_loss(logits, targets, **kwargs):
#     return MaxEntLoss(ratio=kwargs['ratio'],
#                       constraints=6, 
#                       device=kwargs['device'], 
#                       gamma=1,
#                       num_classes=kwargs['num_classes'] 
#                     )(logits, targets)

