#!/usr/bin/env python3
"""
Created on 23:17, Aug. 3rd, 2023

@author: Anonymous
"""
import copy as cp
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir))
    from layers import *
else:
    from .layers import *
import utils.model

__all__ = [
    "roformer",
]

class roformer(K.Model):
    """
    `roformer` model, with considering time information.
    """

    def __init__(self, params, **kwargs):
        """
        Initialize `roformer` object.

        Args:
            params: Model parameters initialized by roformer_params, updated by params.iteration.
            kwargs: The arguments related to initialize `tf.keras.Model`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.Model`
        # style model and inherit it's functionality.
        super(roformer, self).__init__(**kwargs)

        # Copy hyperparameters (e.g. network sizes) from parameter dotdict,
        # usually generated from roformer_params() in params/roformer_params.py.
        self.params = cp.deepcopy(params)

        # Create trainable vars.
        self._init_trainable()

    """
    init funcs
    """
    # def _init_trainable func
    def _init_trainable(self):
        """
        Initialize trainable variables.

        Args:
            None

        Returns:
            None
        """
        ## Construct embedding layers.
        # Initialize the embedding layer for input.
        # emb_input - (batch_size, seq_len, d_input) -> (batch_size, seq_len, d_model)
        self.emb_input = TokenEmbedding(d_model=self.params.encoder.d_model, kernel_size=3, name="emb_input")
        # Note: Not initialize the position embedding layer!
        ## Construct encoder block.
        # Initialize encoder block.
        # encoder - (batch_size, seq_len, d_model) -> (batch_size, seq_len // pool_size, d_model)
        self.encoder = TransformerStack(self.params.encoder, name="encoder")
        ## Construct fc block.
        # Initialize fc block.
        model_fc = K.models.Sequential(name="FullConnect")
        # Flatten convolved features to 1D-vector.
        model_fc.add(K.layers.Flatten(data_format="channels_last"))
        # Add hidden `Dense` layers.
        for d_hidden_i in self.params.fc.d_hidden:
             model_fc.add(K.layers.Dense(
                # Modified `Dense` parameters.
                units=d_hidden_i, activation="relu",
                kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
                bias_initializer=K.initializers.constant(value=0.01),
                # Default `Dense` parameters.
                use_bias=True, kernel_regularizer=None, bias_regularizer=None,
                activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            ))
        # Add `Dropout` after hidden `Dense` layer.
        if self.params.fc.dropout > 0.:
            model_fc.add(K.layers.Dropout(rate=self.params.fc.dropout, name="Dropout_{}".format("fc")))
        # Add the final classification `Dense` layer.
        model_fc.add(K.layers.Dense(
            # Modified `Dense` parameters.
            units=self.params.fc.d_output,
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
            bias_initializer=K.initializers.constant(value=0.01),
            # Default `Dense` parameters.
            activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        )); model_fc.add(K.layers.Softmax(axis=-1)); self.fc = model_fc

    """
    network funcs
    """
    # def call func
    def call(self, inputs, training=None, mask=None):
        """
        Forward `roformer` to get the final predictions.

        Args:
            inputs: tuple - The input data.
            training: Boolean or boolean scalar tensor, indicating whether to run
                the `Network` in training mode or inference mode.
            mask: A mask or list of masks. A mask can be either a tensor or None (no mask).

        Returns:
            outputs: (batch_size, n_labels) - The output labels.
            loss: tf.float32 - The corresponding loss.
        """
        # Initialize components of inputs.
        # X - (batch_size, seq_len, n_channels)
        # y - (batch_size,)
        X = inputs[0]; y_true = inputs[1]
        # Embed `X` to get `X_emb`.
        # X_emb - (batch_size, n_patches, d_model)
        X_emb = self.emb_input(X)
        # Forward `encoder` to get the encoder-transformed embedding.
        # X_emb - (batch_size, n_patches // pool_size, d_model)
        X_emb, _, _ = self.encoder(X_emb)
        # Forward `fc` to get the final prediction.
        # y_pred - (batch_size, n_labels)
        y_pred = self.fc(X_emb)
        # Calculate the binary cross entropy loss.
        # loss - tf.float32
        loss = tf.reduce_mean(self._loss_bce(y_pred, y_true))
        # Return the final `outputs` & `loss`.
        return y_pred, loss

    # def _loss_bce func
    @utils.model.tf_scope
    def _loss_bce(self, value, target):
        """
        Calculates binary cross entropy between tensors value and target.
        Get mean over last dimension to keep losses of different batches separate.
        :param value: (batch_size,) - Value of the object.
        :param target: (batch_size,) - Target of the object.
        :return loss: (batch_size,) - Loss between value and target.
        """
        # Note: `tf.nn.softmax_cross_entropy_with_logits` needs unscaled log probabilities,
        # we must not add `tf.nn.Softmax` layer at the last of the model.
        # loss - (batch_size,)
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=target,logits=value) if type(value) is not list else\
            [tf.nn.softmax_cross_entropy_with_logits(labels=target[i],logits=value[i]) for i in range(len(value))]
        # Return the final `loss`.
        return loss

    """
    tool funcs
    """
    # def summary func
    @utils.model.tf_scope
    def summary(self, print_fn=None):
        """
        Summary built model.
        :param print_fn: callable - Print function to use. Defaults to `print`. It will be called on each
            line of the summary. You can set it to a custom function in order to capture the string summary.
        """
        super(roformer, self).summary(print_fn=print_fn)

if __name__ == "__main__":
    import numpy as np
    # local dep
    from params.roformer_params import roformer_params

    # macro
    dataset = "eeg_anonymous"; batch_size = 32; seq_len = 80; n_channels = 55

    # Initialize training process.
    utils.model.set_seeds(42)

    # Instantiate params.
    roformer_params_inst = roformer_params(dataset=dataset)
    n_labels = roformer_params_inst.model.n_labels
    # Instantiate roformer.
    roformer_inst = roformer(roformer_params_inst.model)
    # Initialize input X & label y_true.
    X = tf.random.normal((batch_size, seq_len, n_channels), dtype=tf.float32)
    y_true = tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), n_labels), dtype=tf.float32)
    # Forward layers in `roformer_inst`.
    outputs, loss = roformer_inst((X, y_true)); roformer_inst.summary()

