
import copy as cp
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.getcwd()))
    print(sys.path)
    from layers import *
else:
    from .layers import *
import utils

__all__ = [
    "conv_net",
]

class conv_net(K.Model):
    """
    `conv_net` model specified for meg decoding.
    """

    def __init__(self, **kwargs):
        """
        Initialize `conv_net` object.

        Args:
            params: Model parameters initialized by conv_net_params, updated by params.iteration.
            kwargs: The arguments related to initialize `tf.keras.Model`-style object.

        Returns:
            None
        """# First call super class init function to set up `K.Model`
        # style model and inherit it's functionality.
        super(conv_net, self).__init__(**kwargs)

        # Copy hyperparameters (e.g. network sizes) from parameter dotdict,
        # usually generated from conv_net_params() in params/conv_net_params.py.
        # self.params = cp.deepcopy(params)
        self.n_labels = 15
        self.contra_pred_mode = "prob_y"
        # Create trainable vars.
        self._init_trainable()

    """
    init funcs
    """
    # def _init_trainable func
    def _init_trainable(self):
        """
        Initialize trainable variables.

        Args:
            None

        Returns:
            None
        """
        # Initialize convolution block.
        self.subject_block = SubjectBlock([204, 256, 256], 32, 0.3)
        self.conv_block = K.models.Sequential()
        # self.conv_block.add(K.layers.Conv1D(
        #     # Modified `Conv1D` layer parameters.
        #     filters=64, kernel_size=3, padding="same", dilation_rate=1,
        #     # Default `Conv1D` layer parameters.
        #     strides=1, data_format="channels_last", groups=1, activation=None, use_bias=True,
        #     kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
        #     bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        # ))
        self.conv_block.add(K.layers.Conv1D(
            # Modified `Conv1D` layer parameters.
            filters=128, kernel_size=5, padding="same", dilation_rate=1,
            # Default `Conv1D` layer parameters.
            strides=1, data_format="channels_last", groups=1, activation=None, use_bias=True,
            kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.conv_block.add(K.layers.Conv1D(
            # Modified `Conv1D` layer parameters.
            filters=256, kernel_size=7, padding="same", dilation_rate=1,
            # Default `Conv1D` layer parameters.
            strides=1, data_format="channels_last", groups=1, activation=None, use_bias=True,
            kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        # self.conv_block.add(K.layers.AveragePooling1D(
        #     # Modified `MaxPool1D` layer parameters.
        #     pool_size=2,
        #     # Default `MaxPool1D` layer parameters.
        #     strides=None, padding="same", data_format="channels_last"
        # ))
        self.conv_block.add(K.layers.Conv1D(
            # Modified `Conv1D` layer parameters.
            filters=512, kernel_size=11, padding="same", dilation_rate=2,
            # Default `Conv1D` layer parameters.
            strides=1, data_format="channels_last", groups=1, activation=None, use_bias=True,
            kernel_initializer="he_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        # self.conv_block.add(K.layers.Conv1D(
        #     # Modified `Conv1D` layer parameters.
        #     filters=256, kernel_size=11, padding="same", dilation_rate=2,
        #     # Default `Conv1D` layer parameters.
        #     strides=1, data_format="channels_last", groups=1, activation=None, use_bias=True,
        #     kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
        #     bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        # ))
        self.conv_block.add(K.layers.Conv1D(
            # Modified `Conv1D` layer parameters.
            filters=256, kernel_size=9, padding="same", dilation_rate=2,
            # Default `Conv1D` layer parameters.
            strides=1, data_format="channels_last", groups=1, activation=None, use_bias=True,
            kernel_initializer="he_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.conv_block.add(K.layers.Conv1D(
            # Modified `Conv1D` layer parameters.
            filters=64, kernel_size=3, padding="same", dilation_rate=2,
            # Default `Conv1D` layer parameters.
            strides=1, data_format="channels_last", groups=1, activation=None, use_bias=True,
            kernel_initializer="he_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.conv_block.add(K.layers.MaxPool1D(
            # Modified `MaxPool1D` layer parameters.
            pool_size=4,
            # Default `MaxPool1D` layer parameters.
            strides=None, padding="valid", data_format="channels_last"
        ))
        self.conv_block.add(K.layers.Dropout(rate=0.3))
        self.conv_block.add(K.layers.BatchNormalization(
            # Default `BatchNormalization` parameters.
            axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer="zeros",
            gamma_initializer="ones", moving_mean_initializer="zeros", moving_variance_initializer="ones",
            beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None
        ))
        # Initialize feature block.
        self.feature_block = K.models.Sequential()
        # self.feature_block.add(K.layers.Flatten())
        self.feature_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=160,
            # Default `Dense` layer parameters.
            activation=None, use_bias=True, kernel_initializer="he_uniform",
            bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        # Initialize contrastive block.
        self.contrastive_block = LossLayer(d_contra=128, loss_mode="clip_orig")
        # self.contrastive_block = LossLayer(d_contra=128, loss_mode="clip")
        # Initialize classification block.
        self.classification_block = K.models.Sequential()
        self.classification_block.add(K.layers.Dropout(rate=0.5))
        self.classification_block.add(K.layers.Flatten())
        self.classification_block.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            units=self.n_labels,
            # Default `Dense` layer parameters.
            activation=None, use_bias=True, kernel_initializer="he_uniform",
            bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        self.classification_block.add(K.layers.Softmax(axis=-1))

    """
    network funcs
    """
    # def call func
    def call(self, inputs, training=None, mask=None):
        """
        Forward `conv_net` to get the final predictions.

        Args:
            inputs: tuple - The input data.
            training: Boolean or boolean scalar tensor, indicating whether to run
                the `Network` in training mode or inference mode.
            mask: A mask or list of masks. A mask can be either a tensor or None (no mask).

        Returns:
            outputs: (batch_size, n_labels) - The output labels.
            loss: tf.float32 - The corresponding loss.
        """
        # Initialize components of inputs.
        # X - (batch_size, seq_len, n_channels)
        # y - (batch_size,)
        # Y_f - (batch_size, seq_len, n_channels)
        X = inputs[0]; y_true = inputs[1]; Y_f = inputs[2]; subject_id = inputs[3]
        # X = tf.cast(X, dtype = tf.float64) ; y_true = tf.cast(y_true, dtype = tf.float64) ; Y_f = tf.cast(Y_f, dtype = tf.float64)
        # Forward convolution layer to get extracted feature maps, then aggregate to get features.
        # X_f - (batch_size, n_features)
        # outputs = self.subject_block((X, locations, subject_id))
        outputs = self.subject_block((X, subject_id))
        X_f = self.feature_block(self.conv_block(outputs))
        # X_f = self.feature_block(self.conv_block(X))
        # Forward contrastive block to calculate contrastive loss.
        # loss_contra - tf.float32
        # contra_matrix - (batch_size, batch_size), the first dimension is z, and the second dimension is y.
        # loss_contra, contra_matrix = self.contrastive_block((Y_f, X_f))

        loss_contra, contra_matrix = self.contrastive_block((X_f, Y_f))

        #     # contra_prob - (batch_size, batch_size), y_pred_contra - (batch_size, n_labels)
        #     contra_prob = tf.transpose(contra_matrix / tf.reduce_sum(contra_matrix, axis=0, keepdims=True), perm=[1,0])
        #     contra_prob1 = tf.transpose(contra_matrix / tf.reduce_sum(contra_matrix, axis=1, keepdims=True), perm=[1,0])

        #     y_pred_contra = tf.matmul(contra_prob, y_true)
        #     y_pred_contra1 = tf.matmul(contra_prob1, y_true)
        # """
        if self.contra_pred_mode == "max_z":
            # contra_idxs - (batch_size,), y_pred_contra - (batch_size, n_labels)
            contra_idxs = tf.squeeze(tf.argmax(contra_matrix, axis=0))
            y_pred_contra = tf.gather(y_true, indices=contra_idxs, axis=0)
        # If use `prob_z` prediction mode, assign each label with the weighted probability of all data items.
        elif self.contra_pred_mode == "prob_z":
            # contra_prob - (batch_size, batch_size), y_pred_contra - (batch_size, n_labels)
            contra_prob = tf.transpose(contra_matrix / tf.reduce_sum(contra_matrix, axis=0, keepdims=True), perm=[1,0])
            y_pred_contra = tf.matmul(contra_prob, y_true)
        # If use `max_y` prediction mode, assign each data item with the label with max similarity.
        elif self.contra_pred_mode == "max_y":
            # contra_idxs - (batch_size,), y_pred_contra - (batch_size, n_labels)
            contra_idxs = tf.squeeze(tf.argmax(contra_matrix, axis=-1))
            y_pred_contra = tf.gather(y_true, indices=contra_idxs, axis=0)
        # If use `prob_y` prediction mode, assign each data item with the weighted probability of all labels.
        elif self.contra_pred_mode == "prob_y":
            # contra_prob - (batch_size, batch_size), y_pred_contra - (batch_size, n_labels)
            contra_prob = contra_matrix / tf.reduce_sum(contra_matrix, axis=-1, keepdims=True)
            y_pred_contra = tf.matmul(contra_prob, y_true)
        # Get unknown contrastive prediction mode.
        else:
            raise ValueError((
                "ERROR: Unknown contrastive prediction mode {} in train.conv_net."
            ).format(self.contra_pred_mode))
        # """
        # Forward classification block to do classification.
        # Note: The layer before the final fc layers is considered as feature vectors.
        # y_pred_bce - (batch_size, n_labels)
        y_pred_bce = self.classification_block(X_f)
        # Calculate the binary cross entropy loss.
        # loss_bce - tf.float32
        loss_bce = tf.reduce_mean(self._loss_bce(y_pred_bce, y_true))
        # Calculate the final `y_pred` & `loss`.
        # y_pred - (batch_size, n_labels); loss - tf.float32
        # y_pred = (1 - self.params.contra_pred_ratio) * y_pred_bce + self.params.contra_pred_ratio * y_pred_contra
        # y_pred = y_pred_bce + 0.01 * y_pred_contra
        # loss = (1 - self.params.contra_loss_ratio) * loss_bce+\
        #     self.params.contra_loss_ratio * self.params.contra_loss_scale * loss_contra
        loss = loss_bce + loss_contra
        # loss = loss_bce
        # Return the final `outputs` & `loss`.
        # return loss,y_pred_contra
        # return loss,y_pred_contra, y_pred_contra1
        return loss,y_pred_bce

    # def _loss_bce func
    @utils.model.tf_scope
    def _loss_bce(self, value, target):
        """
        Calculates binary cross entropy between tensors value and target.
        Get mean over last dimension to keep losses of different batches separate.
        :param value: (batch_size,) - Value of the object.
        :param target: (batch_size,) - Target of the object.
        :return loss: (batch_size,) - Loss between value and target.
        """
        # Note: `tf.nn.softmax_cross_entropy_with_logits` needs unscaled log probabilities,
        # we must not add `tf.nn.Softmax` layer at the last of the model.
        # loss - (batch_size,)
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=target,logits=value) if type(value) is not list else\
            [tf.nn.softmax_cross_entropy_with_logits(labels=target[i],logits=value[i]) for i in range(len(value))]
        # Return the final `loss`.
        return loss

if __name__ == "__main__":
    import numpy as np
    # local dep
    # from params import conv_net_params

    # macro
    dataset = "meg_anonymous"; batch_size = 16; seq_len = 160; n_features = 128

    # Initialize training process.
    utils.model.set_seeds(42)

    # Instantiate params.
    # conv_net_params_inst = conv_net_params(dataset=dataset)
    # n_channels = conv_net_params_inst.model.n_channels; n_labels = conv_net_params_inst.model.n_labels
    n_channels = 204 ; n_labels = 6
    # Instantiate conv_net.
    conv_net_inst = conv_net()
    conv_net_inst.conv_block.trainable = False
    conv_net_inst.feature_block.trainable = False
    print(conv_net_inst.feature_block.layers)
    k
    for layer in conv_net_inst.layers:
        print(layer.name)
    
    # Initialize input X & label y_true & representation Y_f.
    X = tf.random.normal((batch_size, seq_len, n_channels), dtype=tf.float32)
    y_true = tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), n_labels), dtype=tf.float32)
    Y_f = tf.random.normal((batch_size, n_features), dtype=tf.float32)
    locations = tf.random.normal((batch_size, n_channels, 2), dtype=tf.float32)
    id = tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), 6), dtype=tf.float32)
    # Forward layers in `conv_net_inst`.
    inputs = (X, y_true, Y_f, locations, id)
    outputs, loss,a = conv_net_inst(inputs)
    # conv_net_inst.summary()

