from typing import Any, List, Optional, Tuple, Union
import torch
from omegaconf import DictConfig, OmegaConf

from pado.core import PadoModule
from pado.nn.modules import Embedding

__all__ = ["BaseASRTransducerPredictor"]


class BaseASRTransducerPredictor(PadoModule):
    """Base RNN-Transducer Predictor module"""

    def __init__(self,
                 num_tokens: int,
                 embed_dim: int,
                 word_drop_prob: float = 0.0,
                 blank_idx: int = 0) -> None:
        super().__init__()
        self.num_tokens = num_tokens
        self.embed_dim = embed_dim
        self.blank_idx = blank_idx

        self.word_emb = Embedding(num_tokens, embed_dim,
                                  padding_idx=blank_idx, word_drop_prob=word_drop_prob)

    def forward(self,
                indices: torch.Tensor,
                lengths: torch.Tensor) -> torch.Tensor:
        """
        :param indices:     (batch_size, seq_length)    long
        :param lengths:     (batch_size,)
        :return:
                result:     (batch_size, seq_length, ...)
        """
        raise NotImplementedError

    def step(self,
             index: torch.Tensor,
             hiddens: Optional[Any] = None) -> Tuple[torch.Tensor, Any]:
        """Run a single step (sequence_length == 1)
        :param index:       (batch_size, 1)
        :param hiddens:     ...
        :return:
                result:     (batch_size, 1, ...)
                hidden_states
        """
        raise NotImplementedError

    def extract_hiddens(self,
                        hiddens,
                        batch_indices: Optional[Union[List[int], torch.Tensor]] = None) -> Any:
        """Extract hiddens of given batch indices"""
        raise NotImplementedError

    def update_hiddens(self,
                       hiddens,
                       new_hiddens,
                       batch_indices: Optional[Union[List[int], torch.Tensor]] = None) -> Any:
        """Update hiddens by new_hiddens, only for given batch indices"""
        raise NotImplementedError

    @classmethod
    def from_config(cls, cfg: DictConfig):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
