import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES

# https://github.com/hessamb/label-refinery/blob/master/extensions/refinery_loss.py
# def kl_div(output, target):
#     # output: NxC pre-softmax output
#     # target: NxC pre-softmax logits
#     return -torch.bmm(
#             F.softmax(target, dim=1).unsqueeze(1), 
#             F.log_softmax(output, dim=1).unsqueeze(2)
#         ).squeeze() # Nx1x1 --> N

@LOSSES.register_module()
class SoftKLDivLoss(nn.Module):
    def __init__(self, reduction='batchmean', loss_weight=1.0, *args, **kwargs):
        super(SoftKLDivLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

        self.cls_criterion = F.kl_div

    def forward(self,
                cls_score,
                gt_label,
                gt_logit,
                avg_factor=None,
                reduction_override=None,
                *args, 
                **kwargs):
        assert reduction_override in (None, 'none', 'mean', 'sum', 'batchmean')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        assert reduction == 'batchmean'
        # loss_cls = self.loss_weight * kl_div(cls_score, gt_logit).mean()
        loss_cls = self.loss_weight * self.cls_criterion(
                    F.log_softmax(cls_score, dim=1), 
                    F.softmax(gt_logit, dim=1), 
                    reduction=reduction)
        return loss_cls