import torch
from torch import nn

class Swish(nn.SiLU):
    """Swish activation (PyTorch内置SiLU等价)"""
    pass

class FeedForwardModule(nn.Module):
    """Macaron FFN子模块"""
    def __init__(self, d_model, expansion=4, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * expansion),
            Swish(),
            nn.Dropout(dropout),
            nn.Linear(d_model * expansion, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class ConvModule(nn.Module):
    """Conformer中的卷积模块"""
    def __init__(self, d_model, kernel_size=31, dropout=0.1):
        super().__init__()
        assert kernel_size % 2 == 1, "kernel_size必须为奇数"
        self.ln = nn.LayerNorm(d_model)
        self.pw1 = nn.Conv1d(d_model, 2 * d_model, 1) # pointwise conv
        self.glu = nn.GLU(dim=1)
        self.dw = nn.Conv1d(d_model, d_model, kernel_size,
                            padding=kernel_size // 2, groups=d_model) # depthwise conv
        self.bn = nn.BatchNorm1d(d_model)
        self.act = Swish()
        self.pw2 = nn.Conv1d(d_model, d_model, 1)
        self.do = nn.Dropout(dropout)

    def forward(self, x): # x: (B, T, D)
        x = self.ln(x)
        x = x.transpose(1, 2) # (B, D, T)
        x = self.pw1(x)
        x = self.glu(x)
        x = self.dw(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.pw2(x)
        x = x.transpose(1, 2) # (B, T, D)
        return self.do(x)

class ConformerBlock(nn.Module):
    """一个完整的Conformer Block"""
    def __init__(self, d_model, num_heads, ff_mult=4, conv_kernel=31, dropout=0.1):
        super().__init__()
        self.ff1 = FeedForwardModule(d_model, ff_mult, dropout)
        self.mha_ln = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.do = nn.Dropout(dropout)
        self.conv = ConvModule(d_model, conv_kernel, dropout)
        self.ff2 = FeedForwardModule(d_model, ff_mult, dropout)
        self.out_ln = nn.LayerNorm(d_model)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        # Macaron: 0.5 * FFN
        x = x + 0.5 * self.ff1(x)
        # MHSA
        y = self.mha_ln(x)
        y, _ = self.mha(y, y, y, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
        x = x + self.do(y)
        # Convolution
        x = x + self.conv(x)
        # Second 0.5 * FFN
        x = x + 0.5 * self.ff2(x)
        return self.out_ln(x)

class ConformerEncoder(nn.Module):
    """多层Conformer Encoder"""
    def __init__(self, d_model, num_layers, num_heads, ff_mult=4, conv_kernel=31, dropout=0.1):
        super().__init__()
        # 添加检查：d_model必须能被num_heads整除
        assert d_model % num_heads == 0, f"[Error] d_model ({d_model}) must be divisible by num_heads ({num_heads})"
        self.layers = nn.ModuleList([
            ConformerBlock(d_model, num_heads, ff_mult, conv_kernel, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, attn_mask=None, key_padding_mask=None): # x: (B, T, D)
        for blk in self.layers:
            x = blk(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
        return x