from mmdet.registry import MODELS
import torch
import torch.nn as nn
import torch.nn.functional as F

@MODELS.register_module()
class DeconfMLP(nn.Module):
    def __init__(
        self,
        d: int = 256,
        nhead: int = 4,
        attn_drop: float = 0.0,
        proj_drop: float = 0.1,
        mlp_mult: int = 4,
        mlp_drop: float = 0.1,
        layerscale_init: float = 1e-4,
        use_alpha: bool = False
    ):
        super().__init__()
        h = d * mlp_mult

        # ---- Self-Attn 子层（Pre-LN）----
        self.ln1  = nn.LayerNorm(d)
        self.mha  = nn.MultiheadAttention(d, nhead, dropout=attn_drop, batch_first=True)
        self.drop_attn = nn.Dropout(proj_drop)
        # LayerScale（可训练缩放，小初值）
        # self.gamma_attn = nn.Parameter(torch.ones(d) * layerscale_init)

        # ---- MLP 子层（Pre-LN + SwiGLU）----
        self.ln2  = nn.LayerNorm(d)
        self.fc1  = nn.Linear(d, 2 * h)   # SwiGLU: (u, v) -> silu(u) * v
        self.fc2  = nn.Linear(h, d)
        self.drop_mlp = nn.Dropout(mlp_drop)
        # self.gamma_mlp = nn.Parameter(torch.ones(d) * layerscale_init)

        # ---- α 门控：控制整体偏移幅度（初始≈0即恒等），不需要可设 use_alpha=False ----
        if use_alpha:
            self.alpha = nn.Parameter(torch.tensor(0.0))
        else:
            self.register_buffer('alpha', torch.tensor(1.0))

        self._init_weights()

    def _init_weights(self):
        # MHA 参数
        for p in self.mha.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        # MLP 参数
        nn.init.xavier_uniform_(self.fc1.weight); nn.init.zeros_(self.fc1.bias)
        # 将最后一层权重置 0，使初始为“近似恒等”更稳
        # nn.init.zeros_(self.fc2.weight); nn.init.zeros_(self.fc2.bias)
        nn.init.xavier_uniform_(self.fc2.weight, gain=1e-3)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None, return_attn: bool = False):
        """
        x: [B, N, D]
        key_padding_mask: [B, N] (bool)，True 表示该 token 被 mask
        return_attn: 是否返回注意力权重（平均各头）
        """
        x_in = x
        x = x.float()

        # ---- Self-Attn 段 ----
        q = self.ln1(x)
        attn_out, attn_w = self.mha(q, q, q, key_padding_mask=key_padding_mask, need_weights=return_attn)
        # x = x + self.drop_attn(attn_out) * self.gamma_attn  # LayerScale 残差
        x = self.drop_attn(attn_out)#  * self.gamma_attn

        # ---- MLP 段（SwiGLU）----
        z = self.ln2(x)
        u, v = self.fc1(z).chunk(2, dim=-1)
        mlp_out = self.fc2(F.silu(u) * v)
        # x = x + self.drop_mlp(mlp_out) * self.gamma_mlp       # LayerScale 残差
        x = self.drop_mlp(mlp_out)# * self.gamma_mlp

        # ---- α 门控整体偏移：out = x_in + α * (x - x_in) ----
        out = x_in + self.alpha * (x)

        if return_attn:
            # 返回平均后的注意力（可用于可视化/诊断）
            # attn_w: [B, nhead, N, N]（batch_first=True 时为此形状）
            attn_mean = attn_w.mean(dim=1) if attn_w is not None else None  # [B, N, N]
            return out, attn_mean
        return out