import torch
import torch.nn.functional as F
import torch.cuda.amp as amp
from omegaconf import DictConfig, OmegaConf

from pado.core.base.loss import PadoLoss

__all__ = ["CTCLoss"]


class CTCLoss(PadoLoss):

    def __init__(self,
                 blank_idx: int = 0,
                 *, reduction: str = "mean",
                 batch_first: bool = True,
                 zero_infinity: bool = False):
        super().__init__(reduction=reduction)
        self.blank_idx = blank_idx
        self.batch_first = batch_first
        self.zero_infinity = zero_infinity

    def forward(self,
                logits: torch.Tensor,
                labels: torch.Tensor,
                logit_lengths: torch.Tensor,
                label_lengths: torch.Tensor) -> torch.Tensor:
        """
        :param logits:              (N, T, V)
        :param labels:              (N, U)          long or int
        :param logit_lengths:       (N,) max T      long or int
        :param label_lengths:       (N,) max U      long or int
        :return:
                                    (1,) if mean or sum, (N,) if none.
        """
        if self.batch_first:
            logits = logits.transpose(0, 1).contiguous()  # (b, t, V) -> (t, b, V)

        with amp.autocast(enabled=False):
            log_probs = F.log_softmax(logits.float(), dim=-1)
            loss = F.ctc_loss(log_probs, labels, logit_lengths, label_lengths,
                              blank=self.blank_idx, reduction=self.reduction, zero_infinity=self.zero_infinity)
        return loss

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "CTCLoss":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)

