# Replaced timm/models/gcvit.py's class WindowAttentionGlobal with below one. (Timm)
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/gcvit.py

class WindowAttentionGlobal(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        window_size: Tuple[int, int],
        use_global: bool = True,
        qkv_bias: bool = True,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ):
        super().__init__()
        window_size = (window_size[0], window_size[1])
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.use_global = use_global

        self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads)

        # --- Main Q, K, V branch ---
        self.qkv = nn.Linear(dim, dim * (2 if use_global else 3), bias=qkv_bias)

        # --- Gating branch ---
        self.qk_gate = nn.Linear(dim, self.head_dim * 2, bias=qkv_bias)
        self.gate_linear = nn.Linear(1, 2)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    # ---------------------------------------------------------------------
    def _compute_gate(self, x: torch.Tensor) -> torch.Tensor:
        B, N, _ = x.shape
        qk_gate = self.qk_gate(x) 
        q_gate, k_gate = qk_gate.chunk(2, dim=-1)  

        q_gate = q_gate.unsqueeze(1).expand(B, self.num_heads, N, self.head_dim)
        k_gate = k_gate.unsqueeze(1).expand(B, self.num_heads, N, self.head_dim)

        raw_gate = torch.matmul(q_gate, k_gate.transpose(-2, -1)) * self.scale  

        # GML
        gate_out = self.gate_linear(raw_gate.unsqueeze(-1))  
        g_a, g_b = gate_out.chunk(2, dim=-1)
        G = torch.tanh((g_a * g_b).squeeze(-1))  
        return G

    # ---------------------------------------------------------------------
    def forward(self, x: torch.Tensor, q_global: Optional[torch.Tensor] = None):
        B, N, C = x.shape

        # ===== 1) MAIN BRANCH (Q, K, V) =====
        if self.use_global and q_global is not None:
            kv = self.qkv(x)  
            kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            k, v = kv.unbind(0)  

            q = q_global.repeat(B // q_global.shape[0], 1, 1, 1)  
            q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  
        else:
            qkv = self.qkv(x) 
            qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)

        q = q * self.scale  

        attn = q @ k.transpose(-2, -1)  
        attn = self.rel_pos(attn)   

        G = self._compute_gate(x)  
        attn = attn * (1.0 + G)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x
