#!/usr/bin/env python3
"""
Created on 21:27, Jul. 25th, 2023

@author: Anonymous
"""
import time
import copy as cp
import numpy as np
import scipy as sp
import tensorflow as tf
from itertools import product
from collections import Counter
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir, os.pardir))
import utils; import utils.model
import utils.data.eeg
from models.domain_adaptation import SubdomainContrastiveTransformer as sdctn_model

__all__ = [
    "train",
]

# Global variables.
params = None; paths = None
model = None; optimizer = None

"""
init funcs
"""
# def init func
def init(base_, params_):
    """
    Initialize `SubdomainContrastiveTransformer` training variables.

    Args:
        base_: str - The base path of current project.
        params_: DotDict - The parameters of current training process.

    Returns:
        None
    """
    global params, paths
    # Initialize parameters.
    params = cp.deepcopy(params_)
    paths = utils.Paths(base=base_, params=params)
    # Initialize model.
    _init_model()
    # Initialize training process.
    _init_train()

# def _init_model func
def _init_model():
    """
    Initialize model used in the training process.

    Args:
        None

    Returns:
        None
    """
    global params
    ## Initialize tf configuration.
    # Not set random seed, should be done before instantiating `model`.
    # Set default precision.
    tf.keras.backend.set_floatx(params._precision)
    # Check whether run in graph mode or eager mode.
    tf.config.run_functions_eagerly(not params.train.use_graph_mode)

# def _init_train func
def _init_train():
    """
    Initialize the training process.

    Args:
        None

    Returns:
        None
    """
    pass

"""
preprocess funcs
"""
def preprocess_data(X, mode="baseline"):
    """
    Preprocess data according to preprocess mode.

    Args:
        X: (batch_size, seq_len, n_channels) - The raw signals (just cropped without baseline correction).
        mode: str - The preprocess mode to preprocess data.

    Returns:
        X: (batch_size, seq_len, n_channals) - The preprocessed data.
    """
    # Preprocess data according to preprocess mode.
    if mode == "baseline":
        X = (X - np.mean(X)) / np.std(X)
    elif mode == "zscore":
        X = (X - np.mean(X, axis=1, keepdims=True)) / np.std(X, axis=1, keepdims=True)
    elif mode == "euclidean_alignment":
        # Transpose `X` to get `M`.
        # M - (batch_size, n_channels, seq_len)
        M = np.transpose(X, axes=[0,2,1])
        # Calculate `R` according to each sample of `M`.
        # R - (n_channels, n_channels)
        R = np.mean(np.matmul(M, np.transpose(M, axes=[0,2,1])), axis=0)
        # Calculate the transformed matrix `T`.
        # T - (n_channels, n_channels)
        T = sp.linalg.inv(sp.linalg.sqrtm(R))
        T = np.real(T).astype(np.float32) if np.iscomplexobj(T) else T
        if not np.any(np.isfinite(T)):
            raise ValueError("ERROR: T matrix has infiniate values in SubdomainContrastiveTransformer.")
        # Transform `M` to get the final `X`.
        # X - (batch_size, seq_len, n_channels)
        X = np.transpose(np.matmul(T, M), axes=[0,2,1])
    else:
        raise ValueError((
            "ERROR: Unknown preprocess mode {} in SubdomainContrastiveTransformer."
        ).format(mode))
    # Return the final `X`.
    return X

"""
data funcs
"""
# def load_data func
def load_data(load_params):
    """
    Load data from specified dataset.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_source: tuple - The source dataset, including (X_source, y_source).
        dataset_target: tuple - The target dataset, including (X_target, y_target).
    """
    global params
    # Load data from specified dataset.
    try:
        func = getattr(sys.modules[__name__], "_".join(["_load_data", params.train.dataset]))
        dataset_source, dataset_target = func(load_params)
    except Exception:
        raise ValueError((
            "ERROR: Unknown dataset type {} in train.domain_adaptation.DomainContrastiveNet.SubdomainContrastiveTransformer."
        ).format(params.train.dataset))
    # Return the final `dataset_source` & `dataset_target`.
    return dataset_source, dataset_target

# def _load_data_eeg_anonymous func
def _load_data_eeg_anonymous(load_params):
    """
    Load eeg data from the specified subject in `eeg_anonymous`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_source: tuple - The source dataset, including (X_source, y_class_source, y_domain_source).
        dataset_target: tuple - The target dataset, including (X_target, y_class_target, y_domain_target).
    """
    global params, paths
    # Initialize path_run.
    path_run = os.path.join(paths.base, "data", "eeg.anonymous", "020", "20230405")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    datasets = utils.DotDict(); dataset_names = sorted(set(load_params.sourceset) | set(load_params.targetset))
    dataset_domains = ["-".join([dataset_name_i.split("-")[0], dataset_name_i.split("-")[-1]])\
        if dataset_name_i.split("-")[0] == "task" else "-".join(dataset_name_i.split("-")[0:2])\
        for dataset_name_i in dataset_names]
    available_dataset_domains = sorted(set(dataset_domains))
    for dataset_name_i, dataset_domain_i in zip(dataset_names, dataset_domains):
        # Load data from specified dataset name.
        session_type_i = "-".join(dataset_name_i.split("-")[:-1]); data_type_i = dataset_name_i.split("-")[-1]
        func_i = getattr(utils.data.eeg.anonymous, "_".join(["load_run", session_type_i.split("-")[0]]))
        X_i, y_class_i = func_i(path_run, session_type="-".join(session_type_i.split("-")[1:]))
        # Check whether current dataset has data items.
        if X_i[data_type_i] is None:
            msg = "WARNING: Skip dataset {} with no data item.".format(dataset_name_i)
            print(msg); paths.run.logger.summaries.info(msg); continue
        X_i = X_i[data_type_i].astype(np.float32); y_class_i = y_class_i[data_type_i].astype(np.int64)
        # Get the index according to `available_dataset_domains`.
        y_domain_i = np.repeat(available_dataset_domains.index(dataset_domain_i), repeats=y_class_i.shape[0]).astype(np.int64)
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        # If the type of dataset is `default`.
        if load_params.type == "default":
            X_i = X_i[:,20:100,:]
        # If the type of dataset is `lvbj`.
        elif load_params.type == "lvbj":
            X_i = X_i[:,20:100,:]
        # Get unknown type of dataset.
        else:
            raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
        # Do cross-trial normalization.
        X_i = preprocess_data(X_i, mode="baseline")
        # Set the corresponding item of dataset.
        datasets[dataset_name_i] = utils.DotDict({"X":X_i,"y_class":y_class_i,"y_domain":y_domain_i,})
    # Initialize sourceset & targetset.
    X_source = []; y_class_source = []; y_domain_source = []
    X_target = []; y_class_target = []; y_domain_target = []
    for dataset_name_i, dataset_i in datasets.items():
        # If sourceset and targetset are not the same, construct sourceset & targetset separately.
        if dataset_name_i in load_params.sourceset and dataset_name_i not in load_params.targetset:
            X_source.append(dataset_i.X); y_class_source.append(dataset_i.y_class)
            y_domain_source.append(dataset_i.y_domain)
        # If sourceset and targetset are not the same, construct sourceset & targetset separately.
        elif dataset_name_i not in load_params.sourceset and dataset_name_i in load_params.targetset:
            X_target.append(dataset_i.X); y_class_target.append(dataset_i.y_class)
            y_domain_target.append(dataset_i.y_domain)
        # Wrong cases.
        else:
            raise ValueError("ERROR: Unknown dataset name {}.".format(dataset_name_i))
    # Check whether sourceset & targetset both have data items.
    if len(X_source) == 0 or len(X_target) == 0: return [], ([], [], [])
    # X - (n_samples, seq_len, n_channels); y_class - (n_samples,); y_domain - (n_samples,)
    X_source = np.concatenate(X_source, axis=0); y_class_source = np.concatenate(y_class_source, axis=0)
    y_domain_source = np.concatenate(y_domain_source, axis=0)
    X_target = np.concatenate(X_target, axis=0); y_class_target = np.concatenate(y_class_target, axis=0)
    y_domain_target = np.concatenate(y_domain_target, axis=0)
    # Make sure there is no overlap between X_source & X_target.
    samples_same = None; n_samples = 10; assert X_source.shape[1] == X_target.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_source.shape[1])
        sample_same_i = np.intersect1d(X_source[:,sample_idx,0], X_target[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to one-hot encoding.
    try:
        labels_source = sorted(set(y_class_source)); domains_source = sorted(set(y_domain_source))
        labels_target = sorted(set(y_class_target)); domains_target = sorted(set(y_domain_target))
        assert len(labels_source) == len(labels_target) == params.model.n_labels
        assert (len(domains_source) < params.model.n_modalities) and (len(domains_target) < params.model.n_modalities)
        labels = labels_source; domains = sorted(set(domains_source) | set(domains_target))
    except AssertionError as e:
        msg = (
            "WARNING: Skip experiment (source:{};target:{}) due to that the classes of test cases are not enough."
        ).format(set(load_params.sourceset), set(load_params.targetset))
        print(msg); paths.run.logger.summaries.info(msg); return [], ([], [], [])
    # y_class - (n_samples, n_labels)
    y_class_source = np.array([labels.index(y_i) for y_i in y_class_source], dtype=np.int64)
    y_class_source = np.eye(len(labels))[y_class_source]
    y_class_target = np.array([labels.index(y_i) for y_i in y_class_target], dtype=np.int64)
    y_class_target = np.eye(len(labels))[y_class_target]
    # y_domain - (n_samples, n_domains)
    y_domain_source = np.array([domains.index(y_i) for y_i in y_domain_source], dtype=np.int64)
    y_domain_source = np.eye(len(domains))[y_domain_source]
    y_domain_target = np.array([domains.index(y_i) for y_i in y_domain_target], dtype=np.int64)
    y_domain_target = np.eye(len(domains))[y_domain_target]
    # Split target domain into train-set & validation-set & test-set.
    train_ratio = params.train.train_ratio; validation_ratio = test_ratio = (1. - train_ratio) / 2.
    train_idxs = sorted(np.random.choice(np.arange(y_class_target.shape[0]),
        size=int(y_class_target.shape[0] * train_ratio), replace=False))
    validation_idxs = np.random.choice(sorted(set(np.arange(y_class_target.shape[0])) - set(train_idxs)),
        size=int(y_class_target.shape[0] * validation_ratio), replace=False)
    test_idxs = sorted(set(np.arange(y_class_target.shape[0])) - set(train_idxs) - set(validation_idxs))
    assert len(set(train_idxs) & set(validation_idxs)) == 0
    assert len(set(train_idxs) & set(test_idxs)) == 0
    assert len(set(validation_idxs) & set(test_idxs)) == 0
    X_target_train = X_target[train_idxs,:,:]; y_class_target_train = y_class_target[train_idxs,:]
    y_domain_target_train = y_domain_target[train_idxs,:]
    X_target_validation = X_target[validation_idxs,:,:]; y_class_target_validation = y_class_target[validation_idxs,:]
    y_domain_target_validation = y_domain_target[validation_idxs,:]
    X_target_test = X_target[test_idxs,:,:]; y_class_target_test = y_class_target[test_idxs,:]
    y_domain_target_test = y_domain_target[test_idxs,:]
    # Randomly select `load_params.n_samples.targetset_train` samples to format target-train-set.
    if load_params.n_samples.targetset_train is not None:
        train_idxs = np.random.choice(np.arange(len(X_target_train)),
            size=load_params.n_samples.targetset_train, replace=False)
        X_target_train = X_target_train[train_idxs,:,:]
        y_class_target_train = y_class_target_train[train_idxs,:]
        y_domain_target_train = y_domain_target_train[train_idxs,:]
    # Construct ERP over source dataset.
    n_samples_target_train = len(X_target_train); erp_samples = 15
    X_source_ = []; y_class_source_ = []; y_domain_source_ = []
    for domain_i in domains_source:
        # Initialize `*_source_i` in `*_source_`.
        X_source_.append([]); y_class_source_.append([]); y_domain_source_.append([])
        # If `erp_samples` is 1, directly split the dataset according to different modalities.
        if erp_samples == 1:
            # Get the indices of the corresponding modality, then get samples.
            sample_mask_i = (np.argmax(y_domain_source, axis=-1) == domain_i)
            sample_idxs = np.where(sample_mask_i)[0]
            # Update `*_source_`.
            for sample_idx in sample_idxs:
                X_source_[-1].append(X_source[sample_idx,:,:])
                y_class_source_[-1].append(y_class_source[sample_idx,:])
                y_domain_source_[-1].append(y_domain_source[sample_idx,:])
        # Otherwise, get ERPs according to `erp_samples`.
        else:
            for _ in range(n_samples_target_train // len(labels_source)):
                # Loop over all labels to get the ERP of samples.
                for label_i in labels_source:
                    # Get the indices of the corresponding label, then get samples.
                    sample_mask_i = (np.argmax(y_domain_source, axis=-1) == domain_i) &\
                        (np.argmax(y_class_source, axis=-1) == label_i)
                    sample_idxs = np.random.choice(np.where(sample_mask_i)[0], size=erp_samples, replace=False)\
                        if erp_samples is not None else np.where(sample_mask_i)[0]
                    # Update `*_source_`.
                    X_source_[-1].append(np.mean(X_source[sample_idxs,:,:], axis=0))
                    y_class_source_[-1].append(np.eye(len(labels))[label_i])
                    y_domain_source_[-1].append(np.eye(len(domains))[domain_i])
    # Convert each item in `*_source_` to `np.array`.
    X_source = [np.stack(X_source_i, axis=0) for X_source_i in X_source_]
    y_class_source = [np.stack(y_class_source_i, axis=0) for y_class_source_i in y_class_source_]
    y_domain_source = [np.stack(y_domain_source_i, axis=0) for y_domain_source_i in y_domain_source_]
    # Shuffle the dataset according to each source.
    for source_idx in range(len(X_source)):
        sample_idxs = np.arange(len(X_source[source_idx])); np.random.shuffle(sample_idxs)
        X_source[source_idx] = X_source[source_idx][sample_idxs,:,:]
        y_class_source[source_idx] = y_class_source[source_idx][sample_idxs,:]
        y_domain_source[source_idx] = y_domain_source[source_idx][sample_idxs,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with source-set ({}) & target-train-set ({})" +\
        " & target-validation-set ({}) & target-test-set ({})."
    ).format(X_source[0].shape, X_target_train.shape, X_target_validation.shape, X_target_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Construct dataset from data items.
    dataset_source = [tf.data.Dataset.from_tensor_slices((X_source_i, y_class_source_i, y_domain_source_i))\
        for X_source_i, y_class_source_i, y_domain_source_i in zip(X_source, y_class_source, y_domain_source)]
    dataset_target_train = tf.data.Dataset.from_tensor_slices((X_target_train, y_class_target_train, y_domain_target_train))
    dataset_target_validation = tf.data.Dataset.from_tensor_slices(
        (X_target_validation, y_class_target_validation, y_domain_target_validation)
    )
    dataset_target_test = tf.data.Dataset.from_tensor_slices((X_target_test, y_class_target_test, y_domain_target_test))
    # Shuffle and then batch the dataset.
    dataset_source = [dataset_source_i.shuffle(params.train.buffer_size).batch(
        params.train.batch_size if params.model.use_siamese else\
            (params.train.batch_size // len(dataset_source)), drop_remainder=False
    ) for dataset_source_i in dataset_source]
    dataset_target_train = dataset_target_train.shuffle(params.train.buffer_size).batch(
        params.train.batch_size, drop_remainder=False)
    dataset_target_validation = dataset_target_validation.shuffle(params.train.buffer_size).batch(
        params.train.batch_size, drop_remainder=False)
    dataset_target_test = dataset_target_test.shuffle(params.train.buffer_size).batch(
        params.train.batch_size, drop_remainder=False)
    # Return the final `dataset_source` & `dataset_target`.
    return dataset_source, (dataset_target_train, dataset_target_validation, dataset_target_test)

"""
train funcs
"""
# def train func
def train(base_, params_):
    """
    Train the model.

    Args:
        base_: str - The base path of current project.
        params_: str - The parameters of current training process.

    Returns:
        None
    """
    global _forward, _train
    global params, paths, model, optimizer
    # Initialize parameters & variables of current training process.
    init(base_, params_)
    # Log the start of current training process.
    paths.run.logger.summaries.info("Training started with dataset {}.".format(params.train.dataset))
    # Initialize load_params. Each load_params_i corresponds to a sub-dataset.
    if params.train.dataset == "eeg_anonymous":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            os.path.join(paths.base, "data", "demo-dataset", "001"),
            os.path.join(paths.base, "data", "demo-dataset", "002"),
            os.path.join(paths.base, "data", "demo-dataset", "003"),
        ] * 20; load_type = "default"; n_samples = utils.DotDict({"targetset_train":None,})
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # source-task-all-all-target-tmr-n23-audio
            utils.DotDict({
                "name": "source-task-all-all-target-tmr-n23-audio",
                "sourceset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "targetset": ["tmr-N2/3-audio",],
                "type": load_type, "n_samples": n_samples, "n_domains": 3,
            }),
            # source-task-all-image-target-tmr-n23-audio
            utils.DotDict({
                "name": "source-task-all-image-target-tmr-n23-audio",
                "sourceset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "targetset": ["tmr-N2/3-audio",],
                "type": load_type, "n_samples": n_samples, "n_domains": 2,
            }),
            # source-task-all-audio-target-tmr-n23-audio
            utils.DotDict({
                "name": "source-task-all-audio-target-tmr-n23-audio",
                "sourceset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "targetset": ["tmr-N2/3-audio",],
                "type": load_type, "n_samples": n_samples, "n_domains": 2,
            }),
        ]
        """
        load_params = [
            # source-task-all-image-target-task-all-audio
            utils.DotDict({
                "name": "source-task-all-image-target-task-all-audio",
                "sourceset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "targetset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "type": load_type, "n_samples": n_samples, "n_domains": 2,
            }),
            # source-task-all-all-target-tmr-n23-audio
            utils.DotDict({
                "name": "source-task-all-all-target-tmr-n23-audio",
                "sourceset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "targetset": ["tmr-N2/3-audio",],
                "type": load_type, "n_samples": n_samples, "n_domains": 3,
            }),
            # source-task-all-image-target-tmr-n23-audio
            utils.DotDict({
                "name": "source-task-all-image-target-tmr-n23-audio",
                "sourceset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "targetset": ["tmr-N2/3-audio",],
                "type": load_type, "n_samples": n_samples, "n_domains": 2,
            }),
            # source-task-all-audio-target-tmr-n23-audio
            utils.DotDict({
                "name": "source-task-all-audio-target-tmr-n23-audio",
                "sourceset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "targetset": ["tmr-N2/3-audio",],
                "type": load_type, "n_samples": n_samples, "n_domains": 2,
            }),
        ]
        """
    else:
        raise ValueError("ERROR: Unknown dataset {} in train.naive_rnn.".format(params.train.dataset))
    # Execute experiments for each dataset run.
    for path_run_i in path_runs:
        # Log the start of current training iteration.
        msg = "Training started with sub-dataset {}.".format(path_run_i)
        print(msg); paths.run.logger.summaries.info(msg)
        for load_params_idx in range(len(load_params)):
            # Add `path_run` to `load_params_i`.
            load_params_i = cp.deepcopy(load_params[load_params_idx]); load_params_i.path_run = path_run_i
            # Update `params` according to `load_params_i`.
            params.model.n_modalities = load_params_i.n_domains
            # Load data from specified experiment.
            dataset_source, dataset_target = load_data(load_params_i)
            dataset_target_train, dataset_target_validation, dataset_target_test = dataset_target
            # Check whether source-set & target-set exists.
            if len(dataset_source[0]) == 0 or len(dataset_target) == 0:
                msg = (
                    "INFO: Skip experiment {} with source-set ({:d} items) & target-set ({:d} items)."
                ).format(load_params_i.name, len(dataset_source[0]), len(dataset_target))
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Start training process of current specified experiment.
            msg = "Start the training process of experiment {}.".format(load_params_i.name)
            print(msg); paths.run.logger.summaries.info(msg)

            # Train the model for each time segment.
            accuracies_train = []; accuracies_validation = []; accuracies_test = []

            # Reset the iteration information of params.
            params.iteration(iteration=0)
            # Initialize model of current time segment.
            model = sdctn_model(params.model)
            # Make an ADAM optimizer for model.
            optimizer = tf.keras.optimizers.Adam(learning_rate=params.train.lr_i, epsilon=1e-4)
            # Re-initialize the tf.function.
            _forward = tf.function(_forward.python_function); _train = tf.function(_train.python_function)
            for epoch_idx in range(params.train.n_epochs):
                # Update iteration parameters.
                params.iteration(iteration=epoch_idx)
                # Record the start time of preparing data.
                time_start = time.time()
                # Prepare for model train process.
                accuracy_train = []; loss_train = []; batch_size_train = []
                # Train model for current epoch.
                iter_target = iter(dataset_target_train)
                iter_source = [iter(dataset_source_i) for dataset_source_i in dataset_source]
                n_batches = max([len(dataset_source_i) for dataset_source_i in dataset_source])
                for batch_idx in range(n_batches):
                    # Initialize `batch_*_i` from `iter_*`.
                    batch_target_i = iter_target.next()
                    batch_source_i = [iter_source_i.next() for iter_source_i in iter_source]
                    batch_i = [batch_target_i, *batch_source_i]
                    # Get the number of current batch_i.
                    batch_size_i = len(batch_i[0][0]); batch_size_train.append(batch_size_i)
                    # If `iter_target` is used up, re-instantiate `iter_target`.
                    if (batch_idx + 1) % len(dataset_target_train) == 0: iter_target = iter(dataset_target_train)
                    # If `iter_source` is used up, re-instantiate `iter_source`.
                    for source_idx in range(len(dataset_source)):
                        if (batch_idx + 1) % len(dataset_source[source_idx]) == 0:
                            iter_source[source_idx] = iter(dataset_source[source_idx])
                    # Train model for current batch.
                    outputs_i, loss_i = _train(batch_i, params); outputs_i, loss_i = outputs_i.numpy(), loss_i.numpy()
                    accuracy_i = np.argmax(outputs_i, axis=-1) == np.argmax(batch_i[0][1], axis=-1)
                    accuracy_i = np.sum(accuracy_i) / accuracy_i.size
                    accuracy_train.append(accuracy_i); loss_train.append(loss_i)
                accuracy_train = np.sum(np.array(accuracy_train) * np.array(batch_size_train)) / np.sum(batch_size_train)
                loss_train = np.sum(np.array(loss_train) * np.array(batch_size_train)) / np.sum(batch_size_train)
                accuracies_train.append(accuracy_train)
                # Prepare for model validation process.
                accuracy_validation = []; loss_validation = []; batch_size_validation = []
                # Validate model for current epoch.
                iter_target = iter(dataset_target_validation); n_batches = len(dataset_target_validation)
                iter_source = [iter(dataset_source_i) for dataset_source_i in dataset_source]
                for batch_idx in range(n_batches):
                    # Initialize `batch_*_i` from `iter_*`.
                    batch_target_i = iter_target.next()
                    batch_source_i = [iter_source_i.next() for iter_source_i in iter_source]
                    batch_i = [batch_target_i, *batch_source_i]
                    # Get the number of current batch_i.
                    batch_size_i = len(batch_i[0][0]); batch_size_validation.append(batch_size_i)
                    # validation model for current batch.
                    outputs_i, loss_i = _forward(batch_i); outputs_i, loss_i = outputs_i.numpy(), loss_i.numpy()
                    accuracy_i = np.argmax(outputs_i, axis=-1) == np.argmax(batch_i[0][1], axis=-1)
                    accuracy_i = np.sum(accuracy_i) / accuracy_i.size
                    accuracy_validation.append(accuracy_i); loss_validation.append(loss_i)
                accuracy_validation = np.sum(np.array(accuracy_validation) *\
                    np.array(batch_size_validation)) / np.sum(batch_size_validation)
                loss_validation = np.sum(np.array(loss_validation) *\
                    np.array(batch_size_validation)) / np.sum(batch_size_validation)
                accuracies_validation.append(accuracy_validation)
                # Prepare for model test process.
                accuracy_test = []; loss_test = []; batch_size_test = []
                # Test model for current epoch.
                iter_target = iter(dataset_target_test); n_batches = len(dataset_target_test)
                iter_source = [iter(dataset_source_i) for dataset_source_i in dataset_source]
                for batch_idx in range(n_batches):
                    # Initialize `batch_*_i` from `iter_*`.
                    batch_target_i = iter_target.next()
                    batch_source_i = [iter_source_i.next() for iter_source_i in iter_source]
                    batch_i = [batch_target_i, *batch_source_i]
                    # Get the number of current batch_i.
                    batch_size_i = len(batch_i[0][0]); batch_size_test.append(batch_size_i)
                    # Test model for current batch.
                    outputs_i, loss_i = _forward(batch_i); outputs_i, loss_i = outputs_i.numpy(), loss_i.numpy()
                    accuracy_i = np.argmax(outputs_i, axis=-1) == np.argmax(batch_i[0][1], axis=-1)
                    accuracy_i = np.sum(accuracy_i) / accuracy_i.size
                    accuracy_test.append(accuracy_i); loss_test.append(loss_i)
                accuracy_test = np.sum(np.array(accuracy_test) *\
                    np.array(batch_size_test)) / np.sum(batch_size_test)
                loss_test = np.sum(np.array(loss_test) *\
                    np.array(batch_size_test)) / np.sum(batch_size_test)
                accuracies_test.append(accuracy_test)
                # Log information related to current training epoch.
                time_stop = time.time()
                msg = (
                    "Finish train epoch {:d} in {:.2f} seconds, generating {:d} concrete functions."
                ).format(epoch_idx, time_stop-time_start, len(_train.pretty_printed_concrete_signatures().split("\n\n")))
                print(msg); paths.run.logger.summaries.info(msg)
                msg = (
                    "Accuracy(train): {:.2f}%. Loss(train): {:.5f}."
                ).format(accuracy_train * 100., loss_train)
                print(msg); paths.run.logger.summaries.info(msg)
                msg = (
                    "Accuracy(validation): {:.2f}%. Loss(validation): {:.5f}."
                ).format(accuracy_validation * 100., loss_validation)
                print(msg); paths.run.logger.summaries.info(msg)
                msg = (
                    "Accuracy(test): {:.2f}%. Loss(test): {:.5f}."
                ).format(accuracy_test * 100., loss_test)
                print(msg); paths.run.logger.summaries.info(msg)
                # Summarize model information.
                if epoch_idx == 0:
                    model.summary(print_fn=print); model.summary(print_fn=paths.run.logger.summaries.info)
            # Convert `accuracies_validation` & `accuracies_test` to `np.array`.
            accuracies_validation = np.round(np.array(accuracies_validation, dtype=np.float32), decimals=4)
            accuracies_test = np.round(np.array(accuracies_test, dtype=np.float32), decimals=4)
            epoch_maxacc_idxs = np.where(accuracies_validation == np.max(accuracies_validation))[0]
            epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(accuracies_test[epoch_maxacc_idxs])]
            # Finish training process of current specified experiment.
            msg = (
                "Finish the training process of experiment {}, with test-accuracy ({:.2f}%)" +\
                " according to max validation-accuracy ({:.2f}%) at epoch {:d}, generating {:d} concrete functions."
            ).format(load_params_i.name, accuracies_test[epoch_maxacc_idx]*100.,
                accuracies_validation[epoch_maxacc_idx]*100., epoch_maxacc_idx,
                len(_train.pretty_printed_concrete_signatures().split("\n\n")))
            print(msg); paths.run.logger.summaries.info(msg)
        # Log the end of current training iteration.
        msg = "Training finished with sub-dataset {}.".format(path_run_i)
        print(msg); paths.run.logger.summaries.info(msg)
    # Log the end of current training process.
    msg = "Training finished with dataset {}.".format(params.train.dataset)
    print(msg); paths.run.logger.summaries.info(msg)

# def _forward func
@tf.function
def _forward(inputs, training=False):
    """
    Forward the model using one-step data. Everything entering this function already be a tensor.
    :param inputs: (X, locations, subject_id, y)
    :param training: Indicate whether enable training process.
    :return outputs_: (n_samples, n_labels) - The predicted labels of inputs.
    :return loss_: float - The corresponding cross-entropy loss.
    """
    global model; return model(inputs, training=training)

# def _train func
@tf.function
def _train(inputs, params_):
    """
    Train the model using one-step data. Everything entering this function already be a tensor.
    :param inputs: (X, locations, subject_id, y)
    :return outputs_: (n_samples, n_labels) - The predicted labels of inputs.
    :return loss_: float - The corresponding cross-entropy loss.
    """
    global model, optimizer
    # Update the learning rate of optimizer.
    optimizer.lr = params_.train.lr_i
    # Train the model using one-step data.
    with tf.GradientTape() as gt:
        outputs_, loss_ = _forward(inputs, training=True)
    # Modify weights to optimize the model.
    gradients = gt.gradient(loss_, model.trainable_variables)
    gradients = [(tf.clip_by_norm(grad, 2), var)\
        if grad is not None else (grad, var)\
        for (grad, var) in zip(gradients, model.trainable_variables)]
    optimizer.apply_gradients(gradients)
    # Return the final `outputs_` & `loss_`.
    return outputs_, loss_

if __name__ == "__main__":
    import os
    # local dep
    from params.domain_adaptation_params import subdomain_contrastive_transformer_params

    # macro
    dataset = "eeg_anonymous"

    # Initialize random seed.
    utils.model.set_seeds(42)

    ## Instantiate naive_rnn.
    # Initialize base.
    base = os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir, os.pardir)
    # Instantiate subdomain_contrastive_transformer_params.
    subdomain_contrastive_transformer_params_inst = subdomain_contrastive_transformer_params(dataset=dataset)
    # Train `SubdomainContrastiveTransformer`.
    train(base, subdomain_contrastive_transformer_params_inst)

