# Replace ViTAE-Transformer/Image-Classification/vitae/NormalCell.py's class Attention (Official github)
# https://github.com/ViTAE-Transformer/ViTAE-Transformer/blob/main/Image-Classification/vitae/NormalCell.py

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
                 attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim ** -0.5

        # Main branch
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        # --- PLuG gating branch (original) ---
        # produce q_gate, k_gate of size head_dim each (per token, shared across heads)
        self.qk_gate = nn.Linear(dim, self.head_dim * 2, bias=qkv_bias)
        # tiny 1->2 MLP on each pairwise scalar; product then tanh → gate in (-1, 1)
        self.gate_linear = nn.Linear(1, 2)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape

        # ===== Main branch: standard MHSA =====
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]                        # (B, H, N, d)
        A_main = (q @ k.transpose(-2, -1)) * self.scale         # (B, H, N, N)

        # ===== PLuG gating branch (shared across heads) =====
        qk_gate = self.qk_gate(x)                               # (B, N, 2d)
        q_gate, k_gate = qk_gate.chunk(2, dim=-1)               # (B, N, d), (B, N, d)

        # pairwise scores once (no per-head duplicate work)
        S = (q_gate @ k_gate.transpose(1, 2)) * self.scale      # (B, N, N)

        # tiny MLP on scalar S_ij → [a,b], product → tanh
        S_unsq = S.unsqueeze(-1)                                # (B, N, N, 1)
        gate_out = self.gate_linear(S_unsq)                     # (B, N, N, 2)
        gateA, gateB = gate_out[..., 0], gate_out[..., 1]       # (B, N, N), (B, N, N)
        G = torch.tanh(gateA * gateB)                           # (B, N, N) in (-1, 1)

        # broadcast gate over heads and scale logits (original PLuG: multiplicative on logits)
        A = A_main * (1 + G.unsqueeze(1))                       # (B, H, N, N)

        # softmax + dropout
        A = A.softmax(dim=-1)
        A = self.attn_drop(A)

        # output projection
        x = (A @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
