#!/usr/bin/env python3
"""
Created on 11:16, Dec. 27th, 2022

@author: Anonymous
"""
import copy as cp
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))
    from SubjectLayer import *
else:
    from .SubjectLayer import *

__all__ = [
    "SubjectBlock",
]

class SubjectBlock(K.layers.Layer):
    """
    `SubjectBlock` layer used to transform each channel with specified subject id.
    """

    def __init__(self, params, **kwargs):
        """
        Initialize `SubjectBlock` object.

        Args:
            params: DotDict - The dict of parameters, containing [cnn,hidden,subject].
            kwargs: The arguments related to initialize `tf.keras.layers.Layer`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(SubjectBlock, self).__init__(**kwargs)

        # Initialize parameters.
        self.params = cp.deepcopy(params)

    """
    network funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.
        :param input_shape: tuple - The shape of input data.
        """
        ## Construct CNN block.
        # Initialize CNN block.
        self.cnn_block = K.models.Sequential(name="CNN")
        # Add `Conv1D` layers.
        for cnn_idx in range(len(self.params.cnn.n_filters)):
            # Initialize `Conv1D` layer.
            n_filters = self.params.cnn.n_filters[cnn_idx]
            self.cnn_block.add(K.layers.Conv1D(
                # Modified `Conv1D` layer parameters.
                filters=n_filters, kernel_size=1, name="Conv1D_{:d}".format(cnn_idx),
                # Default `Conv1D` layer parameters.
                strides=1, padding="valid", data_format="channels_last", dilation_rate=1, groups=1, activation=None,
                use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            ))
        # Add `Dropout` layer.
        self.cnn_block.add(K.layers.Dropout(rate=self.params.cnn.dropout_ratio, trainable=False))
        # Add `BatchNormalization` layer.
        self.cnn_block.add(K.layers.BatchNormalization(
            # Default `BatchNormalization` layer parameters.
            axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer="zeros",
            gamma_initializer="ones", moving_mean_initializer="zeros", moving_variance_initializer="ones",
            beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, trainable=False
        ))
        ## Construct hidden block.
        # Initialize the hidden block.
        self.hidden_block = K.models.Sequential(name="hidden")
        # Add `Dense` layers.
        for hidden_idx in range(len(self.params.hidden.n_hiddens)):
            # Initialize `Dense` layer.
            n_hiddens = self.params.hidden.n_hiddens[hidden_idx]
            self.hidden_block.add(K.layers.Dense(
                # Modified `Dense` layer parameters.
                units=n_hiddens, kernel_initializer="he_uniform",
                # Default `Dense` layer parameters.
                activation=None, use_bias=True, bias_initializer="zeros", kernel_regularizer=None,
                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            ))
        ## Construct subject layer.
        # Initialize the subject layer.
        self.subject_layer = SubjectLayer(n_channels_output=self.params.subject.n_channels_output)
        # Build super to set up `K.layers.Layer`-style model and inherit it's network.
        super(SubjectBlock, self).build(input_shape)

    # def call func
    def call(self, inputs):
        """
        Forward layers in `SubjectBlock` to get the final result.

        Args:
            inputs: (2[list],) - The input data.

        Returns:
            Z: (batch_size, seq_len, n_channels_output) - The subject-transformed data.
        """
        # Get the [X,locations,subject_id] from inputs.
        # X - (batch_size, seq_len, n_input_channels)
        # subject_id - (batch_size,)
        X, subject_id = inputs
        # Forward layers in SubjectBlock.
        # Z - (batch_size, seq_len, n_output_channels)
        Z = self.cnn_block(X)
        Z = self.hidden_block(Z)
        Z = self.subject_layer(inputs=Z, subject_id=subject_id)
        # Return the final `Z`.
        return Z

if __name__ == "__main__":
    import numpy as np
    # local dep
    from params.conv_net_params import conv_net_params

    # Initialize macros.
    batch_size = 32; seq_len = 100; n_input_channels = 55; n_subjects = 64
    conv_net_params_inst = conv_net_params(dataset="eeg_anonymous")

    # Instantiate SubjectBlock.
    sb_inst = SubjectBlock(conv_net_params_inst.model.subject)
    # Initialize input data & locations & subject_id.
    X = tf.random.normal((batch_size, seq_len, n_input_channels), dtype=tf.float32)
    subject_id = tf.cast(np.eye(n_subjects)[np.random.randint(0, n_subjects, size=(batch_size,))], dtype=tf.float32)
    # Forward layers in `sb_inst`.
    Z = sb_inst((X, subject_id))

