# Replaced /PVT/classification/pvt.py's class Attention with below one (Official github)
# https://github.com/whai362/PVT/blob/v2/classification/pvt.py

class Attention(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
                 attn_drop=0., proj_drop=0., sr_ratio=1, gate_tanh=True):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads  = num_heads
        self.head_dim   = dim // num_heads
        self.scale      = qk_scale or self.head_dim ** -0.5
        self.sr_ratio   = sr_ratio
        self.gate_tanh  = gate_tanh

        # main projections
        self.q  = nn.Linear(dim, dim,     bias=qkv_bias)
        self.kv = nn.Linear(dim, dim*2,   bias=qkv_bias)

        # gating branch
        self.qk_gate  = nn.Linear(dim, self.head_dim*2, bias=qkv_bias)
        self.mix_gate = nn.Linear(1, 2)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj      = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        if sr_ratio > 1:
            self.sr   = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

        # zero-init
        # nn.init.zeros_(self.qk_gate.weight); nn.init.zeros_(self.qk_gate.bias)
        # nn.init.zeros_(self.mix_gate.weight); nn.init.zeros_(self.mix_gate.bias)

    def forward(self, x, H, W):
        B, N, C = x.shape

        # -------- Q --------
        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3) 

        # -------- K / V source --------
        if self.sr_ratio > 1:
            x_ = x.permute(0,2,1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0,2,1)     
            x_ = self.norm(x_)
            kv_src = x_
        else:
            kv_src = x                                             

        # -------- K, V --------
        kv = self.kv(kv_src).reshape(B, -1, 2, self.num_heads, self.head_dim) \
                            .permute(2,0,3,1,4)                 
        k, v = kv[0], kv[1]                                     

        # -------- Gate branch --------
        qg_full, _   = self.qk_gate(x).chunk(2, dim=-1)           
        _, kg_pool   = self.qk_gate(kv_src).chunk(2, dim=-1)     

        qg = qg_full.unsqueeze(1).expand(B, self.num_heads, N,  self.head_dim)   
        N_prime = k.shape[2]                                   
        kg = kg_pool.unsqueeze(1).expand(B, self.num_heads, N_prime, self.head_dim)  

        raw_gate = (qg @ kg.transpose(-2,-1)) * self.scale       
        gA, gB   = self.mix_gate(raw_gate.unsqueeze(-1)).chunk(2, -1)
        G        = (gA * gB).squeeze(-1)
        if self.gate_tanh:
            G = torch.tanh(G)

        attn = (q @ k.transpose(-2,-1)) * self.scale      
        attn = attn * (1. + G)
        attn = self.attn_drop(attn.softmax(-1))

        out = (attn @ v).transpose(1,2).reshape(B, N, C)
        return self.proj_drop(self.proj(out))
