# Modified from torch 1.10
# Put this code in your conda environment ~/anaconda3/envs/mask2former/lib/python3.8/site-packages/torch/nn/modules/activation.py
# Make sure to add PLuGMultiheadAttention to __init__.py

# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py
# In class SelfAttentionLayer, replace 'self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)' to 'self.self_attn = nn.PLuGMultiheadAttention(d_model, nhead, dropout=dropout)'

class PLuGMultiheadAttention(Module):
    r"""MultiheadAttention with PLuG (Pairwise Logit Gating).

    Identical API and behavior to torch.nn.MultiheadAttention,
    but applies multiplicative gating on the attention *logits*:

        logits = (QK^T) / sqrt(d)
        G = tanh( gateA(raw) * gateB(raw) )          # raw = (Q_g K_g^T)/sqrt(d)
        logits <- logits * (1 + G)
        attn = softmax(logits)

    where Q_g and K_g come from a light head-shared projection to head_dim.

    The gate uses the exact structure you provided:
      - qk_gate: Linear(embed_dim -> 2*head_dim), split into (q_gate, k_gate)
      - gate_linear: Linear(1 -> 2), take product then tanh
    """
    __constants__ = ['batch_first']
    bias_k: Optional[torch.Tensor]
    bias_v: Optional[torch.Tensor]

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = float(dropout)
        self.batch_first = batch_first

        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.head_dim = embed_dim // num_heads
        self.scaling = self.head_dim ** -0.5  

        if not self._qkv_same_embed_dim:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)

        self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self.qk_gate = NonDynamicallyQuantizableLinear(embed_dim, 2 * self.head_dim, bias=bias, **factory_kwargs)
        self.gate_linear = NonDynamicallyQuantizableLinear(1, 2, bias=True, **factory_kwargs)

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)

        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

        xavier_uniform_(self.qk_gate.weight)
        if self.qk_gate.bias is not None:
            constant_(self.qk_gate.bias, 0.)
        xavier_uniform_(self.gate_linear.weight)
        if self.gate_linear.bias is not None:
            constant_(self.gate_linear.bias, 0.)

    def __setstate__(self, state):
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True
        super().__setstate__(state)

    def _in_proj(self, input: Tensor, start: int, end: int) -> Tensor:
        weight = self.in_proj_weight[start:end, :]
        if self.in_proj_bias is None:
            bias = None
        else:
            bias = self.in_proj_bias[start:end]
        return F.linear(input, weight, bias)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        key_padding_mask: Optional[Tensor] = None,
        need_weights: bool = True,
        attn_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:

        if self.batch_first:
            query, key, value = [x.transpose(0, 1) for x in (query, key, value)]

        tgt_len, bsz, embed_dim = query.shape
        src_len = key.shape[0]
        assert embed_dim == self.embed_dim

        if not self._qkv_same_embed_dim:
            q = F.linear(query, self.q_proj_weight, self.in_proj_bias[:self.embed_dim] if self.in_proj_bias is not None else None)
            k = F.linear(key,   self.k_proj_weight, self.in_proj_bias[self.embed_dim:2*self.embed_dim] if self.in_proj_bias is not None else None)
            v = F.linear(value, self.v_proj_weight, self.in_proj_bias[2*self.embed_dim:] if self.in_proj_bias is not None else None)
        else:
            q = self._in_proj(query, 0, self.embed_dim)
            k = self._in_proj(key,   self.embed_dim, 2 * self.embed_dim)
            v = self._in_proj(value, 2 * self.embed_dim, 3 * self.embed_dim)

        q = q * self.scaling

    
        if self.bias_k is not None and self.bias_v is not None:
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)], dim=0)  
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)], dim=0)  
            src_len += 1
            if attn_mask is not None:
                if attn_mask.dim() == 2:
                    attn_mask = F.pad(attn_mask, (0, 1))
                else:
                    attn_mask = F.pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = F.pad(key_padding_mask, (0, 1))

        q = q.contiguous().view(tgt_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)  
        k = k.contiguous().view(src_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)  
        v = v.contiguous().view(src_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)  

        # Main attention logits
        attn_logits = torch.matmul(q, k.transpose(-2, -1)) 

        # PLuG
        q_in = query.transpose(0, 1)  
        k_in = key.transpose(0, 1)    

        if self.bias_k is not None:
            k_bias_in = self.bias_k.expand(1, bsz, self.embed_dim).transpose(0, 1) 
            k_in = torch.cat([k_in, k_bias_in], dim=1)  

        qg_full = self.qk_gate(q_in)                 
        kg_full = self.qk_gate(k_in)                 
        q_gate, _ = torch.tensor_split(qg_full, 2, dim=-1)  
        _, k_gate = torch.tensor_split(kg_full, 2, dim=-1)  
        if self.add_zero_attn:
            k_gate = F.pad(k_gate, (0, 0, 0, 1))  

        q_gate = q_gate.unsqueeze(1).expand(bsz, self.num_heads, tgt_len, self.head_dim)  
        k_gate = k_gate.unsqueeze(1).expand(bsz, self.num_heads, k.shape[-2], self.head_dim)  

        raw_gate = torch.matmul(q_gate, k_gate.transpose(-2, -1)) * self.scaling  
        s = raw_gate.unsqueeze(-1)                      
        gateAB = self.gate_linear(s)                      
        gateA, gateB = gateAB.unbind(dim=-1)             
        G = torch.tanh(gateA * gateB)                    

        # Apply multiplicative gating on logits
        attn_logits = attn_logits * (1.0 + G)

        attn_weights = attn_logits.reshape(bsz * self.num_heads, tgt_len, k.shape[-2])  

        if attn_mask is not None:
            if attn_mask.dim() == 2:
                attn_mask = attn_mask.unsqueeze(0) 
            elif attn_mask.dim() == 3:
                if attn_mask.shape[0] != attn_weights.shape[0]:
                    raise RuntimeError(f"The size of the attn_mask is not correct, expected {attn_weights.shape} but got {attn_mask.shape}")
            else:
                raise RuntimeError(f"attn_mask dim should be 2 or 3, got {attn_mask.dim()}")
            attn_weights = attn_weights + attn_mask

        if key_padding_mask is not None:
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, k.shape[-2])
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.view(bsz, 1, 1, k.shape[-2]).to(torch.bool), float('-inf')
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, k.shape[-2])

        if self.add_zero_attn:
            zero_attn = torch.zeros((bsz * self.num_heads, 1, self.head_dim), dtype=v.dtype, device=v.device)
            v = torch.cat([v.view(bsz * self.num_heads, k.shape[-2], self.head_dim), zero_attn], dim=1)
            zero_col = torch.zeros((bsz * self.num_heads, tgt_len, 1), dtype=attn_weights.dtype, device=attn_weights.device)
            attn_weights = torch.cat([attn_weights, zero_col], dim=-1)

        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

        v = v.view(bsz * self.num_heads, -1, self.head_dim)
        attn_output = torch.bmm(attn_weights, v)  

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, self.embed_dim) 
        attn_output = self.out_proj(attn_output.transpose(0, 1))  

        if need_weights:
            attn_weights_agg = attn_weights.view(bsz, self.num_heads, tgt_len, -1).mean(dim=1)  
        else:
            attn_weights_agg = None

        if self.batch_first:
            return attn_output.transpose(0, 1), attn_weights_agg 
        else:
            return attn_output, attn_weights_agg