from typing import Optional, Union, List
import torch
from omegaconf import DictConfig, OmegaConf

from pado.core import PadoModule
from pado.nn.modules import Linear, Dropout, TimeGradientScale
from pado.nn.modules.activation import get_activation_cls

__all__ = ["BaseASRTransducerJointer"]


class BaseASRTransducerJointer(PadoModule):
    """
    Base RNN-Transducer Joint module.
    """

    def __init__(self,
                 num_tokens: int,
                 joint_dim: int,
                 drop_prob: float = 0.0,
                 rescale_grad: bool = False, *,
                 act_type: str = "tanh",
                 per_sample: bool = True) -> None:
        super().__init__()
        self.num_tokens = num_tokens
        self.joint_dim = joint_dim

        if rescale_grad:
            self.grad_scaler = TimeGradientScale(dim=1)
        else:
            self.grad_scaler = None

        self.act = get_activation_cls(act_type, inplace=True)
        self.drop = Dropout(drop_prob)
        self.fc = Linear(joint_dim, num_tokens, bias=True, init_type="proj")

        self.per_sample = per_sample

    def _forward_batch(self,
                       enc: torch.Tensor,
                       pred: torch.Tensor,
                       enc_lengths: Union[int, Optional[torch.Tensor]],
                       pred_lengths: Union[int, Optional[torch.Tensor]]) -> torch.Tensor:
        raise NotImplementedError

    def _forward_sample(self,
                        enc: torch.Tensor,
                        pred: torch.Tensor,
                        enc_lengths: torch.Tensor,
                        pred_lengths: torch.Tensor) -> List[torch.Tensor]:

        outs = []
        batch_size = enc.shape[0]
        for batch_idx in range(batch_size):
            enc_len = enc_lengths[batch_idx].item()
            pred_len = pred_lengths[batch_idx].item() + 1  # should keep + 1 (front blank)
            enc_i = enc[batch_idx, :enc_len].unsqueeze(0)  # keep batch dimension
            pred_i = pred[batch_idx, :pred_len].unsqueeze(0)  # keep batch dimension

            out = self._forward_batch(enc_i, pred_i, enc_len, pred_len)
            outs.append(out)
            assert out.ndim == 4  # (1, enc, pred, v)
        return outs

    def forward(self,
                enc: torch.Tensor,
                pred: torch.Tensor,
                enc_lengths: Optional[torch.Tensor] = None,
                pred_lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]:
        """
        :param enc:                 (batch_size, max_enc_length, joint_dim)
        :param pred:                (batch_size, max_pred_length, joint_dim)
        :param enc_lengths:         (batch_size,)
        :param pred_lengths:        (batch_size,)
        :return:
                result:             (batch_size, max_enc_length, max_pred_length, num_tokens)
        """
        batch_size = enc.shape[0]

        if (not self.per_sample) or (enc_lengths is None) or (pred_lengths is None) or (batch_size == 1):
            # not batched, or length is not given
            return self._forward_batch(enc, pred, enc_lengths, pred_lengths)
        else:
            return self._forward_sample(enc, pred, enc_lengths, pred_lengths)

    def step(self,
             enc: torch.Tensor,
             pred: torch.Tensor) -> torch.Tensor:
        """Run a single step (sequence_length == 1)
        :param enc:         (batch_size, 1, joint_dim) | (batch_size, joint_dim)
        :param pred:        (batch_size, 1, joint_dim) | (batch_size, joint_dim)
        :return:
                result:     (batch_size, 1, num_tokens) | (batch_size, num_tokens)
        """
        raise NotImplementedError

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "BaseASRTransducerJointer":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
