from spaghettini import quick_register

import torch
from torch.nn.functional import cross_entropy


@quick_register
def class_balanced_cross_entropy(inputs, targets):
    assert len(targets.shape) <= 2
    if len(targets.shape) == 2:
        assert targets.shape[1] == 1
        target = targets[:, 0]
    target = target.long()

    # Compute the loss separately for each class and take the average.
    idx0, idx1 = (target == torch.zeros_like(target)), (target == torch.ones_like(target))
    loss0 = cross_entropy(inputs[idx0], target[idx0]) if idx0.sum() > 0 else 0.
    loss1 = cross_entropy(inputs[idx1], target[idx1]) if idx1.sum() > 0 else 0.

    return (1 / 2) * loss0 + (1 / 2) * loss1
