#!/usr/bin/env python3
"""
Created on 21:08, Aug. 3rd, 2023

@author: Anonymous
"""
import numpy as np
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))

__all__ = [
    # Value Embeddings.
    "TokenEmbedding",
    # Position Embeddings.
    "RotaryEmbedding",
]

"""
value embeddings
"""
# def TokenEmbedding class
class TokenEmbedding(K.layers.Layer):
    """
    Token embedding to transform the raw time series.
    """

    def __init__(self, d_model, kernel_size=3, **kwargs):
        """
        Initialize `TokenEmbedding` object.

        Args:
            d_model: int - The dimensions of model embedding.
            kernel_size: int - The size of convolution kernel.
            kwargs: dict - The arguments related to initialize `tf.keras.layers.Layer`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(TokenEmbedding, self).__init__(**kwargs)

        # Initialize parameters.
        self.d_model = d_model; self.kernel_size = kernel_size

    """
    init funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.

        Args:
            input_shape: tuple - The shape of input data, e.g., (batch_size, seq_len, n_channels).

        Returns:
            None
        """
        # Initialize token embedding layer.
        # TODO: Support `circular` padding mode as in `nn.Conv1d`, and configurable `he_normal` as in
        # Time-Series-Library (https://github.com/thuml/Time-Series-Library/blob/main/layers/Embed.py).
        # emb_token - (batch_size, seq_len, n_channels) -> (batch_size, seq_len, d_model)
        self.emb_token = K.layers.Conv1D(
            # Modified `Conv1D` layer parameters.
            filters=self.d_model, kernel_size=self.kernel_size,
            padding="same", use_bias=False, kernel_initializer="he_normal",
            # Default `Conv1D` layer parameters.
            strides=1, data_format="channels_last", dilation_rate=1, groups=1, activation=None,
            bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        )

    """
    network funcs
    """
    # def call func
    def call(self, X):
        """
        Forward layers in `TokenEmbedding` to get the token-embedded result.

        Args:
            X: (batch_size, seq_len, n_channels) - The raw time series.

        Returns:
            emb: (batch_size, seq_len, d_model) - The sequence of token-embedded elements.
        """
        return self.emb_token(X)

"""
position embeddings
"""
# def RotaryEmbedding class
class RotaryEmbedding(K.layers.Layer):
    """
    Rotary embedding to inject relative position information.
    """

    def __init__(self, theta=1e4, **kwargs):
        """
        Initialize `RotaryEmbedding` object.

        Args:
            theta: float - The power base of rotation angle.
            kwargs: dict - The arguments related to initialize `tf.keras.layers.Layer`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(RotaryEmbedding, self).__init__(**kwargs)

        # Initialize parameters.
        self.theta = theta

    """
    init funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.

        Args:
            input_shape: tuple - The shape of input data, e.g., (batch_size, seq_len, n_channels).

        Returns:
            None
        """
        # Initialize `d_model` from `input_shape`.
        self.d_model = input_shape[-1]
        assert self.d_model % 2 == 0, "ERROR: The dimensions of model must be a multiples of 2."
        # TODO: Support theta rescale according to `theta_rescale_factor`, which is proposed by reddit user `bloc97`.
        # To rescale rotary embeddings to longer sequence length without fine-tuning has some connection to NTK literature
        # (https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/).
        # >>> self.theta *= theta_rescale_factor ** (self.d_model / (self.d_model - 2))
        # Initialize `freqs`, i.e., the rotation angles, only support language `freqs` currently.
        # TODO: Support `freqs` for different modalities (e.g., language, pixel, constant) as in `lucidrains`'s implementation
        # (https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py).
        # >>> if self.freqs is not None:
        # >>>     freqs = self.freqs
        # >>> elif self.freq_type == "language":
        # >>>     freqs = 1. / (self.theta ** (tf.range(0, self.d_model, 2)[:(self.d_model // 2)] / self.d_model))
        # >>> elif self.freq_type == "pixel":
        # >>>     freqs = tf.linspace(1., self.max_freq / 2, self.d_model // 2) * np.pi
        # >>> elif self.freq_type == "constant":
        # >>>     freqs = tf.ones((self.n_freqs,))
        # >>> else:
        # >>>     raise ValueError("ERROR: Unknown frequency type {}.".format(self.freq_type))
        # freqs - (d_model // 2,)
        freqs = 1. / (self.theta ** (tf.range(0, self.d_model, 2)[:(self.d_model // 2)] / self.d_model))
        # TODO: Support trainable `freqs` according to trainable flag.
        self.freqs = tf.cast(freqs, dtype=tf.float32)

    """
    network funcs
    """
    # def call func
    def call(self, emb):
        """
        Forward layers in `RotaryEmbedding` to get the rotary-embedded result.

        Args:
            emb: (batch_size, *, seq_len, d_model) - The sequence of elements.

        Returns:
            emb: (batch_size, *, seq_len, d_model) - The sequence of rotary-embedded elements.
        """
        # Initialize the indices of available positions (i.e. within `max_len`).
        # position_idxs - (seq_len,)
        position_idxs = tf.range(emb.shape[-2], dtype=emb.dtype)
        # Construct `freqs` according to `freqs` & `position_idxs`.
        # freqs - (seq_len, d_model // 2)
        freqs = tf.einsum("..., f -> ... f", position_idxs, self.freqs)
        # freqs - (seq_len, d_model)
        freqs = tf.repeat(freqs, repeats=2, axis=-1)
        # Calculate the rotary-embedded result.
        # emb - (batch_size, *, seq_len, d_model)
        emb = tf.cos(freqs) * emb + tf.sin(freqs) * tf.reshape((
            tf.reverse(tf.reshape(emb, shape=(*emb.shape[:-1], emb.shape[-1] // 2, 2)), axis=[-1]) *\
            tf.broadcast_to(tf.constant([-1., 1.], dtype=emb.dtype), shape=(*emb.shape[:-1], emb.shape[-1] // 2, 2))
        ), shape=emb.shape)
        # Return the final `emb`.
        return emb

"""
plot funcs
"""
# def plot_rope_decay func
def plot_rope_decay(freqs, max_len=80, path_img=None):
    """
    Plot the long-term decay of RoPE.

    Args:
        freqs: (d_model // 2,) - The rotation angles.
        max_len: int - The maximum length of sequence.
        path_img: str - The path of images to save.

    Returns:
        None
    """
    import seaborn as sns
    import matplotlib.pyplot as plt
    # Construct the decay function according to `freqs`.
    func = lambda x: np.sum([np.sum(np.cos(x * freqs[:freq_idx+1])) for freq_idx in range(len(freqs))]) / len(freqs)
    # Generate data for the corresponding function.
    # x - (2 * max_len - 1,); y - (2 * max_len - 1,)
    x = np.linspace(start=-(max_len - 1), stop=(max_len - 1), num=(2 * max_len - 1))
    y = np.array([func(x_i) for x_i in x], dtype=x.dtype)
    # Create a line plot.
    sns.set(style="whitegrid"); sns.lineplot(x=x, y=y)
    # Add labels and a title.
    plt.xlabel("Relative Distance"); plt.ylabel("Relative Upper Bound"); plt.title("Long-term Decay of RoPE")
    # Show or save the plot.
    if path_img is None:
        plt.show()
    else:
        plt.savefig(os.path.join(path_img, "rope_decay.png"))
    plt.close("all")

if __name__ == "__main__":
    # Initialize macros.
    batch_size = 32; seq_len = 80; n_channels = 55; d_model = 128
    path_img = os.path.join(os.getcwd(), "__image__")
    if not os.path.exists(path_img): os.makedirs(path_img)

    ## Forward value embeddings.
    # Initialize raw input `X`.
    # X - (batch_size, seq_len, n_channels)
    X = tf.random.uniform((batch_size, seq_len, n_channels), dtype=tf.float32)
    # Instantiate `TokenEmbedding`.
    emb_token_inst = TokenEmbedding(d_model=d_model)
    # Forward layers in `TokenEmbedding`.
    # emb - (batch_size, seq_len, d_model)
    emb = emb_token_inst(X)
    ## Forward position embeddings.
    # Initialize embedded input `emb`.
    # emb - (batch_size, seq_len, d_model)
    emb = tf.random.uniform((batch_size, seq_len, d_model), dtype=tf.float32)
    # Instantiate `RotaryEmbedding`.
    emb_rot_inst = RotaryEmbedding(theta=1e4)
    # Forward layers in `RotaryEmbedding`.
    # emb - (batch_size, seq_len, d_model)
    emb = emb_rot_inst(emb)
    # Plot RoPE decay of `RotaryEmbedding`.
    plot_rope_decay(emb_rot_inst.freqs.numpy(), max_len=seq_len, path_img=path_img)

