#!/usr/bin/env python3
"""
Created on 16:04, Apr. 13th, 2023

@author: Anonymous
"""
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir, os.pardir, os.pardir))
    from MultiHeadAttention import MultiHeadAttention
    from FeedForward import FeedForward
else:
    from models.domain_adaptation.DomainContrastiveNet.DomainContrastiveTransformer.layers.MultiHeadAttention import MultiHeadAttention
    from models.domain_adaptation.DomainContrastiveNet.DomainContrastiveTransformer.layers.FeedForward import FeedForward

__all__ = [
    "TransformerBlock",
]

class TransformerBlock(K.layers.Layer):
    """
    `TransformerBlock` acts as an encoder layer or a decoder layer.
    """

    def __init__(self, d_model, n_heads, d_head, attn_dropout, proj_dropout,
        d_ff, ff_dropout, norm_type="batch_norm", norm_first=False, **kwargs):
        """
        Initialize `TransformerBlock` object.

        Args:
            d_model: int - The dimensions of model embedding.
            n_heads: int - The number of attention heads in `mha` block.
            d_head: int - The dimensions of attention head in `mha` block.
            attn_dropout: float - The dropout probability of attention score in `mha` block.
            proj_dropout: float - The dropout probability of projection in `mha` block.
            d_ff: int - The dimensions of the hidden layer in `ffn` block.
            ff_dropout: float - The dropout probability in `ffn` block.
            norm_type: str - The type of normalization.
            norm_first: bool - The flag that indicates whether normalize data first.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(TransformerBlock, self).__init__(**kwargs)

        # Initialize parameters.
        assert norm_type in ["batch_norm", "layer_norm"], (
            "ERROR: Get unknown normalization type {} in TransformerBlock.py"
        ).format(norm_type)
        self.d_model = d_model; self.n_heads = n_heads; self.d_head = d_head
        self.attn_dropout = attn_dropout; self.proj_dropout = proj_dropout
        self.d_ff = d_ff; self.ff_dropout = ff_dropout
        self.norm_type = norm_type; self.norm_first = norm_first

    """
    network funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.

        Args:
            input_shape: tuple - The shape of input data, e.g. (batch_size, seq_len, d_model).

        Returns:
            None
        """
        # Initialize `mha` block.
        # mha - (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads * d_head)
        self.mha = MultiHeadAttention(d_model=self.d_model, n_heads=self.n_heads, d_head=self.d_head,
            attn_dropout=self.attn_dropout, proj_dropout=self.proj_dropout, use_bias=True)
        # Initialize the normalization layer after `mha` block.
        self.norm_mha = K.layers.LayerNormalization(
            # Modified `LayerNormalization` layer parameters.
            epsilon=1e-5,
            # Default `LayerNormalization` layer parameters.
            axis=-1, center=True, scale=True, beta_initializer="zeros", gamma_initializer="ones",
            beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None
        ) if self.norm_type == "layer_norm" else K.layers.BatchNormalization(
            # Modified `BatchNormalization` parameters.
            momentum=0.1, epsilon=1e-5,
            # Default `BatchNormalization` parameters.
            axis=-1, center=True, scale=True, beta_initializer="zeros",
            gamma_initializer="ones", moving_mean_initializer="zeros", moving_variance_initializer="ones",
            beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None
        )
        # Initialize `ffn` block.
        # ffn - (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        self.ffn = FeedForward(d_ff=self.d_ff, ff_dropout=self.ff_dropout)
        # Initialize the normalization layer after `ffn` block.
        self.norm_ffn = K.layers.LayerNormalization(
            # Modified `LayerNormalization` layer parameters.
            epsilon=1e-5,
            # Default `LayerNormalization` layer parameters.
            axis=-1, center=True, scale=True, beta_initializer="zeros", gamma_initializer="ones",
            beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None
        ) if self.norm_type == "layer_norm" else K.layers.BatchNormalization(
            # Modified `BatchNormalization` parameters.
            momentum=0.1, epsilon=1e-5,
            # Default `BatchNormalization` parameters.
            axis=-1, center=True, scale=True, beta_initializer="zeros",
            gamma_initializer="ones", moving_mean_initializer="zeros", moving_variance_initializer="ones",
            beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None
        )

    # def call func
    def call(self, emb, attn_score=None, attn_mask=None, key_padding_mask=None):
        """
        Forward layers in `TransformerBlock` to get the mha-ffn transformed result.

        Args:
            emb: (batch_size, seq_len, d_model) - The input embedding.
            attn_score: (batch_size, n_heads, seq_len, seq_len) - The attention score from the previous layer.
            attn_mask: (seq_len, seq_len) - The pre-defined attention mask within sequence.
            key_padding_mask: (batch_size, seq_len) - The pre-defined key mask within sequence.

        Returns:
            emb: (batch_size, seq_len, d_model) - The mha-ffn transformed embedding.
            attn_weight: (batch_size, n_heads, seq_len, seq_len) - The attention weight.
            attn_score: (batch_size, n_heads, seq_len, seq_len) - The attention score.
        """
        # Get the mha transformed embedding.
        # emb - (batch_size, seq_len, d_model)
        # attn_weight - (batch_size, n_heads, seq_len, seq_len)
        # attn_score - (batch_size, n_heads, seq_len, seq_len)
        if self.norm_first: emb = self.norm_mha(emb)
        attn_emb, attn_weight, attn_score = self.mha((emb, emb, emb),
            attn_score=attn_score, attn_mask=attn_mask, key_padding_mask=key_padding_mask); emb = attn_emb + emb
        if not self.norm_first: emb = self.norm_mha(emb)
        # Get the ffn transformed embedding.
        # emb - (batch_size, seq_len, d_model)
        if self.norm_first: emb = self.norm_ffn(emb)
        emb = self.ffn(emb) + emb
        if not self.norm_first: emb = self.norm_ffn(emb)
        # Return the final `emb` & `attn_weight` & `attn_score`.
        return emb, attn_weight, attn_score

if __name__ == "__main__":
    # Initialize macros.
    batch_size = 32; seq_len = 80; d_model = 128; n_heads = 8; d_head = 64; attn_dropout = 0.4
    proj_dropout = 0.4; d_ff = 128; ff_dropout = [0.4, 0.4]; norm_type = "batch_norm"; norm_first = False
    # Instantiate `TransformerBlock`.
    tb_inst = TransformerBlock(d_model=d_model, n_heads=n_heads, d_head=d_head, attn_dropout=attn_dropout,
        proj_dropout=proj_dropout, d_ff=d_ff, ff_dropout=ff_dropout, norm_type=norm_type, norm_first=norm_first)
    # Initialize random input for `tb_inst`.
    # emb - (batch_size, seq_len, d_model)
    emb = tf.cast(tf.random.uniform((batch_size, seq_len, d_model)), dtype=tf.float32)
    # attn_score - (batch_size, n_heads, seq_len, seq_len)
    attn_score = tf.cast(tf.random.uniform((batch_size, n_heads, seq_len, seq_len)), dtype=tf.float32)
    # attn_mask - (seq_len, seq_len)
    attn_mask = tf.cast(tf.random.uniform((seq_len, seq_len), minval=0, maxval=2, dtype=tf.int32), dtype=tf.bool)
    # key_padding_mask - (batch_size, seq_len)
    key_padding_mask = tf.cast(tf.random.uniform((batch_size, seq_len), minval=0, maxval=2, dtype=tf.int32), dtype=tf.bool)
    # Forward `tb_inst` with random input.
    emb, attn_weight, attn_score = tb_inst(emb, attn_score=attn_score, attn_mask=attn_mask, key_padding_mask=key_padding_mask)

