from typing import Optional
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.cuda.amp as amp
from torch.utils.checkpoint import checkpoint

from pado.core import PadoModule
from pado.nn.modules import Linear, Conv1d, MaskedBatchNorm, LayerNorm, GroupNorm, GLU, Swish, Dropout, Add

__all__ = ["ConformerConvModule"]


class ConformerConvModule(PadoModule):

    def __init__(self,
                 in_channels: int,
                 kernel_size: int,
                 drop_prob: float = 0.1,
                 eps: float = 1e-5,
                 momentum: float = 0.01, *,
                 norm_type: str = "bn",
                 sync_bn: bool = True,
                 gn_groups: int = 2,
                 partial_conv: bool = False,
                 memory_efficient: bool = False):
        super().__init__()

        self.in_channels = in_channels
        if kernel_size % 2 != 1:
            raise ValueError(f"ConformerConvModule kernel_size should be odd, but got {kernel_size}.")

        self.norm = LayerNorm(in_channels, eps=eps)

        # Implementation choice: Linear instead of Point-wise Conv.
        # Both are identical, but Linear does not require additional transpose
        self.conv1 = nn.Sequential(
            OrderedDict({
                "conv1": Linear(in_channels, in_channels * 2, bias=True),
                "act1": GLU(dim=-1)
            }))

        self.conv2 = Conv1d(in_channels, in_channels,
                            kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
                            bias=False, groups=in_channels, partial=partial_conv)

        norm_type = norm_type.lower()
        self.norm_type = norm_type
        if norm_type == "bn":
            self.bn2 = MaskedBatchNorm(in_channels, eps=eps, momentum=momentum, sync_bn=sync_bn)
        elif norm_type == "ln":
            self.bn2 = LayerNorm(in_channels, eps=eps)
        elif norm_type == "gn":
            self.bn2 = GroupNorm(gn_groups, in_channels, eps=eps)
        else:
            raise NotImplementedError(f"ConformerConvModule norm_type {norm_type} not supported.")
        self.act2 = Swish()

        self.conv3 = nn.Sequential(
            OrderedDict({
                "conv3": Linear(in_channels, in_channels, bias=True),
                "drop": Dropout(drop_prob, inplace=True)
            })
        )
        self.add = Add()
        self.memory_efficient = memory_efficient

        self._initialize_parameters()

    def _initialize_parameters(self) -> None:
        nn.init.normal_(self.conv2.weight.data, std=0.01)

    def forward(self,
                x: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Conformer ConvModule forward
        :param x:       (batch_size, seq_length, hidden_dim)
        :param mask:    (batch_size, seq_length)
        :return:
                        (batch_size, seq_length, hidden_dim)
        """
        identity = x

        x = self.norm(x)  # (b, s, d)
        if (not self.memory_efficient) or (not self.training):
            x = self.conv1(x)  # (b, s, d) -> (b, s, 2d) -> (b, s, d)
        else:
            x = checkpoint(self.conv1, x)
        x = x.transpose(1, 2).contiguous()  # (b, d, s)

        with amp.autocast(enabled=False):
            # Depth-separable Conv and normalization seems to be unstable for FP16 training.
            x = x.float()
            x = self.conv2(x, mask=mask)  # (b, d, s)

            b, d, s = x.shape
            if self.norm_type == "bn":
                mask = mask.view(b, 1, s)
                x = self.bn2(x, mask)  # (b, d, s)
                x = x.transpose(1, 2).contiguous()  # (b, s, d)
            elif self.norm_type == "gn":
                x = self.bn2(x)  # (b, d, s)
                x = x.transpose(1, 2).contiguous()  # (b, s, d)
            else:  # ln
                x = x.transpose(1, 2).contiguous()  # (b, s, d)
                x = self.bn2(x)  # (b, s, d)
        x = self.act2(x)

        if (not self.memory_efficient) or (not self.training):
            x = self.conv3(x)  # (b, s, d) -> (b, s, d)
        else:
            x = checkpoint(self.conv3, x)
        x = self.add(x, identity)
        return x
