from typing import Optional, Union
import torch

from pado.tasks.asr.transducer_jointer import BaseASRTransducerJointer

__all__ = ["AddTransducerJointer", "ConcatTransducerJointer"]


class AddTransducerJointer(BaseASRTransducerJointer):

    def _forward_batch(self,
                       enc: torch.Tensor,
                       pred: torch.Tensor,
                       enc_lengths: Union[int, Optional[torch.Tensor]],
                       pred_lengths: Union[int, Optional[torch.Tensor]]) -> torch.Tensor:
        if self.grad_scaler is not None:
            enc = self.grad_scaler(enc, pred_lengths)
            pred = self.grad_scaler(pred, enc_lengths)

        enc = enc.unsqueeze(2)  # (b, enc_s, 1, dim)
        pred = pred.unsqueeze(1)  # (b, 1, pred_s, dim)

        out = enc + pred
        out = self.act(out)
        out = self.drop(out)
        out = self.fc(out)
        return out

    def step(self,
             enc: torch.Tensor,
             pred: torch.Tensor) -> torch.Tensor:
        out = enc + pred
        out = self.act(out)
        out = self.fc(out)
        return out


class ConcatTransducerJointer(BaseASRTransducerJointer):

    def _forward_batch(self,
                       enc: torch.Tensor,
                       pred: torch.Tensor,
                       enc_lengths: Union[int, Optional[torch.Tensor]],
                       pred_lengths: Union[int, Optional[torch.Tensor]]) -> torch.Tensor:
        if self.grad_scaler is not None:
            enc = self.grad_scaler(enc, pred_lengths)
            pred = self.grad_scaler(pred, enc_lengths)

        enc = enc.unsqueeze(2)  # (b, enc_s, 1, dim1)
        pred = pred.unsqueeze(1)  # (b, 1, pred_s, dim2)

        out = torch.cat([enc, pred], dim=-1)
        out = self.act(out)
        out = self.drop(out)
        out = self.fc(out)
        return out

    def step(self,
             enc: torch.Tensor,
             pred: torch.Tensor) -> torch.Tensor:
        out = torch.cat([enc, pred], dim=-1)
        out = self.act(out)
        out = self.fc(out)
        return out
