import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4, dropout=0.0):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, dropout=0.0):
        super(Attention, self).__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask=None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(b, n, h, -1).transpose(1, 2), qkv)

        dots = (q @ k.transpose(-1, -2)) * self.scale

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            dots = dots.masked_fill(mask == 0, float('-inf'))
            dots = dots.softmax(dim=-1).masked_fill(mask == 0, 0.0)
        else:
            dots = dots.softmax(dim=-1)

        out = (dots @ v).transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class Scale(nn.Module):
    def __init__(self, scale, fn):
        super(Scale, self).__init__()
        self.scale = scale
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

class ConformerConvModule(nn.Module):
    def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
        super(ConformerConvModule, self).__init__()
        self.pointwise_conv1 = nn.Conv1d(dim, dim * expansion_factor, kernel_size=1)
        self.depthwise_conv = nn.Conv1d(dim * expansion_factor, dim * expansion_factor, kernel_size=kernel_size, groups=dim * expansion_factor, padding=kernel_size // 2)
        self.batch_norm = nn.BatchNorm1d(dim * expansion_factor)
        self.swish = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(dim * expansion_factor, dim, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.pointwise_conv1(x)
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.swish(x)
        x = self.pointwise_conv2(x)
        x = self.dropout(x)
        return x.transpose(1, 2)

class ConformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_head=64,
        heads=10,
        ff_mult=4,
        conv_expansion_factor=2,
        conv_kernel_size=31,
        attn_dropout=0.0,
        ff_dropout=0.0,
        conv_dropout=0.0,
    ):
        super(ConformerBlock, self).__init__()
        self.ff1 = Scale(0.5, PreNorm(dim, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)))
        self.attn = PreNorm(dim, Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout))
        self.conv = ConformerConvModule(dim=dim, expansion_factor=conv_expansion_factor, kernel_size=conv_kernel_size, dropout=conv_dropout)
        self.ff2 = Scale(0.5, PreNorm(dim, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)))

        self.post_norm = nn.LayerNorm(dim)

    def forward(self, x, mask=None):
        x = self.ff1(x) + x
        x = self.attn(x, mask=mask) + x
        x = self.conv(x) + x
        x = self.ff2(x) + x
        x = self.post_norm(x)
        return x