from typing import Tuple

import torch

from pado.tasks.asr.ctc_model import ASRCTCModel
from pado.models.blocks.asr.common_subsampling import StrideSubsampling, VGGSubsampling
from pado.tasks.asr.ctc_beamsearch import ASRCTCBeamsearch
from pado.models.utils import replace_weight_variational_noise
from conformer.conformer_encoder import ReuseAttnConformerEncoder

__all__ = ["ReuseAttnConformerCTC"]


class ReuseAttnConformerCTC(ASRCTCModel):

    def __init__(self, cfg):

        # ------------------------------------------------------------------------------------------------ #
        # Build config and modules
        # ------------------------------------------------------------------------------------------------ #
        encoder = ReuseAttnConformerEncoder.from_config(cfg["encoder"])

        subsampling_type = cfg["subsampling"]["type"].lower()
        del cfg["subsampling"]["type"]
        if "stride" in subsampling_type:
            subsampling = StrideSubsampling.from_config(cfg["subsampling"])
        elif "vgg" in subsampling_type:
            subsampling = VGGSubsampling.from_config(cfg["subsampling"])
        else:
            raise ValueError(f"ReuseAttnConformerCTC invalid subsampling type {subsampling_type}.")

        decoder = ASRCTCBeamsearch.from_config(cfg["decoder"])

        # ------------------------------------------------------------------------------------------------ #
        # Init
        # ------------------------------------------------------------------------------------------------ #
        super().__init__(encoder=encoder,
                         subsampling=subsampling,
                         decoder=decoder)
        replace_weight_variational_noise(self, exclude_keys=("subsampling", "conv"))
        self.set_name()

    def encode(self,
               input_features: torch.Tensor,
               input_lengths: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        features, lengths = self.subsampling(input_features, input_lengths)
        enc, enc_lengths, scores, values, hiddens = self.encoder(features, lengths)
        return enc, enc_lengths, scores, values, hiddens

    def forward(self,
                input_features: torch.Tensor,
                input_lengths: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        enc, enc_lengths, scores, values, hiddens = self.encode(input_features, input_lengths)

        # no log_softmax. CTCLoss will handle inside.
        return enc, enc_lengths, scores, values, hiddens
