import torch
import torch.nn as nn
import torch.nn.functional as F


class Linear(nn.Module):
    """Wrapper for nn.Linear with Xavier initialization"""
    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        torch.nn.init.xavier_uniform_(self.linear.weight)
        if bias:
            torch.nn.init.zeros_(self.linear.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


class View(nn.Module):
    """Wrapper for tensor.view()"""
    def __init__(self, shape: tuple, contiguous: bool = False):
        super().__init__()
        self.shape = shape
        self.contiguous = contiguous

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.contiguous:
            x = x.contiguous()
        return x.view(*self.shape)


class Transpose(nn.Module):
    """Wrapper for tensor.transpose()"""
    def __init__(self, shape: tuple):
        super().__init__()
        self.shape = shape

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.transpose(*self.shape)


class ResidualConnectionModule(nn.Module):
    """
    Residual Connection: output = module(inputs)*module_factor + inputs*input_factor
    """
    def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0):
        super().__init__()
        self.module = module
        self.module_factor = module_factor
        self.input_factor = input_factor

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.module(inputs) * self.module_factor + inputs * self.input_factor


class Swish(nn.Module):
    """Swish activation"""
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.sigmoid(x)


"""Convolution module"""
class DepthwiseConv1d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False):
        super().__init__()
        assert out_channels % in_channels == 0
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=in_channels, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class PointwiseConv1d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1, padding: int = 0, bias: bool = True):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, 1, stride=stride, padding=padding, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class ConformerConvModule(nn.Module):
    """
    Conformer convolution module:
    - LayerNorm -> Transpose -> PointwiseConv -> GLU -> DepthwiseConv -> BN -> Swish -> PointwiseConv -> Dropout
    """
    def __init__(self, in_channels: int, kernel_size: int = 5, expansion_factor: int = 2, dropout_p: float = 0.1):
        super().__init__()
        assert (kernel_size - 1) % 2 == 0, "kernel_size should be odd"
        assert expansion_factor == 2, "here only supports expansion_factor=2"

        self.sequential = nn.Sequential(
            nn.LayerNorm(in_channels),
            Transpose(shape=(1, 2)),  # [B, dim, time]
            PointwiseConv1d(in_channels, in_channels * expansion_factor),
            nn.GLU(dim=1),
            DepthwiseConv1d(in_channels, in_channels, kernel_size, padding=(kernel_size - 1)//2),
            nn.BatchNorm1d(in_channels),
            Swish(),
            PointwiseConv1d(in_channels, in_channels),
            nn.Dropout(dropout_p)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.sequential(x).transpose(1, 2)


"""Feedforward module"""
class FeedForwardModule(nn.Module):
    """Conformer feed forward module with pre-norm, Swish, dropout"""
    def __init__(self, encoder_dim: int = 640, expansion_factor: int = 2, dropout_p: float = 0.1):
        super().__init__()
        self.sequential = nn.Sequential(
            nn.LayerNorm(encoder_dim),
            Linear(encoder_dim, encoder_dim * expansion_factor),
            Swish(),
            nn.Dropout(dropout_p),
            Linear(encoder_dim * expansion_factor, encoder_dim),
            nn.Dropout(dropout_p)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.sequential(x)