#!/usr/bin/env python3
"""
Created on 21:15, Jul. 25th, 2023

@author: Anonymous
"""
import copy as cp
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))
import utils.model
from utils import DotDict

__all__ = [
    "SubdomainAdversarialConvNet",
]

# def SubdomainAdversarialConvNet class
class SubdomainAdversarialConvNet(K.Model):
    """
    Sub-domain Adversarial Convolution Neural Network.
    """

    def __init__(self, params):
        """
        Initialize `SubdomainAdversarialConvNet` object.

        Args:
            params: DotDict - Model parameters initialized by domain_adaptation_params, updated by params.iteration.

        Returns:
            None
        """
        # Initialize super, to get `K.Model`-style class.
        super(SubdomainAdversarialConvNet, self).__init__()

        # Initialize parameters.
        self.params = cp.deepcopy(params)

        # Initialize variables.
        self._init_trainable(); self._init_optimizer()

    """
    init funcs
    """
    # def _init_trainable func
    def _init_trainable(self):
        """
        Initialize trainable variables.

        Args:
            None

        Returns:
            None
        """
        ## Initialize trainable cnn layers.
        model_cnn = K.models.Sequential(name="CNN")
        # Add `Conv1D` & `MaxPool1D` layers.
        for cnn_idx in range(len(self.params.cnn.n_filters)):
            # Initialize `Conv1D` layer. `tf.keras.layers.Conv1D` is different from `torch.nn.Conv1d`. It doesn't have
            # `in_channels` argument. And `filters` argument equals to `out_channels` argument.
            out_channels, kernel_size = self.params.cnn.n_filters[cnn_idx], self.params.cnn.d_kernel[cnn_idx]
            strides, padding = self.params.cnn.strides[cnn_idx], self.params.cnn.padding[cnn_idx]
            dilation_rate = self.params.cnn.dilation_rate[cnn_idx]
            model_cnn.add(K.layers.Conv1D(
                # Modified `Conv1D` layer parameters.
                filters=out_channels, kernel_size=kernel_size, strides=strides,
                padding=padding, dilation_rate=dilation_rate, name="Conv1D_{:d}".format(cnn_idx),
                # Default `Conv1D` layer parameters.
                data_format="channels_last", groups=1, activation=None, use_bias=True,
                kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            ))
            # Initialize `MaxPool1D` layer.
            if isinstance(self.params.cnn.d_pooling_kernel, list):
                kernel_size = self.params.cnn.d_pooling_kernel[cnn_idx]
                model_cnn.add(K.layers.MaxPool1D(
                    # Modified `MaxPool1D` layer parameters.
                    pool_size=kernel_size, strides=1, name="MaxPool1D_{:d}".format(cnn_idx),
                    # Default `MaxPool1D` layer parameters.
                    padding="valid", data_format="channels_last"
                ))
            else:
                # Only add `MaxPool1D` layer at the last layer of cnn.
                if cnn_idx == len(self.params.cnn.n_filters) - 1:
                    kernel_size = self.params.cnn.d_pooling_kernel
                    model_cnn.add(K.layers.MaxPool1D(
                        # Modified `MaxPool1D` layer parameters.
                        pool_size=kernel_size, name="MaxPool1D_{:d}".format(cnn_idx),
                        # Default `MaxPool1D` layer parameters.
                        strides=None, padding="valid", data_format="channels_last"
                    ))
        # Add `Dropout` after `MaxPool1D` layer.
        if self.params.cnn.dropout > 0.:
            model_cnn.add(K.layers.Dropout(rate=self.params.cnn.dropout, name="Dropout_{}".format("cnn")))
        # Add `BatchNormalization` at the last layer of cnn layers.
        model_cnn.add(K.layers.BatchNormalization(
            # Modified `BatchNormalization` parameters.
            name="BatchNormalization_{}".format("cnn"),
            # Default `BatchNormalization` parameters.
            axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer="zeros",
            gamma_initializer="ones", moving_mean_initializer="zeros", moving_variance_initializer="ones",
            beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None
        )); self.cnn = model_cnn
        self.cnn.add(K.layers.Flatten(data_format="channels_last"))
        self.cnn.add(K.layers.Dense(
            # Modified `Dense` parameters.
            units=256, activation="relu",
            # Default `Dense` parameters.
            use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        ## Initialize trainable class classifier layers.
        model_cls_class = K.models.Sequential(name="CLS-class")
        # Flatten convolved features to 1D-vector.
        model_cls_class.add(K.layers.Flatten(data_format="channels_last"))
        # Add hidden `Dense` layers.
        for d_hidden_i in self.params.cls_class.d_hidden:
            model_cls_class.add(K.layers.Dense(
                # Modified `Dense` parameters.
                units=d_hidden_i, activation="relu",
                # Default `Dense` parameters.
                use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            ))
        # Add `Dropout` after hidden `Dense` layer.
        if self.params.cls_class.dropout > 0.:
            model_cls_class.add(K.layers.Dropout(rate=self.params.cls_class.dropout, name="Dropout_{}".format("fc")))
        # Add the final classification `Dense` layer.
        model_cls_class.add(K.layers.Dense(
            # Modified `Dense` parameters.
            units=self.params.cls_class.d_output, activation="sigmoid",
            # Default `Dense` parameters.
            use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        )); model_cls_class.add(K.layers.Softmax(axis=-1)); self.cls_class = model_cls_class
        ## Initialize trainable domain classifier layers.
        model_cls_domain = K.models.Sequential(name="CLS-domain")
        # Flatten convolved features to 1D-vector.
        model_cls_domain.add(K.layers.Flatten(data_format="channels_last"))
        # Add hidden `Dense` layers.
        for d_hidden_i in self.params.cls_domain.d_hidden:
             model_cls_domain.add(K.layers.Dense(
                # Modified `Dense` parameters.
                units=d_hidden_i, activation="relu",
                # Default `Dense` parameters.
                use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
            ))
        # Add `Dropout` after hidden `Dense` layer.
        if self.params.cls_domain.dropout > 0.:
            model_cls_domain.add(K.layers.Dropout(rate=self.params.cls_domain.dropout, name="Dropout_{}".format("fc")))
        # Add the final classification `Dense` layer.
        model_cls_domain.add(K.layers.Dense(
            # Modified `Dense` parameters.
            units=self.params.cls_domain.d_output, activation="sigmoid",
            # Default `Dense` parameters.
            use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        )); model_cls_domain.add(K.layers.Softmax(axis=-1)); self.cls_domain = model_cls_domain

    # def _init_optimizer func
    def _init_optimizer(self):
        """
        Initialize optimizers used to optimize different parts of model.

        Args:
            None

        Returns:
            None
        """
        # Initialize optimizer for class classifier.
        self.optimizer_class = K.optimizers.Adam(learning_rate=self.params.lr_class_i)
        # Initialize optimizer for domain classifier.
        self.optimizer_domain = K.optimizers.Adam(learning_rate=self.params.lr_domain_i)

    """
    update funcs
    """
    # def _update_params func
    def _update_params(self, params):
        """
        Update model parameters.

        Args:
            params: DotDict - The updated model parameters.

        Returns:
            None
        """
        ## Update iteration parameters of model.
        # If `params` is completely equal to the parameters of model, this is equal to setting `params` to None.
        for key_i in [key_i for key_i in utils.DotDict.iter_keys(params) if key_i[-1].endswith("_i")]:
            utils.DotDict.iter_setattr(self.params, key_i, utils.DotDict.iter_getattr(params, key_i))

    """
    network funcs
    """
    # def call func
    @tf.function
    def call(self, inputs, params=None, training=None, mask=None):
        """
        Forward `SubdomainAdversarialConvNet` to get the final predictions.

        Args:
            inputs: tuple - The input data.
            params: DotDict - The updated model parameters.
            training: Boolean or boolean scalar tensor, indicating whether to run
                the `Network` in training mode or inference mode.
            mask: A mask or list of masks. A mask can be either a tensor or None (no mask).

        Returns:
            y_pred: (batch_size, n_labels) - The output labels.
            loss: tf.float32 - The corresponding loss.
        """
        # Update iteration parameters.
        if params is not None: self._update_params(params)
        # Initialize `X` & `y_true_class` & `y_true_domain` from `inputs`.
        X = inputs[0]; y_true_class = inputs[1]; y_true_domain = inputs[2]
        # Get the feature embeddings according to `X`, then get the prediction of `y_pred_class` & `y_pred_domain`.
        # y_pred_class - (batch_size, n_labels); y_pred_domain - (batch_size, n_domains)
        Z = self.cnn(X); y_pred_class = self.cls_class(Z); y_pred_domain = self.cls_domain(Z)
        # Calculate the corresponding loss.
        # loss_* - (batch_size,)
        loss_class = self._loss_bce(y_pred_class, y_true_class)
        loss_domain = self._loss_bce(y_pred_domain, y_true_domain)
        # Get the final `y_pred` & `loss`.
        y_pred = DotDict({"class": y_pred_class, "domain": y_pred_domain,})
        loss = DotDict({"class": loss_class, "domain": loss_domain,})
        # Return the final `y_pred` & `loss`.
        return y_pred, loss

    # def _loss_bce func
    def _loss_bce(self, value, target):
        """
        Calculates binary cross entropy between tensors value and target.
        Get mean over last dimension to keep losses of different batches separate.
        :param value: (batch_size,) - Value of the object.
        :param target: (batch_size,) - Target of the object.
        :return loss: (batch_size,) - Loss between value and target.
        """
        # Note: `tf.nn.softmax_cross_entropy_with_logits` needs unscaled log probabilities,
        # we must not add `tf.nn.Softmax` layer at the last of the model.
        # loss - (batch_size,)
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=target,logits=value) if type(value) is not list else\
            [tf.nn.softmax_cross_entropy_with_logits(labels=target[i],logits=value[i]) for i in range(len(value))]
        # Return the final `loss`.
        return loss

    # def _loss_mse func
    def _loss_mse(self, value, target, keepdims=False):
        """
        Calculate mean squared error (L2 norm) between (list of) tensors value and target. Include a factor
        0.5 to squared error by convention. Set `keepdims` to false, then get sum over last dimension to keep
        losses of different batches separate.
        :param value: (batch_size,) - Value of the object.
        :param target: (batch_size,) - Traget of the object.
        :param keepdims: boolean - Parameter of `tf.reduce_sum`.
        :return loss: (batch_size,) - Loss between value and target.
        """
        # loss - (batch_size,)
        loss = [0.5*tf.reduce_sum(tf.square(target[i]-value[i]),axis=1, keepdims=keepdims)\
            for i in range(len(value))] if type(value) is list else\
            0.5*tf.reduce_sum(tf.square(target-value), axis=1, keepdims=keepdims)
        # Return the final `loss`.
        return loss

    """
    train funcs
    """
    # def train func
    @tf.function
    def train(self, inputs, params=None):
        """
        Train `SubdomainAdversarialConvNet` with source & target domain.

        Args:
            inputs: DotDict - The input data.
            params: DotDict - The updated model parameters.

        Returns:
            y_pred: DotDict - The output labels.
            loss: DotDict - The corresponding loss.
        """
        # Update iteration parameters.
        if params is not None: self._update_params(params)
        # Execute class classifier training, which requires:
        # 1) Extracted feature can confuse domain classfier.
        # 2) Extracted feature can be ​distinguished w.r.t. classes.
        y_pred, loss = self._train_class(inputs)
        # Execute domain classifier training, which requires:
        # 1) Domain classifier try to classify the extracted feature which is trained to confuse domain classifier.
        _, _ = self._train_domain(inputs)
        # Return the final `y_pred` & `loss`.
        return y_pred, loss

    # def _train_class func
    def _train_class(self, inputs):
        """
        Train the class classifier part of `SubdomainAdversarialConvNet` with source & target domain.

        Args:
            inputs: DotDict - The input data.

        Returns:
            y_pred: DotDict - The output labels.
            loss: DotDict - The corresponding loss.
        """
        # Initialize `inputs_source` & `inputs_target` according to `inputs`.
        inputs_source = inputs.source; inputs_target = inputs.target
        # Execute class classifier training, which requires:
        # 1) Extracted feature can confuse domain classfier.
        # 2) Extracted feature can be ​distinguished w.r.t. classes.
        with tf.GradientTape() as gt:
            # Forward `inputs_*` to get `y_pred_*` & `loss_*`.
            y_pred_source, loss_source = self(inputs_source, training=True)
            y_pred_target, loss_target = self(inputs_target, training=True)
            # Calculate the final `loss` of the class classifier part.
            loss = tf.reduce_mean(tf.concat([loss_source["class"], loss_target["class"]], axis=0)) -\
                self.params.w_loss_domain_i * tf.reduce_mean(tf.concat([loss_source["domain"], loss_target["domain"]], axis=0))
        # Get the corresponding gradients of [cnn,cls_class].
        trainable_variables = self.cnn.trainable_variables + self.cls_class.trainable_variables
        gradients = gt.gradient(loss, trainable_variables)
        self.optimizer_class.apply_gradients(zip(gradients, trainable_variables))
        # Get the final `y_pred` & `loss`.
        y_pred = DotDict({"source": y_pred_source, "target": y_pred_target,})
        loss = DotDict({
            "source": {"class": tf.reduce_mean(loss_source["class"]), "domain": tf.reduce_mean(loss_source["domain"]),},
            "target": {"class": tf.reduce_mean(loss_target["class"]), "domain": tf.reduce_mean(loss_target["domain"]),},
        })
        # Return the final `y_pred` & `loss`.
        return y_pred, loss

    # def _train_domain func
    def _train_domain(self, inputs):
        """
        Train the domain classifier part of `SubdomainAdversarialConvNet` with source & target domain.

        Args:
            inputs: DotDict - The input data.

        Returns:
            y_pred: DotDict - The output labels.
            loss: DotDict - The corresponding loss.
        """
        # Initialize `inputs_source` & `inputs_target` according to `inputs`.
        inputs_source = inputs.source; inputs_target = inputs.target
        # Execute domain classifier training, which requires:
        # 1) Domain classifier try to classify the extracted feature which is trained to confuse domain classifier.
        with tf.GradientTape() as gt:
            # Forward `inputs_*` to get `y_pred_*` & `loss_*`.
            y_pred_source, loss_source = self(inputs_source, training=True)
            y_pred_target, loss_target = self(inputs_target, training=True)
            # Calculate the final `loss` of the domain classifier part.
            loss = tf.reduce_mean(tf.concat([loss_source["domain"], loss_target["domain"]], axis=0))
        # Get the corresponding gradients of [cls_domain,].
        trainable_variables = self.cls_domain.trainable_variables
        gradients = gt.gradient(loss, trainable_variables)
        self.optimizer_domain.apply_gradients(zip(gradients, trainable_variables))
        # Get the final `y_pred` & `loss`.
        y_pred = DotDict({"source": y_pred_source, "target": y_pred_target,})
        loss = DotDict({
            "source": {"class": tf.reduce_mean(loss_source["class"]), "domain": tf.reduce_mean(loss_source["domain"]),},
            "target": {"class": tf.reduce_mean(loss_target["class"]), "domain": tf.reduce_mean(loss_target["domain"]),},
        })
        # Return the final `y_pred` & `loss`.
        return y_pred, loss

    """
    tool funcs
    """
    # def summary func
    def summary(self, print_fn=None):
        """
        Summary built model.
        :param print_fn: callable - Print function to use. Defaults to `print`. It will be called on each
            line of the summary. You can set it to a custom function in order to capture the string summary.
        """
        super(SubdomainAdversarialConvNet, self).summary(print_fn=print_fn)

if __name__ == "__main__":
    import numpy as np
    # local dep
    from params.domain_adaptation_params import subdomain_adversarial_conv_params

    # macro
    dataset = "eeg_anonymous"; batch_size = 32; seq_len = 80

    # Initialize training process.
    utils.model.set_seeds(42)

    ## Instantiate params.
    sdacn_params_inst = subdomain_adversarial_conv_params(dataset=dataset)
    n_channels = sdacn_params_inst.model.n_channels
    n_labels = sdacn_params_inst.model.n_labels
    n_domains = sdacn_params_inst.model.n_domains
    ## Evaluate `SubdomainAdversarialConvNet`.
    # Instantiate `SubdomainAdversarialConvNet`.
    sdacn_inst = SubdomainAdversarialConvNet(sdacn_params_inst.model)
    # Instantiate input X & y_true_class & y_true_domain of source & target domain.
    assert n_domains == 2
    X_source = tf.random.normal((batch_size, seq_len, n_channels), dtype=tf.float32)
    y_true_class_source = tf.cast(np.eye(n_labels)[np.random.randint(0, n_labels, size=(batch_size,))], dtype=tf.float32)
    y_true_domain_source = tf.cast(np.eye(n_domains)[np.zeros((batch_size,), dtype=np.int64)], dtype=tf.float32)
    X_target = tf.random.normal((batch_size, seq_len, n_channels), dtype=tf.float32)
    y_true_class_target = tf.cast(np.eye(n_labels)[np.random.randint(0, n_labels, size=(batch_size,))], dtype=tf.float32)
    y_true_domain_target = tf.cast(np.eye(n_domains)[np.ones((batch_size,), dtype=np.int64)], dtype=tf.float32)
    # Get the corresponding `inputs`, then forward layers in `sdacn_inst`.
    inputs = DotDict({
        "source": (X_source, y_true_class_source, y_true_domain_source,),
        "target": (X_target, y_true_class_target, y_true_domain_target,),
    }); y_pred, loss = sdacn_inst.train(inputs); sdacn_inst.summary()

