import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .cross_entropy_loss import cross_entropy


@LOSSES.register_module()
class SoftCrossEntropyLoss(nn.Module):
    """Cross entropy loss with pseudo label
    Args:
        reduction (str): The method used to reduce the loss.
            Options are "none", "mean" and "sum". Defaults to 'mean'.
        loss_weight (float):  Weight of the loss. Defaults to 1.0.
    """

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(SoftCrossEntropyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.cls_criterion = cross_entropy

    def forward(self,
                cls_score,
                gt_label,
                gt_logit,
                weight=None,
                avg_factor=None,
                reduction_override=None,
                *args,
                **kwargs):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        conf, label = torch.max(F.softmax(gt_logit, dim=1), dim=1)
        loss_cls = self.loss_weight * self.cls_criterion(
            cls_score,
            label,
            weight,
            reduction=reduction,
            avg_factor=avg_factor)
        return loss_cls