#!/usr/bin/env python3
"""
Created on 20:33, Aug. 5th, 2023

@author: Anonymous
"""
import copy as cp
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir, os.pardir))
    from layers import *
else:
    from .layers import *
import utils.model

__all__ = [
    "SubdomainContrastiveConvNet",
]

class SubdomainContrastiveConvNet(K.Model):
    """
    `SubdomainContrastiveConvNet` model, with considering time information.
    """

    def __init__(self, params, **kwargs):
        """
        Initialize `SubdomainContrastiveConvNet` object.

        Args:
            params: Model parameters initialized by domain_contrastive_conv_params, updated by params.iteration.
            kwargs: The arguments related to initialize `tf.keras.Model`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.Model`
        # style model and inherit it's functionality.
        super(SubdomainContrastiveConvNet, self).__init__(**kwargs)

        # Copy hyperparameters (e.g. network sizes) from parameter dotdict,
        # usually generated from domain_contrastive_conv_params() in params/domain_adaptation_params.py.
        self.params = cp.deepcopy(params)

        # Create trainable vars.
        self._init_trainable()

    """
    init funcs
    """
    # def _init_trainable func
    def _init_trainable(self):
        """
        Initialize trainable variables.

        Args:
            None

        Returns:
            None
        """
        ## Initialize `modality_mask` according to `params.n_modalities`.
        self.modality_mask = [True for _ in range(self.params.n_modalities)]
        ## Construct cnn blocks.
        # Initialize cnn blocks.
        self.cnn_blocks = []
        # Construct cnn block for each modality.
        for modality_idx in range(self.params.n_modalities if self.params.use_siamese else 1):
            # Check whether modality exist, if not, add `None`.
            if not self.modality_mask[modality_idx]:
                self.cnn_blocks.append(None); continue
            # Initialize trainable cnn layers.
            model_cnn = K.models.Sequential(name="CNN-{:d}".format(modality_idx))
            # Add `Conv1D` & `MaxPool1D` layers.
            for cnn_idx in range(len(self.params.cnn.n_filters)):
                # Initialize `Conv1D` layer. `tf.keras.layers.Conv1D` is different from `torch.nn.Conv1d`. It doesn't have
                # `in_channels` argument. And `filters` argument equals to `out_channels` argument.
                out_channels, kernel_size = self.params.cnn.n_filters[cnn_idx], self.params.cnn.d_kernel[cnn_idx]
                strides, padding = self.params.cnn.strides[cnn_idx], self.params.cnn.padding[cnn_idx]
                dilation_rate = self.params.cnn.dilation_rate[cnn_idx]
                model_cnn.add(K.layers.Conv1D(
                    # Modified `Conv1D` layer parameters.
                    filters=out_channels, kernel_size=kernel_size, strides=strides,
                    padding=padding, dilation_rate=dilation_rate, name="Conv1D_{:d}".format(cnn_idx),
                    # Default `Conv1D` layer parameters.
                    data_format="channels_last", groups=1, activation=None, use_bias=True,
                    kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
                    bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
                ))
                # Initialize `MaxPool1D` layer.
                if isinstance(self.params.cnn.d_pooling_kernel, list):
                    kernel_size = self.params.cnn.d_pooling_kernel[cnn_idx]
                    model_cnn.add(K.layers.MaxPool1D(
                        # Modified `MaxPool1D` layer parameters.
                        pool_size=kernel_size, strides=1, name="MaxPool1D_{:d}".format(cnn_idx),
                        # Default `MaxPool1D` layer parameters.
                        padding="valid", data_format="channels_last"
                    ))
                else:
                    # Only add `MaxPool1D` layer at the last layer of cnn.
                    if cnn_idx == len(self.params.cnn.n_filters) - 1:
                        kernel_size = self.params.cnn.d_pooling_kernel
                        model_cnn.add(K.layers.MaxPool1D(
                            # Modified `MaxPool1D` layer parameters.
                            pool_size=kernel_size, name="MaxPool1D_{:d}".format(cnn_idx),
                            # Default `MaxPool1D` layer parameters.
                            strides=None, padding="valid", data_format="channels_last"
                        ))
            # Add `Dropout` after `MaxPool1D` layer.
            if self.params.cnn.dropout > 0.:
                model_cnn.add(K.layers.Dropout(rate=self.params.cnn.dropout, name="Dropout_{}".format("cnn")))
            # Add `BatchNormalization` at the last layer of cnn layers.
            model_cnn.add(K.layers.BatchNormalization(
                # Modified `BatchNormalization` parameters.
                name="BatchNormalization_{}".format("cnn"),
                # Default `BatchNormalization` parameters.
                axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer="zeros",
                gamma_initializer="ones", moving_mean_initializer="zeros", moving_variance_initializer="ones",
                beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None
            ))
            # Append `model_cnn` to `cnn_blocks`.
            self.cnn_blocks.append(model_cnn)
        ## Construct feature blocks.
        # Initialize feature blocks.
        self.feature_blocks = []
        # Construct feature block for each modality.
        for modality_idx in range(self.params.n_modalities if self.params.use_siamese else 1):
            # Check whether modality exists, if not, add `None`.
            if not self.modality_mask[modality_idx]:
                self.feature_blocks.append(None); continue
            # Initialize feature block.
            model_feature = K.models.Sequential(name="feature-{:d}".format(modality_idx))
            model_feature.add(K.layers.Flatten(data_format="channels_last"))
            # Add `Dropout` after hidden `Dense` layer.
            if self.params.feature.dropout > 0.:
                model_feature.add(K.layers.Dropout(rate=self.params.feature.dropout))
            # Add hidden `Dense` layers.
            for d_hidden_i in self.params.feature.d_hidden:
                model_feature.add(K.layers.Dense(
                    # Modified `Dense` layer parameters.
                    units=d_hidden_i, activation="relu",
                    kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
                    bias_initializer=K.initializers.constant(value=0.01),
                    # Default `Dense` layer parameters.
                    use_bias=True, kernel_regularizer=None, bias_regularizer=None,
                    activity_regularizer=None, kernel_constraint=None, bias_constraint=None
                ))
            # Append `model_feature` to `feature_blocks`.
            self.feature_blocks.append(model_feature)
        ## Construct contrastive blocks.
        self.contrastive_blocks = [(LossLayer(d_contra=self.params.d_contra, data_mode=self.params.contra_data_mode,
            loss_mode=self.params.contra_loss_mode, name="contrastive-{:d}".format(modality_idx)
        ) if (self.modality_mask[modality_idx] and self.modality_mask[modality_idx+1]) else None
        ) for modality_idx in range(self.params.n_modalities - 1)]
        ## Construct classification blocks.
        # Initialize classification blocks.
        self.classification_blocks = []
        # Construct classification block for each modality.
        for modality_idx in range(self.params.n_modalities):
            # Check whether modality exists, if not, add `None`.
            if not self.modality_mask[modality_idx]:
                self.classification_blocks.append(None); continue
            # Initialize classification block.
            model_classification = K.models.Sequential(name="classification-{:d}".format(modality_idx))
            # Add hidden `Dense` layers.
            for d_hidden_i in self.params.cls.d_hidden:
                model_classification.add(K.layers.Dense(
                    # Modified `Dense` layer parameters.
                    units=d_hidden_i, activation="relu",
                    kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
                    bias_initializer=K.initializers.constant(value=0.01),
                    # Default `Dense` layer parameters.
                    use_bias=True, kernel_regularizer=None, bias_regularizer=None,
                    activity_regularizer=None, kernel_constraint=None, bias_constraint=None
                ))
            # Add `Dropout` after hidden `Dense` layer.
            if self.params.cls.dropout > 0.:
                model_classification.add(K.layers.Dropout(rate=self.params.cls.dropout))
            # Add the final classification `Dense` layer.
            model_classification.add(K.layers.Dense(
                # Modified `Dense` layer parameters.
                units=self.params.cls.d_output,
                kernel_initializer=K.initializers.random_normal(mean=0., stddev=0.01),
                bias_initializer=K.initializers.constant(value=0.01),
                # Default `Dense` layer parameters.
                activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None,
                activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            )); model_classification.add(K.layers.Softmax(axis=-1))
            # Append `model_classification` to `classification_blocks`.
            self.classification_blocks.append(model_classification)

    """
    network funcs
    """
    # def call func
    def call(self, inputs, training=None, mask=None):
        """
        Forward `SubdomainContrastiveConvNet` to get the final predictions.

        Args:
            inputs: tuple - The input data.
            training: Boolean or boolean scalar tensor, indicating whether to run
                the `Network` in training mode or inference mode.
            mask: A mask or list of masks. A mask can be either a tensor or None (no mask).

        Returns:
            outputs: (batch_size, n_labels) - The output labels.
            loss: tf.float32 - The corresponding loss.
        """
        # Initialize components of inputs.
        # X - (n_modalities[list], batch_size, seq_len, n_channels)
        # y - (n_modalities[list], batch_size, n_labels)
        X = [modality_i[0] for modality_i in inputs]; y_true = [modality_i[1] for modality_i in inputs]
        # Make sure that `X_tmr` & `X_aud` is not None.
        for modality_idx in range(self.params.n_modalities):
            if not self.modality_mask[modality_idx]: assert X[modality_idx] is None
        # Forward encoder layer to get extracted feature maps, then aggregate to get features.
        # X_f - (n_modalities[list], batch_size, n_features)
        X_f = [(self.feature_blocks[modality_idx if self.params.use_siamese else 0](
            self.cnn_blocks[modality_idx if self.params.use_siamese else 0](X[modality_idx])
        ) if X[modality_idx] is not None else None) for modality_idx in range(self.params.n_modalities)]
        # Forward classification block to do classification.
        # Note: The layer before the final fc layers is considered as feature vectors.
        # y_pred_bce - (n_modalities[list], batch_size, n_labels)
        y_pred_bce = [(self.classification_blocks[modality_idx](X_f[modality_idx])\
            if X_f[modality_idx] is not None else None) for modality_idx in range(self.params.n_modalities)]
        # Calculate the binary cross entropy loss.
        # loss_bce - tf.float32
        loss_bce = [(tf.reduce_mean(self._loss_bce(y_pred_bce[modality_idx], y_true[modality_idx]))\
            if y_pred_bce[modality_idx] is not None else None)\
            for modality_idx in range(self.params.n_modalities)]
        loss_bce_mask = tf.constant([(loss_bce[modality_idx] is not None)\
            for modality_idx in range(self.params.n_modalities)], dtype=tf.bool)
        loss_bce_weight = tf.cast(loss_bce_mask, dtype=tf.float32) / tf.reduce_sum(tf.cast(loss_bce_mask, dtype=tf.float32))
        loss_bce = tf.squeeze(tf.matmul(
            tf.expand_dims(tf.stack([(loss_i if loss_i is not None else\
                tf.zeros(())) for loss_i in loss_bce], axis=0), axis=0),
            tf.expand_dims(loss_bce_weight, axis=-1)
        ))
        # Get the final `y_pred_bce` of `X_tmr`.
        # y_pred_bce - (batch_size, n_labels)
        y_pred_bce = y_pred_bce[0]
        # Forward contrastive block to calculate contrastive loss.
        # TODO: Make sure that the first modality is tmr, the second modality is image!!!
        # loss_contra - tf.float32
        # contra_matrix - (batch_size, batch_size), the first dimension is z, and the second dimension is y.
        loss_contra = []; contra_matrix = []
        for modality_idx in range(self.params.n_modalities - 1):
            loss_contra_i, contra_matrix_i = self.contrastive_blocks[modality_idx](
                ((X_f[modality_idx], X_f[modality_idx+1]), (y_true[modality_idx], y_true[modality_idx+1]))
            ) if (X_f[modality_idx] is not None) and (X_f[modality_idx+1] is not None) else (None, None)
            loss_contra.append(loss_contra_i); contra_matrix.append(contra_matrix_i)
        contra_matrix = contra_matrix[0]; loss_contra_mask = tf.constant([(
            (loss_contra[modality_idx] is not None)
        ) for modality_idx in range(self.params.n_modalities - 1)], dtype=tf.bool)
        loss_contra_weight = tf.cast(loss_contra_mask, dtype=tf.float32) /\
            tf.reduce_sum(tf.cast(loss_contra_mask, dtype=tf.float32))
        loss_contra = tf.squeeze(tf.matmul(
            tf.expand_dims(tf.stack([(loss_i if loss_i is not None else\
                tf.zeros(())) for loss_i in loss_contra], axis=0), axis=0),
            tf.expand_dims(loss_contra_weight, axis=-1)
        ))
        # If use `max_z` prediction mode, assign each label with the data item with max similarity.
        if self.params.contra_pred_mode == "max_z":
            # contra_idxs - (batch_size,), y_pred_contra - (batch_size, n_labels)
            contra_idxs = tf.squeeze(tf.argmax(contra_matrix, axis=0))
            y_pred_contra = tf.gather(y_true[0], indices=contra_idxs, axis=0)
        # If use `prob_z` prediction mode, assign each label with the weighted probability of all data items.
        elif self.params.contra_pred_mode == "prob_z":
            # contra_prob - (batch_size, batch_size), y_pred_contra - (batch_size, n_labels)
            contra_prob = tf.transpose(contra_matrix / tf.reduce_sum(contra_matrix, axis=0, keepdims=True), perm=[1,0])
            y_pred_contra = tf.matmul(contra_prob, y_true[0])
        # If use `max_y` prediction mode, assign each data item with the label with max similarity.
        elif self.params.contra_pred_mode == "max_y":
            # contra_idxs - (batch_size,), y_pred_contra - (batch_size, n_labels)
            contra_idxs = tf.squeeze(tf.argmax(contra_matrix, axis=-1))
            y_pred_contra = tf.gather(y_true[1], indices=contra_idxs, axis=0)
        # If use `prob_y` prediction mode, assign each data item with the weighted probability of all labels.
        elif self.params.contra_pred_mode == "prob_y":
            # contra_prob - (batch_size, batch_size), y_pred_contra - (batch_size, n_labels)
            contra_prob = contra_matrix / tf.reduce_sum(contra_matrix, axis=-1, keepdims=True)
            y_pred_contra = tf.matmul(contra_prob, y_true[1])
        # Get unknown contrastive prediction mode.
        else:
            raise ValueError((
                "ERROR: Unknown contrastive prediction mode {} in train.conv_net."
            ).format(self.params.contra_pred_mode))
        # Calculate the final `y_pred` & `loss`.
        # TODO: Make sure that the first modality is tmr!!!
        # y_pred - (batch_size, n_labels); loss - tf.float32
        y_pred = (1 - self.params.contra_pred_ratio) * y_pred_bce + self.params.contra_pred_ratio * y_pred_contra
        loss = (1 - self.params.contra_loss_ratio) * loss_bce +\
            self.params.contra_loss_ratio * self.params.contra_loss_scale * loss_contra
        # Return the final `outputs` & `loss`.
        return y_pred, loss

    # def _loss_bce func
    @utils.model.tf_scope
    def _loss_bce(self, value, target):
        """
        Calculates binary cross entropy between tensors value and target.
        Get mean over last dimension to keep losses of different batches separate.
        :param value: (batch_size,) - Value of the object.
        :param target: (batch_size,) - Target of the object.
        :return loss: (batch_size,) - Loss between value and target.
        """
        # Note: `tf.nn.softmax_cross_entropy_with_logits` needs unscaled log probabilities,
        # we must not add `tf.nn.Softmax` layer at the last of the model.
        # loss - (batch_size,)
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=target,logits=value) if type(value) is not list else\
            [tf.nn.softmax_cross_entropy_with_logits(labels=target[i],logits=value[i]) for i in range(len(value))]
        # Return the final `loss`.
        return loss

    """
    tool funcs
    """
    # def summary func
    @utils.model.tf_scope
    def summary(self, print_fn=None):
        """
        Summary built model.
        :param print_fn: callable - Print function to use. Defaults to `print`. It will be called on each
            line of the summary. You can set it to a custom function in order to capture the string summary.
        """
        super(SubdomainContrastiveConvNet, self).summary(print_fn=print_fn)

if __name__ == "__main__":
    import numpy as np
    # local dep
    from params.domain_adaptation_params import subdomain_contrastive_conv_params

    # macro
    dataset = "eeg_anonymous"; batch_size = 64; seq_len = 80; n_channels = 55; n_features = 128

    # Initialize training process.
    utils.model.set_seeds(42)

    # Instantiate params.
    subdomain_contrastive_conv_params_inst = subdomain_contrastive_conv_params(dataset=dataset)
    n_labels = subdomain_contrastive_conv_params_inst.model.n_labels
    # Instantiate sdccn.
    sdccn_inst = SubdomainContrastiveConvNet(subdomain_contrastive_conv_params_inst.model)
    # Initialize input X & label y_true.
    X = [tf.random.normal((batch_size, seq_len, n_channels), dtype=tf.float32),
        tf.random.normal((batch_size, seq_len, n_channels), dtype=tf.float32),
        tf.random.normal((batch_size, seq_len, n_channels), dtype=tf.float32)]
    y_true = [tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), n_labels), dtype=tf.float32),
        tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), n_labels), dtype=tf.float32),
        tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), n_labels), dtype=tf.float32)]
    # Forward layers in `sdccn_inst`.
    outputs, loss = sdccn_inst(list(zip(X, y_true))); sdccn_inst.summary()

