from typing import Optional, Tuple

import torch
from pado.core import PadoModule
from pado.nn.modules import LayerNorm, Swish
from pado.nn.transformer import NormFeedForwardResidual, ProjectionResidual
from pado.models.conformer.conformer_conv import ConformerConvModule

from module.multihead_attention_reuse import ReuseMultiheadSelfAttentionWithV

__all__ = ["ReuseAttnConformerLayerWithV"]


class ReuseAttnConformerLayerWithV(PadoModule):

    def __init__(self,
                 hidden_dim: int,
                 num_heads: int,
                 conv_kernel_size: int = 31,
                 feedforward_dim: Optional[int] = None,
                 proj_drop_prob: float = 0.1,
                 feedforward_drop_prob: Optional[float] = None,
                 eps: float = 1e-5,
                 momentum: float = 0.005, *,
                 attn_first: bool = True,
                 conv_norm_type: str = "bn",
                 conv_sync_bn: bool = True,
                 conv_gn_groups: int = 2,
                 conv_partial: bool = False,
                 share_v_proj_dim: Optional[int] = None,
                 memory_efficient: bool = False) -> None:
        super().__init__()

        # FF1
        self.feed_forward1 = NormFeedForwardResidual(
            hidden_dim, feedforward_dim, proj_drop_prob, feedforward_drop_prob,
            eps=eps, add_weight=0.5, act_layer=Swish, memory_efficient=memory_efficient
        )

        # Attn
        self.attn_first = attn_first

        self.attn_norm = LayerNorm(hidden_dim, eps=eps)
        self.attn = ReuseMultiheadSelfAttentionWithV(hidden_dim, num_heads, v_proj_dim=share_v_proj_dim)

        self.projection = ProjectionResidual(
            self.attn.v_proj_dim, hidden_dim, proj_drop_prob, memory_efficient=memory_efficient
        )

        # Conv
        self.conv = ConformerConvModule(
            hidden_dim, conv_kernel_size, proj_drop_prob, eps=eps, momentum=momentum, norm_type=conv_norm_type,
            gn_groups=conv_gn_groups, sync_bn=conv_sync_bn, partial_conv=conv_partial, memory_efficient=memory_efficient
        )

        # FF2
        self.feed_forward2 = NormFeedForwardResidual(
            hidden_dim, feedforward_dim, proj_drop_prob, feedforward_drop_prob,
            eps=eps, add_weight=0.5, act_layer=Swish, memory_efficient=memory_efficient
        )
        self.out_norm = LayerNorm(hidden_dim, eps=eps)

    def forward(self,
                hidden: torch.Tensor,
                attn_scores: torch.Tensor, *,
                conv_mask: Optional[torch.Tensor] = None
                ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        """
        Conformer Layer forward
        :param hidden:              (batch_size, seq_length, hidden_dim)
        :param attn_scores:         (batch_size, seq_length, seq_length)
        :param conv_mask:           (batch_size, seq_length)
        :return:
                                    (batch_size, seq_length, hidden_dim)
        """
        hidden = self.feed_forward1(hidden)

        if self.attn_first:
            attn = self.attn_norm(hidden)
            attn, prob, v = self.attn(attn, attn_scores)
            proj = self.projection(attn, hidden)
            proj = self.conv(proj, mask=conv_mask)
        else:
            conv = self.conv(hidden, mask=conv_mask)
            attn = self.attn_norm(conv)
            attn, prob, v = self.attn(attn, attn_scores)
            proj = self.projection(attn, conv)

        output = self.feed_forward2(proj)  # (b, s, d)
        output = self.out_norm(output)  # (b, s, d)
        return output, (prob,), (v,)
