from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf

from pado.core.base.loss import PadoLoss

__all__ = ["CrossEntropyLossWithSmoothing"]


class CrossEntropyLossWithSmoothing(PadoLoss):

    def __init__(self,
                 smoothing: float = 0.0,
                 ignore_index: int = -1,
                 *, reduction: str = "mean"):
        super().__init__(reduction=reduction)

        if not (0 <= smoothing < 1):
            raise ValueError(f"Invalid label smoothing: {smoothing}, should be in range [0, 1)")
        self.smoothing = smoothing
        self.ignore_index = ignore_index

    def forward(self,
                output: torch.Tensor,
                target: torch.Tensor,
                *, smoothing: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Cross entropy loss with smoothing
        :param output:              (batch_size, num_classes)       float, before softmax
        :param target:              (batch_size,)                   long
        :param smoothing:           float [0, 1)
        :return:
                smoothed_loss:              (batch_size,) or (1,)
                nll_loss (non-smoothed):    (batch_size,) or (1,)
        """
        if smoothing is None:  # override
            smoothing = self.smoothing

        output = output.float()

        if smoothing <= 0:
            loss = F.cross_entropy(output, target, ignore_index=self.ignore_index, reduction=self.reduction)
            return loss, loss

        log_prob = F.log_softmax(output, dim=-1)  # (n, c)
        num_classes = log_prob.shape[-1]

        nll_loss = -log_prob.gather(dim=-1, index=target.unsqueeze(-1))  # (n, 1)
        smooth_loss = -log_prob.sum(dim=-1, keepdim=True)  # (n, 1)
        if self.ignore_index >= 0:
            ignore_mask = target.eq(self.ignore_index)
            if ignore_mask.ndim == 1:  # temporally expand
                ignore_mask = ignore_mask.view(-1, 1)
            nll_loss.masked_fill_(ignore_mask, 0.0)
            smooth_loss.masked_fill_(ignore_mask, 0.0)

        nll_loss = nll_loss.squeeze(1)  # (n,)
        smooth_loss = smooth_loss.squeeze(1)  # (n,)

        smoothing_val = smoothing / (num_classes - 1)
        loss = (1.0 - smoothing - smoothing_val) * nll_loss + smoothing_val * smooth_loss

        if self.reduction == "sum":
            return loss.sum(), nll_loss.sum()
        elif self.reduction == "mean":
            return loss.mean(), nll_loss.mean()
        else:  # none
            return loss, nll_loss

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "CrossEntropyLossWithSmoothing":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
