#!/usr/bin/env python3
"""
Created on 21:42, Jul. 26th, 2023

@author: Anonymous
"""
import numpy as np
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))

__all__ = [
    "MultiHeadAttention",
]

class MultiHeadAttention(K.layers.Layer):
    """
    `MultiHeadAttention` computes the scaled multi-headed attention.
    """

    def __init__(self, d_model, n_heads, d_head, attn_dropout=0., proj_dropout=0., emb_rotary=None, use_bias=True, **kwargs):
        """
        Initialize `MultiHeadAttention` object.

        Args:
            d_model: int - The dimensions of model embedding.
            n_heads: int - The number of attention heads.
            d_head: int - The dimensions of attention head.
            attn_dropout: float - The probability of attention score dropout.
            proj_dropout: float - The probability of projection dropout.
            emb_rotary: K.layers.Layer - The rotary embedding layer.
            use_bias: bool - The flag indicates whether use bias.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(MultiHeadAttention, self).__init__(**kwargs)

        # Initialize parameters.
        self.d_model = d_model; self.n_heads = n_heads; self.d_head = d_head
        self.attn_dropout = attn_dropout; self.proj_dropout = proj_dropout
        self.emb_rotary = emb_rotary; self.use_bias = use_bias

    """
    network funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.

        Args:
            input_shape: tuple - The shape of input data, e.g. (batch_size, seq_len, d_input).

        Returns:
            None
        """
        # Note: As we use tuple as input, we cannot initialize `d_input` from `input_shape`.
        # Initialize the query & key & value transformation matrices (perhaps w. bias).
        # W_[q,k,v] - (batch_size, seq_len, d_input) -> (batch_size, seq_len, n_heads, d_head)
        self.W_q = MHAMatrix(n_heads=self.n_heads, d_head=self.d_head, use_bias=self.use_bias)
        self.W_k = MHAMatrix(n_heads=self.n_heads, d_head=self.d_head, use_bias=self.use_bias)
        self.W_v = MHAMatrix(n_heads=self.n_heads, d_head=self.d_head, use_bias=self.use_bias)
        # Initialize the scaled dot-product attention layer.
        self.attention = ScaledDotProductAttention(d_head=self.d_head, attn_dropout=self.attn_dropout, scale_trainable=False)
        # Initialize the project layer.
        self.proj = K.Sequential(layers=[
            K.layers.Dense(
                # Modified `Dense` layer parameters.
                units=self.d_model,
                kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
                bias_initializer=K.initializers.constant(value=0.01),
                # Default `Dense` layer parameters.
                activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None,
                activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            ),
            K.layers.Dropout(rate=self.proj_dropout, noise_shape=None, seed=None),
        ])

    def call(self, embs, attn_score=None, attn_mask=None, key_padding_mask=None):
        """
        Forward layers in `MultiHeadAttention` to get the single-head attention result.

        Args:
            embs: tuple - The embeddings containing emb_[q,k,v], each element is (batch_size, seq_len, d_input).
            attn_score: (batch_size, n_heads, seq_len, seq_len) - The attention score from the previous layer.
            attn_mask: (seq_len, seq_len) - The pre-defined attention mask within sequence.
            key_padding_mask: (batch_size, seq_len) - The pre-defined key mask within sequence.

        Returns:
            emb: (batch_size, seq_len, d_model) - The single-head attention embedding.
            attn_weight: (batch_size, n_heads, seq_len, seq_len) - The attention weight.
            attn_score: (batch_size, n_heads, seq_len, seq_len) - The attention score.
        """
        # Initialize `emb_q` & `emb_k` & `emb_v` from `embs`.
        # emb_[q,k,v] - (batch_size, seq_len, d_input)
        emb_q, emb_k, emb_v = embs
        # Prepare query & key & value for attention computation.
        # emb_[q,k,v] - (batch_size, n_heads, seq_len, d_head)
        emb_q = tf.transpose(self.W_q(emb_q), perm=[0,2,1,3])
        emb_k = tf.transpose(self.W_k(emb_k), perm=[0,2,1,3])
        emb_v = tf.transpose(self.W_v(emb_v), perm=[0,2,1,3])
        # If `emb_rotary` is not None, further embed query & key.
        if self.emb_rotary is not None:
            emb_q = self.emb_rotary(emb_q); emb_k = self.emb_rotary(emb_k)
        # Calculate attention result from `emb_*`.
        # emb - (batch_size, n_heads, seq_len, d_head)
        # attn_weight - (batch_size, n_heads, seq_len, seq_len)
        # attn_score - (batch_size, n_heads, seq_len, seq_len)
        emb, attn_weight, attn_score = self.attention((emb_q, emb_k, emb_v),
            attn_score=attn_score, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
        # Transpose `emb` to the original dimensions.
        # emb - (batch_size, seq_len, n_heads, d_head)
        emb = tf.transpose(emb, perm=[0,2,1,3])
        # Concatenate multiple heads.
        # emb - (batch_size, seq_len, n_heads * d_head)
        emb = tf.reshape(emb, (*emb.shape[:-2], -1))
        # Project `emb` to the original dimensions.
        # emb - (batch_size, seq_len, d_model)
        emb = self.proj(emb)
        # Return the final `emb` & `attn_weight` & `attn_score`.
        return emb, attn_weight, attn_score

class MHAMatrix(K.layers.Layer):
    """
    `MHAMatrix` model does a linear transformation and splits the vector into given number of heads
    for multi-head attention. This is used to transform key, query, and value vectors.
    """

    def __init__(self, n_heads, d_head, use_bias=True, **kwargs):
        """
        Initialize `MHAMatrix` object.

        Args:
            n_heads: int - The number of attention heads.
            d_head: int - The dimensions of attention head.
            use_bias: bool - The flag indicates whether use bias.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(MHAMatrix, self).__init__(**kwargs)

        # Initialize parameters.
        self.n_heads = n_heads
        self.d_head = d_head
        self.use_bias = use_bias

    """
    network funcs
    """
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.

        Args:
            input_shape: tuple - The shape of input data, e.g. (batch_size, seq_len, d_input).

        Returns:
            None
        """
        # Initialize `d_input` from `input_shape`.
        self.d_input = input_shape[-1]
        # Initialize the transformation matrix (perhaps w. bias).
        # W - (batch_size, seq_len, d_input) -> (batch_size, seq_len, n_heads * d_head)
        self.W = K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=self.n_heads * self.d_head, use_bias=self.use_bias,
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
            bias_initializer=K.initializers.constant(value=0.01),
            # Default `Dense` layer parameters.
            activation=None, kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        )

    # def call func
    def call(self, emb):
        """
        Forward layers in `MHAMatrix` to get the linear transformed result.

        Args:
            emb: (batch_size, seq_len, d_input) - The input embedding.

        Returns:
            emb: (batch_size, seq_len, n_heads, d_head) - The linear transformed embedding.
        """
        # Get the shape of head from `emb`.
        # head_shape - tuple, should be (batch_size, seq_len)
        head_shape = emb.shape[:-1]
        # Linearly transform `emb` using `W`.
        # emb - (batch_size, seq_len, n_heads, d_head)
        emb = tf.reshape(self.W(emb), (*head_shape, self.n_heads, self.d_head))
        # Return the final `emb`.
        return emb

# def ScaledDotProductAttention class
class ScaledDotProductAttention(K.layers.Layer):
    """
    Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with
    optional residual attention from previous layer (Realformer: Transformer likes residual attention by He et al, 2020)
    and locality self sttention (Vision Transformer for Small-Size Datasets by Lee et al, 2021).
    """

    def __init__(self, d_head, attn_dropout=0., scale_trainable=False, **kwargs):
        """
        Initialize `ScaledDotProductAttention` object.

        Args:
            d_head: int - The dimensions of attention head.
            attn_dropout: float - The probability of dropout.
            scale_trainable: bool - The flag that indicates whether scale factor is trainable.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(ScaledDotProductAttention, self).__init__(**kwargs)

        # Initialize parameters.
        self.d_head = d_head; self.attn_dropout = attn_dropout; self.scale_trainable = scale_trainable

    """
    network funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.

        Args:
            input_shape: tuple - The shape of input data.

        Returns:
            None
        """
        # Initialize the dropout layer.
        self.dropout = K.layers.Dropout(rate=self.attn_dropout, noise_shape=None, seed=None)
        # Initialize scale factor.
        self.scale = tf.Variable(1. / np.sqrt(self.d_head), dtype=tf.float32, trainable=self.scale_trainable)

    # def call func
    def call(self, embs, attn_score=None, attn_mask=None, key_padding_mask=None):
        """
        Forward layers in `ScaledDotProductAttention` to get the attentioned result.

        Args:
            embs: tuple - The embeddings containing emb_[q,k,v], each element is (batch_size, n_heads, seq_len, d_head).
            attn_score: (batch_size, n_heads, seq_len, seq_len) - The attention score from the previous layer.
            attn_mask: (seq_len, seq_len) - The pre-defined attention mask within sequence.
            key_padding_mask: (batch_size, seq_len) - The pre-defined key mask within sequence.

        Returns:
            emb: (batch_size, n_heads, seq_len, d_head) - The attention embedding.
            attn_weight: (batch_size, n_heads, seq_len, seq_len) - The attention weight.
            attn_score: (batch_size, n_heads, seq_len, seq_len) - The attention score.
        """
        # Initialize `emb_q` & `emb_k` & `emb_v` from `embs`.
        # emb_[q,k,v] - (batch_size, n_heads, seq_len, d_head)
        emb_q, emb_k, emb_v = embs
        # Calculate scaled similarity score for all pairs of positions in an input sequence.
        # Here, we support residual attention in Realformer by He et al, 2020.
        # attn_score - (batch_size, n_heads, seq_len, seq_len)
        attn_score = (tf.matmul(emb_q, tf.transpose(emb_k, perm=[0,1,3,2])) * self.scale) if attn_score is None else\
            (tf.matmul(emb_q, tf.transpose(emb_k, perm=[0,1,3,2])) * self.scale) + attn_score
        # Use pre-defined attention mask to introduce inductive bias.
        # Here, we support locality self attention in Small-Size Dataset ViT by Lee et al, 2021.
        if attn_mask is not None:
            attn_score = tf.where(tf.expand_dims(tf.expand_dims(attn_mask, axis=0), axis=0), -np.inf, attn_score)
        # Use pre-defined key padding mask to ignore some keys and their corresponding values.
        if key_padding_mask is not None:
            attn_score = tf.where(tf.expand_dims(tf.expand_dims(key_padding_mask, axis=1), axis=2), -np.inf, attn_score)
        # Normalize the attention score to get attention weight.
        # attn_weight - (batch_size, n_heads, seq_len, seq_len)
        attn_weight = self.dropout(tf.nn.softmax(attn_score, axis=-1))
        # Calculate the attention embedding.
        # emb - (batch_size, n_heads, seq_len, d_head)
        emb = tf.matmul(attn_weight, emb_v)
        # Return the final `emb` & `attn_weight` & `attn_score`.
        return emb, attn_weight, attn_score

if __name__ == "__main__":
    # Initialize macros.
    batch_size = 32; seq_len = 80; d_model = 128
    n_heads = 16; d_head = 64; attn_dropout = 0.4; proj_dropout = 0.; use_bias = True
    # Instantiate `MultiHeadAttention`.
    mha_inst = MultiHeadAttention(d_model=d_model, n_heads=n_heads, d_head=d_head,
        attn_dropout=attn_dropout, proj_dropout=proj_dropout, use_bias=use_bias)
    # Forward `mha_inst` with random input.
    emb = mha_inst((tf.random.uniform((batch_size, seq_len, d_model)) for _ in range(3)))

