# PyTorch
import torch.nn as nn


class MHSA(nn.Module):
    """Multihead Self-Attention with custom residuals connections.

    Parameters
    ----------
    hidden_dim: int
      Hidden dimension (in this setting, it matches the embeddings dimension)
    num_heads: int
      nn.MultiheadAttention number of heads.
    eps_layer_norm: float
      epsilon parameter for nn.LayerNorm applied
      preliminary to the first and last layers of the architecture.
    p_dropout: float
      Dropout probability on the input.
    rff_depth: int, default 1
      Depth of the number of row-wise feed forward layers.
    """

    def __init__(
            self,
            hidden_dim: int,
            num_heads: int,
            eps_layer_norm: float,
            p_dropout: float,
            rff_depth: int, 
            batch_first:bool
    ):
        super(MHSA, self).__init__()

        # *** Encoder definition ***

        # h x h linear transform, with h the dimension of axis=-1
        self.W_res = nn.Linear(hidden_dim, hidden_dim)

        # We just normalize over the embeddings
        self.layer_norm = nn.LayerNorm((hidden_dim), eps=eps_layer_norm)

        self.mhsa_block = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=p_dropout,
            kdim=hidden_dim,
            vdim=hidden_dim,
            batch_first=batch_first
        )

        # Row-wise feed forward
        self.hidden_dropout = (nn.Dropout(p=p_dropout) if p_dropout else None)
        self.init_rff(hidden_dim, rff_depth)


    def init_rff(self, dim_out, rff_depth):
        # Same as NPT: https://github.com/OATML/non-parametric-transformers/blob/main/npt/model/npt_modules.py
        self.rff = [nn.Linear(dim_out, 4 * dim_out), nn.GELU()]

        if self.hidden_dropout is not None:
            self.rff.append(self.hidden_dropout)

        for i in range(rff_depth - 1):
            self.rff += [nn.Linear(4 * dim_out, 4 * dim_out), nn.GELU()]

            if self.hidden_dropout is not None:
                self.rff.append(self.hidden_dropout)

        self.rff += [nn.Linear(4 * dim_out, dim_out)]

        if self.hidden_dropout is not None:
            self.rff.append(self.hidden_dropout)

        self.rff = nn.Sequential(*self.rff)


    def forward(self, X, key_padding_mask=None):
      """Multihead Self-Attention forward.

      X: Tensor of shape `(batch_size*n_samples, n_nodes, embed_dim)`
        The second dimension might have padded entries
      key_padding_mask: Tensor of shape `(batch_size*n_samples, n_nodes)`
        Maks of  X padded entries. A `True` value denotes padding token
        to be ignored for the purpose of attention.  
      """
      X_norm = self.layer_norm(X)

      # Equation (4) NPT paper
      attn_values, _ = self.mhsa_block(
        query=X_norm, key=X_norm, value=X_norm, key_padding_mask=key_padding_mask
      )
      X_res = self.W_res(X) + attn_values

      # Equation (5) NPT paper
      output = X_res + self.rff(self.layer_norm(X_res))

      return output