# Replaced xcit/xcit.py's class XCA with below one (Official github)
# Note that PLuG was applied only to XCA (not to class attention)
# https://github.com/facebookresearch/xcit/blob/main/xcit.py

class XCA(nn.Module):
    """ Cross-Covariance Attention (XCA) with PLuG-style pre-softmax gating.
        Channels attend to channels (D x D). Gate is shared across heads.
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
                 attn_drop=0., proj_drop=0., use_gate=True):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))  # as in XCA

        # main branch
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # ---------- PLuG gating branch ----------
        self.use_gate = use_gate
        if self.use_gate:
            self.qk_gate = nn.Linear(dim, 2 * self.head_dim, bias=qkv_bias)
            self.gate_linear = nn.Linear(1, 2) 

    def forward(self, x):
        B, N, C = x.shape
        H, D = self.num_heads, self.head_dim

        # ----- Main branch -----
        qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]          # (B, H, N, D)

        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn_main = (q @ k.transpose(-2, -1)) * self.temperature  # (B, H, D, D)

        # ----- Gate branch -----
        if self.use_gate:
            qgkg = self.qk_gate(x)                
            qg, kg = qgkg.split(D, dim=-1)       

            qg = qg.transpose(-2, -1)            
            kg = kg.transpose(-2, -1)            

            gate_scale = (N ** -0.5)
            raw_gate = (qg @ kg.transpose(-2, -1)) * gate_scale  

            gate_out = self.gate_linear(raw_gate.unsqueeze(-1))  
            gA, gB = gate_out.unbind(-1)                          
            G = torch.tanh(gA * gB)                            

            G = G.unsqueeze(1).expand(B, H, D, D)              

            attn = attn_main * (1.0 + G)
        else:
            attn = attn_main

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

