# Replaced Visformer/models.py's class Attention to below one (Official github)
# https://github.com/danczs/Visformer/blob/main/models.py

class Attention(nn.Module):
    """
    Visformer attention with an extra PLuG-style gating branch.
    Keeps original API so existing Blocks work unchanged.
    """
    def __init__(
        self,
        dim,
        num_heads=8,
        head_dim_ratio=1.,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.,
        proj_drop=0.,
        use_gate: bool = True,   
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = round(dim // num_heads * head_dim_ratio)
        self.head_dim = head_dim

        qk_scale_factor = qk_scale if qk_scale is not None else -0.25
        self.scale = head_dim ** qk_scale_factor

        # ---------- Main QKV branch ----------
        self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=qkv_bias)

        # ---------- Gating branch (new) ----------
        self.use_gate = use_gate
        if self.use_gate:
            self.qk_gate = nn.Conv2d(dim, head_dim * 2, 1, stride=1, padding=0, bias=qkv_bias)
            self.gate_linear = nn.Linear(1, 2)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, C, H, W = x.shape
        HW = H * W

        # ---------- Main attention ----------
        x_qkv = self.qkv(x)
        qkv = rearrange(x_qkv, 'b (x h d) h1 w1 -> x b h (h1 w1) d', x=3, h=self.num_heads, d=self.head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  

        # raw attention logits
        A_main = (q * self.scale) @ (k.transpose(-2, -1) * self.scale)  

        if self.use_gate:
            # ---------- Gating branch ----------
            gate_qk = self.qk_gate(x) 
            q_gate_t, k_gate_t = gate_qk.chunk(2, dim=1)  

            q_gate = rearrange(q_gate_t, 'b d h w -> b 1 (h w) d').expand(-1, self.num_heads, -1, -1)
            k_gate = rearrange(k_gate_t, 'b d h w -> b 1 (h w) d').expand(-1, self.num_heads, -1, -1)

            raw_gate = (q_gate * self.scale) @ (k_gate.transpose(-2, -1) * self.scale)  

            gate_out = self.gate_linear(raw_gate.unsqueeze(-1))  
            gateA, gateB = gate_out.unbind(dim=-1)               
            G = torch.tanh(gateA * gateB)                       

            A = A_main * (1.0 + G)
        else:
            A = A_main

        A = A.softmax(dim=-1)
        A = self.attn_drop(A)

        x_out = A @ v 

        # restore (B, C, H, W)
        x_out = rearrange(x_out, 'b h (hw) d -> b (h d) hw', hw=HW)
        x_out = rearrange(x_out, 'b c (h w) -> b c h w', h=H, w=W)

        x_out = self.proj(x_out)
        x_out = self.proj_drop(x_out)
        return x_out
