import torch
import torch.nn as nn

from src.utils import powmax


class KLDivLoss(nn.Module):
    def __init__(self, eps=1e-12):
        super().__init__()
        self.eps = eps

    def forward(
            self,
            logits: torch.Tensor,
            targets: torch.Tensor,
            masks: torch.Tensor = None,
            T: float = 1.0
    ):
        """
        Compute KL-Divergence loss.
        :param T: Temperature, default to be 1.
        :param logits: the logits of the estimated distribution, before `softmax`
        :param targets: the target distribution, which should be summed up to be 1.
        :param masks: Optional. For masked selection.
        Shape is identical to the shape of `logits` up to last dim.
        :return: scalar loss.
        """
        logits = logits.view(-1, logits.size(-1))
        estimated = torch.softmax(logits / T, dim=-1)
        estimated = powmax(estimated + self.eps)
        targets = targets.view(-1, targets.size(-1)).float()
        targets = powmax(targets + self.eps)

        # _targets = torch.sum(targets, dim=-1)
        # _targets = torch.masked_select(_targets, masks.view(-1))
        # print(_targets.detach().cpu().numpy().tolist())
        # assert torch.all(_targets)

        loss = targets * torch.log(targets / estimated)
        loss = torch.sum(loss, dim=-1)
        if masks is not None:
            masks = masks.view(-1)
            loss = torch.masked_select(loss, masks)
        return torch.mean(loss)


if __name__ == '__main__':
    criterion = KLDivLoss()
    _logits = torch.rand(size=(5, 12))
    _targets = torch.rand(size=(5, 12))
    _targets = torch.softmax(_targets, dim=-1)
    _masks = torch.tensor((True, False, False, False, True))
    _loss = criterion.forward(_logits, _targets)
    print(_loss)
