from torch import nn

from src.losses.focal_loss import FocalLoss
from src.losses.kde_ece import get_ece_reg, get_sharpness_reg, get_bandwidth
from src.losses.mmce import MMCE_weighted


def cross_entropy(logits, targets, **kwargs):
    if kwargs['num_classes'] == 1:
        return nn.BCEWithLogitsLoss()(logits, targets)
    return nn.CrossEntropyLoss()(logits, targets)


def mean_squared_error(scores, targets, **kwargs):
    return nn.MSELoss()(scores, targets)


def focal_loss(logits, targets, **kwargs):
    return FocalLoss(adaptive=kwargs['adaptive'], gamma=kwargs['loss_param'], device=kwargs['device'])(logits, targets)


def mmce(logits, targets, **kwargs):
    ce = nn.CrossEntropyLoss()(logits, targets)
    mmce_w = MMCE_weighted(kwargs['device'])(logits, targets)

    return ce + kwargs['loss_param'] * mmce_w


def kde_ce(logits, targets, **kwargs):
    ce = cross_entropy(logits, targets, **kwargs)
    bandwidth = get_bandwidth(kwargs['b'], kwargs['f'], kwargs['device'])
    reg = get_ece_reg(kwargs['f'], kwargs['target_orig'], bandwidth, kwargs['p'], kwargs['mc_type'], kwargs['device'])

    return ce + kwargs['loss_param'] * reg


def kde_mse(scores, targets, **kwargs):
    mse = mean_squared_error(scores, targets)
    bandwidth = get_bandwidth(kwargs['b'], kwargs['f'], kwargs['device'])
    reg = get_sharpness_reg(kwargs['f'], kwargs['target_orig'], bandwidth, kwargs['device'])

    return mse + kwargs['loss_param'] * reg
