from typing import Tuple, Union, List
import torch

from pado.tasks.asr.asr_model import BaseASRModel
from pado.tasks.asr.asr_encoder import BaseASREncoder
from pado.tasks.asr.asr_subsampling import BaseASRSubsampling
from pado.tasks.asr.transducer_predictor import BaseASRTransducerPredictor
from pado.tasks.asr.transducer_jointer import BaseASRTransducerJointer
from pado.tasks.asr.transducer_beamsearch import ASRTransducerBeamsearch
from pado.tasks.asr.utils import BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE

__all__ = ["ASRTransducerModel"]


class ASRTransducerModel(BaseASRModel):

    def __init__(self,
                 encoder: BaseASREncoder,
                 subsampling: BaseASRSubsampling,
                 predictor: BaseASRTransducerPredictor,
                 jointer: BaseASRTransducerJointer,
                 decoder: ASRTransducerBeamsearch) -> None:
        super().__init__()
        self.encoder = encoder
        self.subsampling = subsampling

        self.predictor = predictor
        self.jointer = jointer
        self.decoder = decoder

        # just assign
        self.decoder.predictor = predictor
        self.decoder.jointer = jointer

        self.set_name()

    def encode(self,
               input_features: torch.Tensor,
               input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param input_features:          (batch_size, max_feat_length, feature_dim)
        :param input_lengths:           (batch_size,)
        :return:
                enc:                    (batch_size, max_seq_length, hidden_dim)
                enc_lengths:            (batch_size,)
        """
        features, lengths = self.subsampling(input_features, input_lengths)
        enc, enc_lengths = self.encoder(features, lengths)
        return enc, enc_lengths

    def forward(self,
                input_features: torch.Tensor,
                input_lengths: torch.Tensor,
                target_indices: torch.Tensor,
                target_lengths: torch.Tensor
                ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], torch.Tensor, torch.Tensor]:
        """
        :param input_features:          (batch_size, max_feat_length, feature_dim)
        :param input_lengths:           (batch_size,)
        :param target_indices:          (batch_size, max_token_length)
        :param target_lengths:          (batch_size,)
        :return:
                out:                    (batch_size, max_token_length, vocab_size)
                enc:                    (batch_size, max_seq_length, hidden_dim)
                enc_lengths:            (batch_size,)
        """
        enc, enc_lengths = self.encode(input_features, input_lengths)

        pred = self.predictor(target_indices, target_lengths)
        out = self.jointer(enc, pred, enc_lengths, target_lengths)

        # no log_softmax. NumbaRNNTLoss will handle it inside.
        return out, enc, enc_lengths

    def decode(self,
               enc_features: torch.Tensor,
               enc_lengths: torch.Tensor,
               beam_width: int = 1) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        """
        :param enc_features:        (batch_size, max_seq_length, hidden_dim)
        :param enc_lengths:         (batch_size,)
        :param beam_width:          if 1, run greedy decoding, else, beam search.
        :return:
                prediction:     list of list of token indices
                logp_scores:    list of log-probabilities
        """
        prediction, logp_scores = self.decoder(enc_features, enc_lengths, beam_width)
        return prediction, logp_scores
