# Replace /TokenLabeling/tlt/models/layers.py's class Attention with below one (Official github)
# https://github.com/zihangJiang/TokenLabeling/blob/main/tlt/models/layers.py

class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        head_dim: Optional[int] = None,    
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,   
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        fraction: float = 1.0,              # gating head-width fraction
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim or (dim // num_heads)
        self.scale    = qk_scale or (self.head_dim ** -0.5)

        # Main Q/K/V
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        # Gating branch
        self.gating_dim   = int(self.head_dim * fraction)
        self.gating_scale = self.gating_dim ** -0.5
        self.qk_gate      = nn.Linear(dim, 2 * self.gating_dim, bias=qkv_bias)
        self.gate_linear  = nn.Linear(1, 2)  

        # Output proj + drops
        self.proj      = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor, padding_mask: torch.Tensor = None) -> torch.Tensor:
        B, N, C = x.shape

        # ——— Main attention ———
        qkv = self.qkv(x)  
        q, k, v = (
            qkv
            .reshape(B, N, 3, self.num_heads, self.head_dim)
            .permute(2, 0, 3, 1, 4)
        )  

        # raw scores
        attn = (q @ k.transpose(-2, -1)) * self.scale  
        if padding_mask is not None:
            # padding_mask: [B,N], True for pad
            pm = padding_mask.unsqueeze(1).unsqueeze(2)   
            attn = attn.masked_fill(pm, float("-inf"))

        # ——— Gating branch ———
        qg, kg = self.qk_gate(x).chunk(2, dim=-1)  
        qg = qg.unsqueeze(1).expand(B, self.num_heads, N, self.gating_dim)
        kg = kg.unsqueeze(1).expand(B, self.num_heads, N, self.gating_dim)

        raw_gate = (qg @ kg.transpose(-2, -1)) * self.gating_scale  
        feats    = raw_gate.unsqueeze(-1)                         
        a, b     = self.gate_linear(feats).chunk(2, dim=-1)        
        G        = (a * b).squeeze(-1)                     
        G = torch.tanh(G) 

        A_gated = attn * (1 + G)
        if padding_mask is not None:
            A_gated = A_gated.masked_fill(pm, 0.0)
        A_gated = A_gated.softmax(dim=-1)
        A_gated = self.attn_drop(A_gated)

        out = (A_gated @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out
