
import copy as cp
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir))
    from layers import *
else:
    from .layers import *
import utils.model

__all__ = [
    "naive_transformer",
]

class naive_transformer(K.Model):
    """
    `naive_transformer` model, with considering time information.
    """

    def __init__(self, **kwargs):
        """
        Initialize `naive_transformer` object.

        Args:
            params: Model parameters initialized by naive_transformer_params, updated by params.iteration.
            kwargs: The arguments related to initialize `tf.keras.Model`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.Model`
        # style model and inherit it's functionality.
        super(naive_transformer, self).__init__(**kwargs)

        # Copy hyperparameters (e.g. network sizes) from parameter dotdict,
        # usually generated from naive_transformer_params() in params/naive_transformer_params.py.
        # self.params = cp.deepcopy(params)
        self.n_labels = 15
        self.contra_pred_mode = "prob_y"
        self.d_model = 256
        # The maximum length of element sequence.
        self.max_len = 80
        # The depth of encoder.
        self.encoder_depth = 10
        # The number of attention heads.
        self.n_heads = 10
        # The dimensions of attention head.
        self.d_head = 256
        # The dropout probability of attention weights.
        self.mha_dropout_prob = 0.
        # The dimensions of the hidden layer in ffn.
        self.d_ff = 512
        # The dropout probability of the hidden layer in ffn.
        self.ff_dropout_prob = [0, 0.5]
        # The dimensions of the hidden layer in fc block.
        self.d_fc = 128
        # The dropout probability of the hidden layer in fc block.
        self.fc_dropout_prob = 0.

        # Create trainable vars.
        self._init_trainable()

    """
    init funcs
    """
    # def _init_trainable func
    def _init_trainable(self):
        """
        Initialize trainable variables.

        Args:
            None

        Returns:
            None
        """
        # Initialize the embedding layer for input.
        # emb_input - (batch_size, seq_len, d_input) -> (batch_size, seq_len, d_model)
        # self.subject_block = SubjectBlock([204, 256, 256], 32, 0.3)
        # self.emb_input = K.layers.Dense(
        #     # Modified `Dense` layer parameters.
        #     units=self.d_model,
        #     kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
        #     bias_initializer=K.initializers.constant(value=0.01),
        #     # Default `Dense` layer parameters.
        #     activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None,
        #     activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        # )
        self.image_emb_input = SubjectBlock([204, 512, 512], 32, 0.3)
        self.image_emb_pos = PositionEmbedding(self.max_len)
        self.audio_emb_input = SubjectBlock([204, 512, 512], 32, 0.3)
        self.audio_emb_pos = PositionEmbedding(self.max_len)
        self.sleep_emb_input = SubjectBlock([204, 512, 512], 32, 0.3)
        self.sleep_emb_pos = PositionEmbedding(self.max_len)
        # Initializez the position embedding layer.
        # emb_pos - (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        # self.emb_pos = PositionEmbedding(self.max_len)
        # Initialize encoder block.
        # encoder - (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        self.encoder = K.models.Sequential(layers=[TransformerBlock(
            n_heads=self.n_heads, d_head=self.d_head, mha_dropout_prob=self.mha_dropout_prob,
            d_ff=self.d_ff, ff_dropout_prob=self.ff_dropout_prob) for _ in range(self.encoder_depth)
        ], name="encoder")
        # Initialize fc block.
        self.feature_block = K.models.Sequential()
        self.feature_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=256, activation="relu", 
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
            bias_initializer=K.initializers.constant(value=0.01),
            # Default `Dense` layer parameters.
            use_bias=True, kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        # self.feature_block.add(K.layers.Dense(
        #     # Modified `Dense` layer parameters.
        #     units=128, activation="relu", 
        #     kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
        #     bias_initializer=K.initializers.constant(value=0.01),
        #     # Default `Dense` layer parameters.
        #     use_bias=True, kernel_regularizer=None, bias_regularizer=None,
        #     activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        # ))
        self.awake_contrastive_block = LossLayer(d_contra=64, loss_mode="clip_orig")
        self.sleep_contrastive_block = LossLayer(d_contra=64, loss_mode="clip_orig")
        # self.fc = K.models.Sequential([
        #     K.layers.Flatten(),
        #     K.layers.Dense(
        #         # Modified `Dense` layer parameters.
        #         units=self.params.d_fc, activation="relu",
        #         kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
        #         bias_initializer=K.initializers.constant(value=0.01),
        #         # Default `Dense` layer parameters.
        #         use_bias=True, kernel_regularizer=None, bias_regularizer=None,
        #         activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        #     ),
        #     K.layers.Dropout(rate=self.params.fc_dropout_prob),
        #     K.layers.Dense(
        #         # Modified `Dense` layer parameters.
        #         units=self.params.n_labels,
        #         kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
        #         bias_initializer=K.initializers.constant(value=0.01),
        #         # Default `Dense` layer parameters.
        #         activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None,
        #         activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        #     ), K.layers.Softmax(axis=-1),
        # ], name="fc")
        self.sleep_classification_block = K.models.Sequential(name="fc")
        self.sleep_classification_block.add(K.layers.Dropout(rate=0.))
        self.sleep_classification_block.add(K.layers.Flatten())
        self.sleep_classification_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=128, activation="relu", 
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
            bias_initializer=K.initializers.constant(value=0.01),
            # Default `Dense` layer parameters.
            use_bias=True, kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.sleep_classification_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=self.n_labels,
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
                bias_initializer=K.initializers.constant(value=0.01),
                # Default `Dense` layer parameters.
                activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None,
                activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.sleep_classification_block.add(K.layers.Softmax(axis=-1))

        self.awake_image_classification_block = K.models.Sequential(name="fc")
        self.awake_image_classification_block.add(K.layers.Dropout(rate=0.))
        self.awake_image_classification_block.add(K.layers.Flatten())
        self.awake_image_classification_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=128, activation="relu", 
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
            bias_initializer=K.initializers.constant(value=0.01),
            # Default `Dense` layer parameters.
            use_bias=True, kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.awake_image_classification_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=self.n_labels,
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
                bias_initializer=K.initializers.constant(value=0.01),
                # Default `Dense` layer parameters.
                activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None,
                activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.awake_image_classification_block.add(K.layers.Softmax(axis=-1))

        self.awake_audio_classification_block = K.models.Sequential(name="fc")
        self.awake_audio_classification_block.add(K.layers.Dropout(rate=0.))
        self.awake_audio_classification_block.add(K.layers.Flatten())
        self.awake_audio_classification_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=128, activation="relu", 
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
            bias_initializer=K.initializers.constant(value=0.01),
            # Default `Dense` layer parameters.
            use_bias=True, kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.awake_audio_classification_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=self.n_labels,
            kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
                bias_initializer=K.initializers.constant(value=0.01),
                # Default `Dense` layer parameters.
                activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None,
                activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.awake_audio_classification_block.add(K.layers.Softmax(axis=-1))

    """
    network funcs
    """
    # def call func
    def call(self, inputs, training=None, mask=None):
        """
        Forward `naive_transformer` to get the final predictions.

        Args:
            inputs: tuple - The input data.
            training: Boolean or boolean scalar tensor, indicating whether to run
                the `Network` in training mode or inference mode.
            mask: A mask or list of masks. A mask can be either a tensor or None (no mask).

        Returns:
            outputs: (batch_size, n_labels) - The output labels.
            loss: tf.float32 - The corresponding loss.
        """
        # Initialize components of inputs.
        # X - (batch_size, seq_len, n_channels)
        # y - (batch_size,)
        # X = inputs[0]; y_true = inputs[1]
        image_data = inputs[0] ; audio_data = inputs[1] ; sleep_data = inputs[2] ; true_label = inputs[3] ; subject_id = inputs[4]
        # Forward convolution layer to get extracted feature maps, then aggregate to get features.
        # X_f - (batch_size, n_features)
        # outputs = self.subject_block((X, locations, subject_id))
        awake_image_emb = self.image_emb_pos(self.image_emb_input(image_data))
        awake_image_emb = self.encoder(awake_image_emb)
        image = self.feature_block(awake_image_emb)
        awake_audio_emb = self.audio_emb_pos(self.audio_emb_input(audio_data))
        awake_audio_emb = self.encoder(awake_audio_emb)
        audio = self.feature_block(awake_audio_emb)
        sleep_emb = self.sleep_emb_pos(self.sleep_emb_input(sleep_data))
        sleep_emb = self.encoder(sleep_emb)
        sleep = self.feature_block(sleep_emb)
        awake_loss_contra, _ = self.awake_contrastive_block((audio, image))
        sleep_loss_contra, contra_matrix = self.sleep_contrastive_block((sleep, audio))

        if self.contra_pred_mode == "max_z":
            # contra_idxs - (batch_size,), y_pred_contra - (batch_size, n_labels)
            contra_idxs = tf.squeeze(tf.argmax(contra_matrix, axis=0))
            y_pred_contra = tf.gather(y_true, indices=contra_idxs, axis=0)
        # If use `prob_z` prediction mode, assign each label with the weighted probability of all data items.
        elif self.contra_pred_mode == "prob_z":
            # contra_prob - (batch_size, batch_size), y_pred_contra - (batch_size, n_labels)
            contra_prob = tf.transpose(contra_matrix / tf.reduce_sum(contra_matrix, axis=0, keepdims=True), perm=[1,0])
            y_pred_contra = tf.matmul(contra_prob,true_label)
        # If use `max_y` prediction mode, assign each data item with the label with max similarity.
        elif self.contra_pred_mode == "max_y":
            # contra_idxs - (batch_size,), y_pred_contra - (batch_size, n_labels)
            contra_idxs = tf.squeeze(tf.argmax(contra_matrix, axis=-1))
            y_pred_contra = tf.gather(y_true, indices=contra_idxs, axis=0)
        # If use `prob_y` prediction mode, assign each data item with the weighted probability of all labels.
        elif self.contra_pred_mode == "prob_y":
            # contra_prob - (batch_size, batch_size), y_pred_contra - (batch_size, n_labels)
            contra_prob = contra_matrix / tf.reduce_sum(contra_matrix, axis=-1, keepdims=True)
            y_pred_contra = tf.matmul(contra_prob,true_label)
            y_pred_clip = tf.nn.softmax(y_pred_contra, axis = -1)
            loss_clip_pred = tf.reduce_mean(self._loss_bce(y_pred_clip, true_label))
        # Get unknown contrastive prediction mode.
        else:
            raise ValueError((
                "ERROR: Unknown contrastive prediction mode {} in train.conv_net."
            ).format(self.contra_pred_mode))
        # Calculate the binary cross entropy loss.
        # loss - tf.float32
        sleep_pred = self.sleep_classification_block(sleep)
        loss_sleep = tf.reduce_mean(self._loss_bce(sleep_pred, true_label))
        image_pred = self.awake_image_classification_block(image)
        loss_image = tf.reduce_mean(self._loss_bce(image_pred, true_label))
        audio_pred = self.awake_audio_classification_block(audio)
        loss_audio = tf.reduce_mean(self._loss_bce(audio_pred, true_label))

        loss = ((loss_sleep + loss_image + loss_audio)/3) + ((awake_loss_contra + sleep_loss_contra) / 2)
        return loss,sleep_pred, y_pred_clip

    # def _loss_bce func
    @utils.model.tf_scope
    def _loss_bce(self, value, target):
        """
        Calculates binary cross entropy between tensors value and target.
        Get mean over last dimension to keep losses of different batches separate.
        :param value: (batch_size,) - Value of the object.
        :param target: (batch_size,) - Target of the object.
        :return loss: (batch_size,) - Loss between value and target.
        """
        # Note: `tf.nn.softmax_cross_entropy_with_logits` needs unscaled log probabilities,
        # we must not add `tf.nn.Softmax` layer at the last of the model.
        # loss - (batch_size,)
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=target,logits=value) if type(value) is not list else\
            [tf.nn.softmax_cross_entropy_with_logits(labels=target[i],logits=value[i]) for i in range(len(value))]
        # Return the final `loss`.
        return loss

    """
    tool funcs
    """
    # def summary func
    @utils.model.tf_scope
    def summary(self, print_fn=None):
        """
        Summary built model.
        :param print_fn: callable - Print function to use. Defaults to `print`. It will be called on each
            line of the summary. You can set it to a custom function in order to capture the string summary.
        """
        super(naive_transformer, self).summary(
            # Modified summary parameters.
            print_fn=print_fn,
            # Default summary parameters.
            line_length=None, positions=None, expand_nested=True, show_trainable=True, layer_range=None
        )

if __name__ == "__main__":
    import numpy as np
    # local dep
    from params.naive_transformer_params import naive_transformer_params

    # Initialize training process.
    utils.model.set_seeds(42)

    # Instantiate params.
    naive_transformer_params_inst = naive_transformer_params(dataset=dataset)
    n_labels = naive_transformer_params_inst.model.n_labels
    # Instantiate naive_transformer.
    naive_transformer_inst = naive_transformer(naive_transformer_params_inst.model)
    # Initialize input X & label y_true.
    X = tf.random.normal((batch_size, seq_len, n_channels), dtype=tf.float32)
    y_true = tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), n_labels), dtype=tf.float32)
    # Forward layers in `naive_transformer_inst`.
    outputs, loss = naive_transformer_inst((X, y_true)); naive_transformer_inst.summary()

