#!/usr/bin/env python3
"""
Created on 20:20, Apr. 12th, 2023

@author: Anonymous
"""
import numpy as np
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir, os.pardir, os.pardir))

__all__ = [
    "PositionEmbedding",
]

class PositionEmbedding(K.layers.Layer):
    """
    Sinusoidal positional encoding for non-recurrent neural networks.
    """

    def __init__(self, max_len, pe_mode=None, **kwargs):
        """
        Initialize `PositionEmbedding` object.

        Args:
            max_len: int - The maximum length of the element sequence.
            pe_mode: str - The mode of position embedding, `None` by default to measure the impact of position embedding.
            kwargs: The arguments related to initialize `tf.keras.layers.Layer`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(PositionEmbedding, self).__init__(**kwargs)

        # Initialize parameters.
        assert pe_mode in [None, "zero", "zeros", "normal", "uniform", "sincos"], (
            "ERROR: Get unknown position embedding mode {} in PositionEmbedding."
        ).format(pe_mode)
        self.max_len = max_len; self.pe_mode = pe_mode

    """
    init funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.

        Args:
            input_shape: tuple - The shape of input data, e.g. (batch_size, seq_len, d_model).

        Returns:
            None
        """
        # Initialize `d_model` from `input_shape`.
        self.d_model = input_shape[-1]; assert self.d_model % 2 == 0
        # Build position embedding variables according to `pe_mode`.
        getattr(self, "_build_{}".format(str(self.pe_mode).lower()))()

    # def _build_none func
    def _build_none(self):
        """
        Build the `None` position embedding variables, which is used to measure the impact of position embedding.

        Args:
            None

        Returns:
            None
        """
        # Set non-trainable empty `position_encodings` matrix.
        # position_encodings - (max_len, d_model)
        self.position_encodings = tf.constant(tf.random.uniform((self.max_len, self.d_model), minval=-2e-2, maxval=2e-2))

    # def _build_zero func
    def _build_zero(self):
        """
        Build the `zero` position embedding variables.

        Args:
            None

        Returns:
            None
        """
        # Set empty `position_encodings` matrix.
        # position_encodings - (max_len, d_model)
        self.position_encodings = tf.Variables(tf.random.uniform((self.max_len, 1), minval=-2e-2, maxval=2e-2))

    # def _build_zeros func
    def _build_zeros(self):
        """
        Build the `zeros` position embedding variables.

        Args:
            None

        Returns:
            None
        """
        # Set empty `position_encodings` matrix.
        # position_encodings - (max_len, d_model)
        self.position_encodings = tf.Variables(tf.random.uniform((self.max_len, self.d_model), minval=-2e-2, maxval=2e-2))

    # def _build_normal func
    def _build_normal(self):
        """
        Build the `normal` position embedding variables.

        Args:
            None

        Returns:
            None
        """
        # Set normal-distributed `position_encodings` matrix.
        # position_encodings - (max_len, d_model))
        self.position_encodings = tf.Variables(tf.random.normal((self.max_len, self.d_model), mean=0., stddev=1e-1))

    # def _build_uniform func
    def _build_uniform(self):
        """
        Build the `uniform` position embedding variables.

        Args:
            None

        Returns:
            None
        """
        # Set uniform-distributed `position_encodings` matrix.
        # position_encodings - (max_len, d_model)
        self.position_encodings = tf.Variables(tf.random.uniform((self.max_len, self.d_model), minval=0., maxval=1e-1))

    # def _build_sincos func
    def _build_sincos(self):
        """
        Build the `sincos` position embedding variables.

        Args:
            None

        Returns:
            None
        """
        # Empty `position_encodings` matrix.
        # position_encodings - (max_len, d_model)
        position_encodings = np.zeros((self.max_len, self.d_model), dtype=np.float32)
        # Get the indexes of available positions (i.e. within `max_len`).
        # position_idxs - (max_len, 1)
        position_idxs = np.expand_dims(np.arange(0, self.max_len, dtype=np.float32), axis=-1)
        # Get the divide term, i.e. $(1e4)*exp(\frac{-2i}{d_model})$.
        # div_term - (d_model//2,)
        div_term = np.exp(np.arange(0, self.d_model, 2, dtype=np.float32) * -(np.log(1e4) / self.d_model))
        # $PE_{p,2i} = sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$.
        position_encodings[:,0::2] = np.sin(position_idxs * div_term)
        # $PE_{p,2i + 1} = cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$
        position_encodings[:,1::2] = np.cos(position_idxs * div_term)
        # Set `position_encodings` as constant, i.e. not trainable.
        self.position_encodings = tf.constant(position_encodings, dtype=tf.float32)

    """
    network funcs
    """
    # def call func
    def call(self, emb):
        """
        Forward layers in `PositionEmbedding` to get the position-embedded result.

        Args:
            emb: (batch_size, seq_len, d_model) - The sequence of elements.

        Returns:
            emb: (batch_size, seq_len, d_model) - The sequence of position-embedded elements.
        """
        # Get the position embeddings `pe` according to the `seq_len`.
        # pos_emb - (seq_len, d_model)
        pos_emb = self.position_encodings[:emb.shape[1],:]
        # Add `pos_emb` to `emb` to get the position-embedded embedding.
        # Note: We have to make sure that `emb` is 0-mean 1-var distribution.
        # If we apply layer normalization over `emb`, `emb` is 0-mean 1/sqrt(d_model)-var
        # distribution, i.e. we have to multiply `emb` with `sqrt(d_model)`.
        emb = emb + tf.expand_dims(pos_emb, axis=0)
        # Return the final `emb`.
        return emb

if __name__ == "__main__":
    # Initialize macros.
    batch_size = 32; max_len = 80; d_model = 128; pe_mode = None
    # Instantiate `PositionEmbedding`.
    pe_inst = PositionEmbedding(max_len=max_len, pe_mode=pe_mode)
    # Forward `pe_inst` with 0s.
    emb = pe_inst(tf.zeros((batch_size, max_len, d_model), dtype=tf.float32))

