from typing import Tuple
import torch
from omegaconf import DictConfig

from pado.core import PadoModule

__all__ = ["BaseASREncoder"]


class BaseASREncoder(PadoModule):
    """
    Base full-context(batch processing) encoder module for ASR
    """

    def forward(self,
                features: torch.Tensor,
                lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param features:        (batch_size, max_seq_length, feature_dim)
        :param lengths:         (batch_size,)
        :return:
                results:        (batch_size, max_out_seq_length, hidden_dim)
                out_lengths:    (batch_size,)
        """
        raise NotImplementedError

    @classmethod
    def from_config(cls, cfg: DictConfig):
        raise NotImplementedError
