from mmdet.registry import MODELS
import torch
import torch.nn as nn
import torch.nn.functional as F

@MODELS.register_module()
class TextKV2Aux(nn.Module):
    def __init__(self, d=256, nhead=4, dropout=0.0):
        super().__init__()
        self.q_proj = nn.Linear(d, d)
        self.k_proj = nn.Linear(d, d)
        self.v_proj = nn.Linear(d, d)
        self.attn   = nn.MultiheadAttention(d, nhead, dropout=dropout, batch_first=True)
        self.ffn    = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
        self.norm1  = nn.LayerNorm(d)
        self.norm2 = nn.LayerNorm(d)
        self.alpha  = nn.Parameter(torch.tensor(0.5))
        self.reset_parameters()

    def reset_parameters(self):
        # Linear: Xavier，bias=0
        for m in [self.q_proj, self.k_proj, self.v_proj]:
            nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
        for m in self.ffn:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
        # LayerNorm: weight=1, bias=0
        nn.init.ones_(self.norm1.weight); nn.init.zeros_(self.norm1.bias)
        nn.init.ones_(self.norm2.weight); nn.init.zeros_(self.norm2.bias)
        # MHA 的内部权重也可显式重置（PyTorch 默认已是 xavier_uniform_）
        nn.init.xavier_uniform_(self.attn.in_proj_weight)
        if self.attn.in_proj_bias is not None:
            nn.init.zeros_(self.attn.in_proj_bias)
        nn.init.xavier_uniform_(self.attn.out_proj.weight)
        nn.init.zeros_(self.attn.out_proj.bias)
        
    def forward(self, Tq, E_old, E_new):
        # Tq: [B,Nq,D], E_old:[C_old,D], E_new:[C_new,D]
        B, Nq, D = Tq.shape
        Q_all = torch.cat([E_old, E_new], dim=0)          # [C_old+C_new, D]
        Q = self.q_proj(Q_all).unsqueeze(0).expand(B,-1,-1)    # [B,C_all,D]
        K = self.k_proj(Tq)                                     # [B,Nq,D]
        V = self.v_proj(Tq)                                     # [B,Nq,D]

        out, attn = self.attn(Q, K, V, need_weights=True)       # out:[B,C_all,D], attn:[B,C_all,Nq]
        y = self.norm1(Q + out)
        y = self.norm2(y + self.ffn(y))                         # [B,C_all,D]
        # 只保留“旧类”部分作为辅助
        C_old = E_old.size(0)
        aux0 = F.normalize(E_old.unsqueeze(0) + self.alpha * y[:, :C_old, :], dim=-1)  # [B,C_old,D]
        # 也返回旧类对应的注意力权重，便于后续对齐
        # attn_old = attn[:, :C_old, :]                           # [B,C_old,Nq]
        return aux0#, attn_old