from typing import Tuple
import torch
import torch.nn as nn
from omegaconf import DictConfig

from pado.core import PadoModule
from pado.nn.modules import Conv1d, Conv2d

__all__ = ["BaseASRSubsampling"]


class BaseASRSubsampling(PadoModule):

    def _initialize_parameters(self) -> None:
        for m in self.modules():
            if isinstance(m, (Conv1d, Conv2d)):
                nn.init.normal_(m.weight.data, std=0.01)

    def forward(self,
                x: torch.Tensor,
                lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError

    def calculate_length(self, lengths: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @classmethod
    def from_config(cls, cfg: DictConfig):
        raise NotImplementedError
