#!/usr/bin/env python3
"""
Created on 17:42, Jul. 8th, 2023

@author: Anonymous
"""
import copy as cp
import numpy as np
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.pardir)
import utils.model

__all__ = [
    "naive_rnn",
]

class naive_rnn(K.Model):
    """
    `naive_rnn` model, with considering time information.
    """

    def __init__(self, params):
        """
        Initialize `naive_rnn` object.

        Args:
            params: Model parameters initialized by naive_rnn_params, updated by params.iteration.

        Returns:
            None
        """
        # First call super class init function to set up `K.Model`
        # style model and inherit it's functionality.
        super(naive_rnn, self).__init__()

        # Copy hyperparameters (e.g. network sizes) from parameter dotdict,
        # usually generated from naive_rnn_params() in params/naive_rnn_params.py.
        self.params = cp.deepcopy(params)

        # Create trainable vars.
        self._init_trainable()

    """
    init funcs
    """
    # def _init_trainable func
    def _init_trainable(self):
        """
        Initialize trainable variables.

        Args:
            None

        Returns:
            None
        """
        ## Initialize trainable rnn layers.
        model_rnn = K.models.Sequential(name="RNN")
        # Add `RNN` layers.
        for rnn_idx in range(len(self.params.rnn.n_units)):
            # Initialize `RNN` layer according to layer type.
            n_units = self.params.rnn.n_units[rnn_idx]; dropout = self.params.rnn.dropout[rnn_idx]
            recurrent_dropout = self.params.rnn.recurrent_dropout[rnn_idx]
            if self.params.rnn.model == "LSTM":
                layer_i = K.layers.LSTM(
                    # Modified `LSTM` layer parameters.
                    units=n_units, dropout=dropout, recurrent_dropout=recurrent_dropout,
                    return_sequences=True if rnn_idx < len(self.params.rnn.n_units) - 1 else False,
                    # Default `LSTM` layer parameters.
                    activation="tanh", recurrent_activation="sigmoid", use_bias=True,
                    kernel_initializer="glorot_uniform", recurrent_initializer="orthogonal", bias_initializer="zeros",
                    unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None,
                    activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None,
                    return_state=False, go_backwards=False, stateful=False, time_major=False, unroll=False
                )
            elif self.params.rnn.model == "GRU":
                layer_i = K.layers.GRU(
                    # Modified `GRU` layer parameters.
                    units=n_units, dropout=dropout, recurrent_dropout=recurrent_dropout,
                    return_sequences=True if rnn_idx < len(self.params.rnn.n_units) - 1 else False,
                    # Default `GRU` layer parameters.
                    activation="tanh", recurrent_activation="sigmoid", use_bias=True,
                    kernel_initializer="glorot_uniform", recurrent_initializer="orthogonal", bias_initializer="zeros",
                    kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None,
                    kernel_constraint=None, recurrent_constraint=None, bias_constraint=None,
                    return_state=False, go_backwards=False, stateful=False, unroll=False, time_major=False, reset_after=True
                )
            else:
                raise ValueError("ERROR: Get unknown layer type {}.".format(self.params.rnn.model))
            # Update `model_rnn`.
            model_rnn.add(layer_i)
        ## Initialize trainable fc layer. Then add FullConnect layer to do classification task.
        model_fc = K.models.Sequential(name="FullConnect")
        # Flatten convolved features to 1D-vector.
        model_fc.add(K.layers.Flatten(data_format="channels_last"))
        # Add hidden `Dense` layers.
        for d_hidden_i in self.params.fc.d_hidden:
             model_fc.add(K.layers.Dense(
                # Modified `Dense` parameters.
                units=d_hidden_i, activation="relu",
                # Default `Dense` parameters.
                use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            ))
        # Add `Dropout` after hidden `Dense` layer.
        if self.params.fc.dropout > 0.:
            model_fc.add(K.layers.Dropout(rate=self.params.fc.dropout, name="Dropout_{}".format("fc")))
        # Add the final classification `Dense` layer.
        model_fc.add(K.layers.Dense(
            # Modified `Dense` parameters.
            units=self.params.fc.d_output, activation="sigmoid",
            # Default `Dense` parameters.
            use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        )); model_fc.add(K.layers.Softmax(axis=-1))
        ## Stack all layers to get the final model.
        self.model = K.models.Sequential([model_rnn, model_fc,])
        optimizer = K.optimizers.Adam(learning_rate=self.params.lr_i)
        self.model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy",])

    """
    network funcs
    """
    # def fit func
    @utils.model.tf_scope
    def fit(self, X_train, y_train, epochs=1, batch_size=16):
        """
        Forward `naive_rnn` to get the final predictions.
        :param X_train: (n_train, seq_len, n_chennals) - The trainset data.
        :param y_train: (n_train, n_labels) - The trainset labels.
        :param epochs: int - The number of epochs.
        :param batch_size: int - The size of batch.
        """
        # Fit the model using [X_train,y_train].
        self.model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size)

    # def evaluate func
    @utils.model.tf_scope
    def evaluate(self, X_test, y_test):
        """
        Calculate loss between tensors value and target.
        :param X_test: (n_test, seq_len, n_chennals) - The trainset data.
        :param y_test: (n_test, n_labels) - The trainset labels.
        :return loss: float - The loss of current evaluation process.
        :return accuracy: float - The accuracy of current evaluation process.
        """
        return self.model.evaluate(X_test, y_test)

    # def summary func
    @utils.model.tf_scope
    def summary(self, print_fn=None):
        """
        Summary built model.
        :param print_fn: callable - Print function to use. Defaults to `print`. It will be called on each
            line of the summary. You can set it to a custom function in order to capture the string summary.
        """
        self.model.summary(
            # Modified summary parameters.
            print_fn=print_fn,
            # Default summary parameters.
            line_length=None, positions=None, expand_nested=True, show_trainable=True, layer_range=None
        )

if __name__ == "__main__":
    # local dep
    from params.naive_rnn_params import naive_rnn_params

    # Initialize training process.
    utils.model.set_seeds(42)
    # Initialize params.
    batch_size = 16; seq_len = 80; dataset = "eeg_palazzo2020decoding"
    naive_rnn_params_inst = naive_rnn_params(dataset=dataset)
    n_channels = naive_rnn_params_inst.model.n_channels; n_labels = naive_rnn_params_inst.model.n_labels
    # Get naive_rnn_inst.
    naive_rnn_inst = naive_rnn(naive_rnn_params_inst.model)
    # Initialize inputs.
    X = tf.random.uniform((batch_size, seq_len, n_channels), dtype=tf.float32)
    y = tf.cast(np.eye(n_labels)[np.random.randint(0, n_labels, size=(batch_size,))], dtype=tf.float32)
    # Fit and evaluate naive_rnn_inst.
    naive_rnn_inst.fit(X, y); _, _ = naive_rnn_inst.evaluate(X, y); naive_rnn_inst.summary()

