from torch.nn import functional as F
from Losses.focal_loss import FocalLoss
from Losses.focal_loss_sd import FocalLossSD
from Losses.adafocal import AdaFocal
from Losses.mmce import MMCE, MMCE_weighted
from Losses.brier_score import BrierScore


def cross_entropy(logits, targets, **kwargs):
    return F.cross_entropy(logits, targets, reduction='sum')

def focal_loss(logits, targets, **kwargs):
    return FocalLoss(gamma=kwargs['gamma'])(logits, targets)

def focal_loss_sd(logits, targets, **kwargs):
    return FocalLossSD(gamma=kwargs['gamma'], device=kwargs['device'])(logits, targets)

def adafocal(logits, targets, **kwargs):
    return AdaFocal(gamma=kwargs['gamma'], device=kwargs['device'], prev_epoch_adabin_dict=kwargs['prev_epoch_adabin_dict'], gamma_lambda=kwargs['gamma_lambda'],
                                    adafocal_start_epoch=kwargs['adafocal_start_epoch'], epoch=kwargs['epoch'])(logits, targets)

def mmce(logits, targets, **kwargs):
    ce = F.cross_entropy(logits, targets)
    mmce = MMCE(kwargs['device'])(logits, targets)
    return ce + (kwargs['lamda'] * mmce)

def mmce_weighted(logits, targets, **kwargs):
    ce = F.cross_entropy(logits, targets)
    mmce = MMCE_weighted(kwargs['device'])(logits, targets)
    return ce + (kwargs['lamda'] * mmce)

def brier_score(logits, targets, **kwargs):
    return BrierScore()(logits, targets)