# Replaced TinyViT/models/tiny_vit.py's class Attention with below one (Official github)
# https://github.com/wkcn/TinyViT/blob/main/models/tiny_vit.py

class Attention(nn.Module):
    def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4.,
                 resolution=(14, 14),         
                 gate_share=1.0,        
                 attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.key_dim  = key_dim
        self.scale    = key_dim ** -0.5
        self.d_v      = int(attn_ratio * key_dim)

        # ——— projections ———
        self.qkv      = nn.Linear(dim, (2*key_dim + self.d_v) * num_heads)
        gdim          = int(key_dim * gate_share)
        self.qk_gate  = nn.Linear(dim, gdim * 2)
        self.gate_scale = gdim ** -0.5
        self.gate_linear = nn.Linear(1, 2)  
        self.proj     = nn.Linear(self.d_v * num_heads, dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)
        self.norm     = nn.LayerNorm(dim)

        # ——— TinyViT’s relative-bias table (unchanged) ———
        pts = list(itertools.product(range(resolution[0]), range(resolution[1])))
        off, idxs = {}, []
        for p1 in pts:
            for p2 in pts:
                d = (abs(p1[0]-p2[0]), abs(p1[1]-p2[1]))
                off.setdefault(d, len(off))
                idxs.append(off[d])
        self.register_buffer('attention_bias_idxs',
                             torch.LongTensor(idxs).view(len(pts), len(pts)),
                             persistent=False)
        self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(off)))

    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        if mode and hasattr(self, 'ab'):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):           
        B, N, _ = x.shape
        x = self.norm(x)

        # —— main Q,K,V ——
        qkv = self.qkv(x).view(B, N, self.num_heads, -1)
        q, k, v = qkv.split([self.key_dim, self.key_dim, self.d_v], dim=3)
        q, k, v = (t.permute(0,2,1,3) for t in (q,k,v))

        logits = (q @ k.transpose(-2, -1)) * self.scale
        logits += self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab

        # —— gate branch ——
        qg, kg = self.qk_gate(x).chunk(2, dim=-1)         
        qg = qg.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
        kg = kg.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
        raw_gate = (qg @ kg.transpose(-2, -1)) * self.gate_scale  
        gA, gB   = self.gate_linear(raw_gate.unsqueeze(-1)).chunk(2, dim=-1)
        G        = torch.tanh(gA * gB).squeeze(-1)                 

        # —— combine, softmax, dropout ——
        attn = (1.0 + G) * logits
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1,2).reshape(B, N, self.d_v * self.num_heads)
        out = self.proj_drop(self.proj(out))
        return out
