# Replace T2T-ViT/models/transformer_block.py's class Attention ((Official github)
# https://github.com/yitu-opensource/T2T-ViT/blob/main/models/transformer_block.py

class Attention(nn.Module):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 qkv_bias: bool = False,
                 qk_scale: Optional[float] = None,
                 attn_drop: float = 0.0,
                 proj_drop: float = 0.0):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.qk_gate = nn.Linear(dim, self.head_dim * 2, bias=qkv_bias)
        self.gate_linear = nn.Linear(1, 2)  

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    # -------------------------------------------------------------------------
    # Helper for gradient‑checkpointing the gating branch
    # -------------------------------------------------------------------------
    def _gate_forward(self, q_gate: torch.Tensor, k_gate: torch.Tensor) -> torch.Tensor:
        """
        Compute the tanh‑gating matrix G (B, H, N, N) from q_gate / k_gate.
        Runs under torch.utils.checkpoint to save activation memory.
        NOTE: Only tensor arguments are allowed; `self` is captured implicitly.
        """
        raw_gate = (q_gate @ k_gate.transpose(-2, -1)) * self.scale         
        raw_gate = raw_gate.unsqueeze(-1)                                     
        gateA, gateB = self.gate_linear(raw_gate).chunk(2, dim=-1)              
        G = torch.tanh(gateA * gateB).squeeze(-1)                                
        return G

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape

        qkv = self.qkv(x)
        q_main, k_main, v_main = qkv.reshape(B, N, 3, self.num_heads,
                                             self.head_dim).permute(2, 0, 3, 1, 4)

        A_main = (q_main @ k_main.transpose(-2, -1)) * self.scale          

        qk_gate = self.qk_gate(x)                                        
        q_gate, k_gate = qk_gate.chunk(2, dim=-1)

        q_gate = q_gate.unsqueeze(1).expand(B, self.num_heads, N, self.head_dim)
        k_gate = k_gate.unsqueeze(1).expand_as(q_gate)

        G = checkpoint(self._gate_forward, q_gate, k_gate)            

        A_gated = (A_main * (1 + G)).softmax(dim=-1)                          
        A_gated = self.attn_drop(A_gated)

        out = (A_gated @ v_main).transpose(1, 2).reshape(B, N, C)              
        out = self.proj_drop(self.proj(out))
        return out
