from typing import Tuple
import torch

from pado.tasks.asr.asr_model import BaseASRModel
from pado.tasks.asr.asr_encoder import BaseASREncoder
from pado.tasks.asr.asr_subsampling import BaseASRSubsampling
from pado.tasks.asr.ctc_beamsearch import ASRCTCBeamsearch
from pado.tasks.asr.utils import BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE

__all__ = ["ASRCTCModel"]


class ASRCTCModel(BaseASRModel):

    def __init__(self,
                 encoder: BaseASREncoder,
                 subsampling: BaseASRSubsampling,
                 decoder: ASRCTCBeamsearch) -> None:
        super().__init__()
        self.encoder = encoder
        self.subsampling = subsampling
        self.decoder = decoder

        self.set_name()

    def encode(self,
               input_features: torch.Tensor,
               input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param input_features:          (batch_size, max_feat_length, feature_dim)
        :param input_lengths:           (batch_size,)
        :return:
                enc:                    (batch_size, max_seq_length, vocab_size)
                enc_lengths:            (batch_size,)
        """
        features, lengths = self.subsampling(input_features, input_lengths)
        enc, enc_lengths = self.encoder(features, lengths)
        return enc, enc_lengths

    def forward(self,
                input_features: torch.Tensor,
                input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param input_features:          (batch_size, max_feat_length, feature_dim)
        :param input_lengths:           (batch_size,)
        :return:
                enc:                    (batch_size, max_seq_length, vocab_size)
                enc_lengths:            (batch_size,)
        """
        enc, enc_lengths = self.encode(input_features, input_lengths)

        # no log_softmax. CTCLoss will handle inside.
        return enc, enc_lengths

    def decode(self,
               enc_features: torch.Tensor,
               enc_lengths: torch.Tensor,
               beam_width: int = 1) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        """
        :param enc_features:        (batch_size, max_seq_length, vocab_size)
        :param enc_lengths:         (batch_size,)
        :param beam_width:          if 1, run greedy decoding, else, beam search.
        :return:
                prediction:     list of list of token indices
                logp_scores:    list of log-probabilities
        """
        prediction, logp_scores = self.decoder(enc_features, enc_lengths, beam_width)
        return prediction, logp_scores
