
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))
    from SpatialAttention import *
    from SubjectLayer import *
else:
    from .SpatialAttention import *
    from .SubjectLayer import *

__all__ = [
    "SubjectBlock",
]

class SubjectBlock(K.layers.Layer):
    """
    `SubjectBlock` layer used to transform each channel with specified subject id.
    """

    def __init__(self, n_output_channels, n_harmonics, drop_distance, **kwargs):
        """
        Initialize `SubjectBlock` object.
        :param n_output_channels: The number of output channels.
        :param n_harmonics: The number of harmonics of each attention weight.
        :param drop_distancec: The radius of the circle field to be dropped. We use a uniform distribution
            to draw the center of drop circle field from input eeg locations.
        :param kwargs: The arguments related to initialize `tf.keras.layers.Layer`-style object.
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(SubjectBlock, self).__init__(**kwargs)

        # Initialize parameters.
        self.n_output_channels = n_output_channels
        self.n_harmonics = n_harmonics
        self.drop_distance = drop_distance

    """
    network funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.
        :param input_shape: tuple - The shape of input data.
        """
        # Initialize SpatialAttention.
        # self.sa_layer = SpatialAttention(n_output_channels=self.n_output_channels[0],
        #     n_harmonics=self.n_harmonics, drop_distance=self.drop_distance)
        # Initialize Conv1D.
        self.conv1d_layer = K.layers.Conv1D(self.n_output_channels[1], kernel_size=1, activation=None, trainable = False)
        # self.conv1d_layer2 = K.layers.Conv1D(self.n_output_channels[1], kernel_size=2, activation=None)
        # self.conv2d_layer = K.layers.Conv2D(self.n_output_channels[1], kernel_size = (1,3), padding = 'valid', activation = None,use_bias=True,
        #     kernel_initializer="he_normal", bias_initializer=K.initializers.Constant(value=0.01), kernel_regularizer=None,
        #     bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)
        # self.poollayer = K.layers.MaxPool2D(pool_size=(1,3),strides=None, padding="valid", data_format="channels_last")
        self.drop = K.layers.Dropout(rate=0.3, trainable = False)
        self.bm = K.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer="zeros",
            gamma_initializer="ones", moving_mean_initializer="zeros", moving_variance_initializer="ones",
            beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, trainable = False
        )
        self.dense_layer1 = K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=512,
            # Default `Dense` layer parameters.
            activation=None, use_bias=True, kernel_initializer="he_uniform",
            bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        )
        self.dense_layer2 = K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=256,
            # Default `Dense` layer parameters.
            activation=None, use_bias=True, kernel_initializer="he_uniform",
            bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        )
        # Initialize SubjectLayer.
        self.sl_layer = SubjectLayer(n_output_channels=self.n_output_channels[2])
        # Build super to set up `K.layers.Layer`-style model and inherit it's network.
        super(SubjectBlock, self).build(input_shape)

    # def call func
    def call(self, inputs):
        """
        Forward layers in `SubjectBlock` to get the final result.
        :param inputs: (3[list],) - The input data.
        :param subject_id: (batch_size,) - The subject id of input data.
        :return outputs: (batch_size, seq_len, n_output_channels) - The subject-transformed data.
        """
        # Get the [X,locations,subject_id] from inputs.
        # X - (batch_size, seq_len, n_input_channels)
        # locations - (batch_size, n_input_channels, 2)
        # subject_id - (batch_size,)
        # X, locations, subject_id = inputs
        X, subject_id = inputs
        # Forward layers in SubjectBlock.
        # outputs - (batch_size, seq_len, n_output_channels)
        # outputs = self.sa_layer(inputs=X, locations=locations)
        # outputs = self.conv1d_layer(inputs=outputs)
        outputs = self.conv1d_layer(inputs=X)
        outputs = self.bm(self.drop(outputs))
        outputs = self.dense_layer1(inputs=outputs)
        outputs = self.dense_layer2(outputs)
        # outputs = self.conv1d_layer2(outputs)
        # outputs = self.conv2d_layer(inputs = X)
        # outputs = self.poollayer(outputs)
        # outputs = self.bm(self.drop(self.poollayer(outputs)))
        # outputs = self.bm(self.drop(outputs))
        # outputs = tf.squeeze(outputs)
        # outputs = self.sl_layer(inputs=outputs, subject_id=subject_id)
        # Return the final `outputs`.
        return outputs

if __name__ == "__main__":
    import numpy as np

    # macro
    batch_size = 16; seq_len = 100; n_input_channels = 32; n_subjects = 42
    n_output_channels = [16, 16, 32]; n_harmonics = 32; drop_distance = 1.

    # Instantiate SubjectBlock.
    sb_inst = SubjectBlock(n_output_channels, n_harmonics, drop_distance)
    # Initialize input data & locations & subject_id.
    X = tf.random.normal((batch_size, seq_len, 3, n_input_channels), dtype=tf.float32)
    locations = tf.random.normal((batch_size, n_input_channels, 2), dtype=tf.float32)
    subject_id = tf.cast(np.eye(n_subjects)[np.random.randint(0, n_subjects, size=(batch_size,))], dtype=tf.float32)
    # Forward layers in `sb_inst`.
    # outputs = sb_inst((X, locations, subject_id))
    outputs = sb_inst((X, subject_id))

