from spaghettini import quick_register

from torch import nn

from src.dl.losses.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy
from src.dl.losses.noisy_labels import flip_binary_labels


@quick_register
class DoubleCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing_coeff=0., label_flip_probability=0.):
        super().__init__()
        # By adding smoothing only in one of the losses, we make sure the smoothing terms don't cancel each other.
        self.ce1 = LabelSmoothedCrossEntropy(alpha=smoothing_coeff, downweight_ce=False)
        self.ce2 = LabelSmoothedCrossEntropy(alpha=0)
        self.label_flip_probability = label_flip_probability

    def forward(self, input, target):
        assert set(list(target.detach().cpu().numpy())).issubset({0, 1})
        # Flip a (usually small) ratio of the labels, if asked.
        if self.label_flip_probability > 0:
            target = flip_binary_labels(target, p=self.label_flip_probability)

        # Get the opposite targets.
        opposite_target = (-1 * target + 1.).long()

        # Sum both saturating and non-saturating losses.
        ce1 = self.ce1(input=input, target=target)
        ce2 = -1 * self.ce2(input=input, target=opposite_target)

        return ce1 + ce2
