from typing import Optional, Tuple
import torch

from pado.core import PadoModule
from pado.nn.modules import LayerNorm, Swish
from pado.nn.transformer import (NormFeedForwardResidual, ProjectionResidual,
                                 MultiheadSelfAttention, AutoEncodingRelMultiheadAttention)
from pado.models.conformer.conformer_conv import ConformerConvModule

__all__ = ["ConformerLayer"]


class ConformerLayer(PadoModule):

    def __init__(self,
                 hidden_dim: int,
                 num_heads: int,
                 conv_kernel_size: int = 31,
                 feedforward_dim: Optional[int] = None,
                 attn_drop_prob: float = 0.1,
                 proj_drop_prob: float = 0.1,
                 feedforward_drop_prob: Optional[float] = None,
                 eps: float = 1e-5,
                 momentum: float = 0.01, *,
                 attn_first: bool = True,
                 rel_attn: bool = True,
                 attn_bias: bool = False,
                 share_query_bias: bool = False,
                 conv_norm_type: str = "bn",
                 conv_sync_bn: bool = True,
                 conv_gn_groups: int = 2,
                 conv_partial: bool = False,
                 was: bool = False,
                 was_gamma: float = 0.5,
                 memory_efficient: bool = False) -> None:
        super().__init__()

        # FF1
        self.feed_forward1 = NormFeedForwardResidual(
            hidden_dim, feedforward_dim, proj_drop_prob, feedforward_drop_prob,
            eps=eps, add_weight=0.5, act_layer=Swish, memory_efficient=memory_efficient
        )

        # Attn
        self.attn_first = attn_first
        self.rel_attn = rel_attn

        self.attn_norm = LayerNorm(hidden_dim, eps=eps)
        if rel_attn:
            self.attn = AutoEncodingRelMultiheadAttention(
                hidden_dim, num_heads, attn_drop_prob, bias=attn_bias,
                share_query_bias=share_query_bias, was=was, was_gamma=was_gamma
            )
        else:
            self.attn = MultiheadSelfAttention(
                hidden_dim, num_heads, attn_drop_prob, bias=attn_bias, was=was, was_gamma=was_gamma
            )
        self.projection = ProjectionResidual(
            hidden_dim, hidden_dim, proj_drop_prob, memory_efficient=memory_efficient
        )

        # Conv
        self.conv = ConformerConvModule(
            hidden_dim, conv_kernel_size, proj_drop_prob, eps=eps, momentum=momentum, norm_type=conv_norm_type,
            gn_groups=conv_gn_groups, sync_bn=conv_sync_bn, partial_conv=conv_partial, memory_efficient=memory_efficient
        )

        # FF2
        self.feed_forward2 = NormFeedForwardResidual(
            hidden_dim, feedforward_dim, proj_drop_prob, feedforward_drop_prob,
            eps=eps, add_weight=0.5, act_layer=Swish, memory_efficient=memory_efficient
        )
        self.out_norm = LayerNorm(hidden_dim, eps=eps)

    def forward(self,
                hidden: torch.Tensor,
                pos_emb: Optional[torch.Tensor],
                query_key_bias: Optional[torch.Tensor],
                query_pos_bias: Optional[torch.Tensor], *,
                attn_mask: Optional[torch.Tensor] = None,
                conv_mask: Optional[torch.Tensor] = None
                ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Conformer Layer forward
        :param hidden:              (batch_size, seq_length, hidden_dim)
        :param pos_emb:             (1, 2 * seq_length - 1, hidden_dim)
        :param query_key_bias:      (hidden_dim,)
        :param query_pos_bias:      (hidden_dim,)
        :param attn_mask:           (batch_size, seq_length, seq_length)
        :param conv_mask:           (batch_size, seq_length)
        :return:
                                    (batch_size, seq_length, hidden_dim)
        """
        hidden = self.feed_forward1(hidden)

        if self.attn_first:
            attn = self.attn_norm(hidden)
            if self.rel_attn:
                attn, prob = self.attn(attn, pos_emb, query_key_bias, query_pos_bias, attn_mask=attn_mask)
            else:
                attn, prob = self.attn(attn, attn_mask=attn_mask)
            proj = self.projection(attn, hidden)
            proj = self.conv(proj, mask=conv_mask)
        else:
            conv = self.conv(hidden, mask=conv_mask)
            attn = self.attn_norm(conv)
            if self.rel_attn:
                attn, prob = self.attn(attn, pos_emb, query_key_bias, query_pos_bias, attn_mask=attn_mask)
            else:
                attn, prob = self.attn(attn, attn_mask=attn_mask)
            proj = self.projection(attn, conv)

        output = self.feed_forward2(proj)  # (b, s, d)
        output = self.out_norm(output)  # (b, s, d)
        return output, (prob,)
