# All we have to do is just simply replacing anaconda3/envs/name/python3.7/site-packages/timm/models/vision_transformer.py's class Attention with below one
# Modified from timm==0.4.12
# https://github.com/facebookresearch/deit

class Attention(nn.Module): 
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # -----------------------------
        # Main branch: Q, K, V
        # -----------------------------
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        # -----------------------------
        # Gating branch: produce Q_gate, K_gate (size = head_dim * 2)
        # -----------------------------
        self.qk_gate = nn.Linear(dim, self.head_dim * 2, bias=qkv_bias)

        self.gate_linear = nn.Linear(1, 2)
        
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape

        # 1) Main branch
        qkv = self.qkv(x)    
        q_main, k_main, v_main = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)

        # Main attention
        A_main = torch.matmul(q_main, k_main.transpose(-2, -1)) * self.scale 

        # 2) Gating branch
        qk_gate = self.qk_gate(x)
        q_gate, k_gate = qk_gate.chunk(2, dim=-1) 

        q_gate = q_gate.unsqueeze(1).expand(B, self.num_heads, N, self.head_dim)
        k_gate = k_gate.unsqueeze(1).expand(B, self.num_heads, N, self.head_dim)

        raw_gate = torch.matmul(q_gate, k_gate.transpose(-2, -1)) * self.scale 

        raw_gate_unsq = raw_gate.unsqueeze(-1) 
        gate_out = self.gate_linear(raw_gate_unsq)  
        gateA, gateB = gate_out.chunk(2, dim=-1)
        G = gateA * gateB  
        G = G.squeeze(-1)  
        G = torch.tanh(G) 

        A_gated = A_main * (1 + G)
        A_gated = A_gated.softmax(dim=-1)
        A_gated = self.attn_drop(A_gated)

        out = torch.matmul(A_gated, v_main) 
        out = out.transpose(1, 2).reshape(B, N, C)

        # Final projection
        out = self.proj(out)
        out = self.proj_drop(out)
        return out
        
# You can also train it by replacing official timm github's file. pytorch-image-models/timm/layers/attention.py
# Replace class Attention of attention.py to below one
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/attention.py

class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        scale_norm: bool = False,
        proj_bias: bool = True,
        attn_drop: float = 0.,
        proj_drop: float = 0.,
        norm_layer: Optional[Type[nn.Module]] = None,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        if qk_norm or scale_norm:
            assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'

        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        # Use the unfused path to apply multiplicative gating to logits
        self.fused_attn = False

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.norm = norm_layer(dim) if scale_norm else nn.Identity()
        self.proj = nn.Linear(dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

        # ---- Gating branch ----
        self.qk_gate = nn.Linear(dim, self.head_dim * 2, bias=qkv_bias)
        self.gate_linear = nn.Linear(1, 2) 

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        B, N, C = x.shape

        # Main branch
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        q = q * self.scale
        logits = q @ k.transpose(-2, -1)  # (B, H, N, N)

        # Gating branch 
        qk_gate = self.qk_gate(x)          
        qg, kg = qk_gate.chunk(2, dim=-1)    
        R = (qg @ kg.transpose(-2, -1)) * float(self.scale)  # (B, N, N)

        s = self.gate_linear(R.unsqueeze(-1))  
        a, b = s[..., 0], s[..., 1]
        G = torch.tanh(a * b)              
        G = G.unsqueeze(1).expand(B, self.num_heads, N, N).to(logits.dtype)

        # Modulate logits
        logits = logits * (1 + G)
        logits = maybe_add_mask(logits, attn_mask)
        attn = logits.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.norm(x)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
