#!/usr/bin/env python3
"""
Created on 15:03, Dec. 22nd, 2022

@author: Anonymous
"""
import re, time
import copy as cp
import numpy as np
import scipy as sp
import tensorflow as tf
from sklearn import datasets
from collections import Counter
from sklearn.discriminant_analysis import _cov
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.pardir)
import utils; import utils.model
import utils.data.eeg; import utils.data.meg; import utils.data.seeg
from models.naive_cnn import naive_cnn as naive_cnn_model

__all__ = [
    "train",
]

# Global variables.
params = None; paths = None

"""
init funcs
"""
# def init func
def init(base_, params_):
    """
    Initialize `naive_cnn` training variables.
    :param base_: The base path of current project.
    :param params_: The parameters of current training process.
    """
    global params, paths
    # Initialize params.
    params = cp.deepcopy(params_)
    paths = utils.Paths(base=base_, params=params)
    # Initialize model.
    _init_model()
    # Initialize training process.
    _init_train()

# def _init_model func
def _init_model():
    """
    Initialize model used in the training process.
    """
    global params
    ## Initialize tf configuration.
    # Not set random seed, should be done before instantiating `model`.
    # Set default precision.
    tf.keras.backend.set_floatx(params._precision)
    # Check whether run in graph mode or eager mode.
    tf.config.run_functions_eagerly(not params.train.use_graph_mode)

# def _init_train func
def _init_train():
    """
    Initialize the training process.
    """
    pass

"""
data funcs
"""
# def load_data func
def load_data(load_params):
    """
    Load data from specified dataset.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params
    # Load data from specified dataset.
    try:
        func = getattr(sys.modules[__name__], "_".join(["_load_data", params.train.dataset]))
        dataset_train_, dataset_validation_, dataset_test_ = func(load_params)
    except Exception:
        raise ValueError("ERROR: Unknown dataset type {} in train.naive_cnn.".format(params.train.dataset))
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return dataset_train_, dataset_validation_, dataset_test_

# def _load_data_eeg_anonymous func
def _load_data_eeg_anonymous(load_params):
    """
    Load eeg data from the specified subject in `eeg_anonymous`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize path_run.
    path_run = os.path.join(paths.base, "data", "eeg.anonymous", "020", "20230405")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    datasets = utils.DotDict(); dataset_names = list(set(load_params.trainset) | set(load_params.testset))
    for dataset_name_i in dataset_names:
        # Load data from specified dataset name.
        session_type_i = "-".join(dataset_name_i.split("-")[:-1]); data_type_i = dataset_name_i.split("-")[-1]
        func_i = getattr(utils.data.eeg.anonymous, "_".join(["load_run", session_type_i.split("-")[0]]))
        X_i, y_i = func_i(path_run, session_type="-".join(session_type_i.split("-")[1:]))
        # Check whether current dataset has data items.
        if X_i[data_type_i] is None:
            msg = "WARNING: Skip dataset {} with no data item.".format(dataset_name_i)
            print(msg); paths.run.logger.summaries.info(msg); continue
        X_i = X_i[data_type_i].astype(np.float32); y_i = y_i[data_type_i].astype(np.int64)
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        # If the type of dataset is `default`.
        if load_params.type == "default":
            X_i = X_i[:,20:100,:]
        # If the type of dataset is `lvbj`.
        elif load_params.type == "lvbj":
            X_i = X_i[:,20:100,:]
        # Get unknown type of dataset.
        else:
            raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
        # Do cross-trial normalization.
        X_i = (X_i - np.mean(X_i)) / np.std(X_i)
        # Set the corresponding item of dataset.
        datasets[dataset_name_i] = utils.DotDict({"X":X_i,"y":y_i,})
    # Initialize trainset & testset.
    X_train = []; y_train = []; X_test = []; y_test = []
    for dataset_name_i, dataset_i in datasets.items():
        # If trainset and testset are the same dataset, split X into trainset & testset.
        if dataset_name_i in load_params.trainset and dataset_name_i in load_params.testset:
            # Construct dataset from data items.
            train_ratio = params.train.train_ratio
            X_train.append(dataset_i.X[:int(train_ratio * dataset_i.X.shape[0]),:,:])
            y_train.append(dataset_i.y[:int(train_ratio * dataset_i.y.shape[0])])
            X_test.append(dataset_i.X[int(train_ratio * dataset_i.X.shape[0]):,:,:])
            y_test.append(dataset_i.y[int(train_ratio * dataset_i.y.shape[0]):])
        # If trainset and testset are not the same, construct trainset & testset separately.
        elif dataset_name_i in load_params.trainset and dataset_name_i not in load_params.testset:
            X_train.append(dataset_i.X); y_train.append(dataset_i.y)
        # If trainset and testset are not the same, construct trainset & testset separately.
        elif dataset_name_i not in load_params.trainset and dataset_name_i in load_params.testset:
            X_test.append(dataset_i.X); y_test.append(dataset_i.y)
        # Wrong cases.
        else:
            raise ValueError("ERROR: Unknown dataset name {}.".format(dataset_name_i))
    # Check whether trainset & testset both have data items.
    if len(X_train) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    X_train = np.concatenate(X_train, axis=0); y_train = np.concatenate(y_train, axis=0)
    X_test = np.concatenate(X_test, axis=0); y_test = np.concatenate(y_test, axis=0)
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to one-hot encoding.
    try:
        assert len(set(y_train)) == len(set(y_test)) == 15; labels = sorted(set(y_train))
    except AssertionError as e:
        msg = (
            "WARNING: Skip experiment (train:{};test:{}) due to that the classes of test cases are not enough."
        ).format(set(load_params.trainset), set(load_params.testset))
        print(msg); paths.run.logger.summaries.info(msg); return ([], []), ([], []), ([], [])
    # y - (n_samples, n_labels)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64); y_train = np.eye(len(labels))[y_train]
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64); y_test = np.eye(len(labels))[y_test]
    # Downsample trainset & testset to test the time redundancy.
    if load_params.type == "default":
        sample_rate = 100
    elif load_params.type == "lvbj":
        sample_rate = 100
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    assert sample_rate / load_params.resample_rate == sample_rate // load_params.resample_rate
    downsample_rate = sample_rate // load_params.resample_rate
    if downsample_rate != 1:
        assert X_train.shape[1] / downsample_rate == X_train.shape[1] // downsample_rate
        X_train = np.concatenate([X_train[:,np.arange(start_i, X_train.shape[1], downsample_rate),:]\
            for start_i in range(downsample_rate)], axis=0)
        y_train = np.concatenate([y_train for _ in range(downsample_rate)], axis=0)
        assert X_test.shape[1] / downsample_rate == X_test.shape[1] // downsample_rate
        X_test = np.concatenate([X_test[:,np.arange(start_i, X_test.shape[1], downsample_rate),:]\
            for start_i in range(downsample_rate)], axis=0)
        y_test = np.concatenate([y_test for _ in range(downsample_rate)], axis=0)
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Further split test-set into validation-set & test-set.
    validation_idxs = np.random.choice(np.arange(X_test.shape[0]), size=int(X_test.shape[0]/2), replace=False)
    validation_mask = np.zeros((X_test.shape[0],), dtype=np.bool_); validation_mask[validation_idxs] = True
    X_validation = X_test[validation_mask,:,:]; y_validation = y_test[validation_mask,:]
    X_test = X_test[~validation_mask,:,:]; y_test = y_test[~validation_mask,:]
    # Randomly select `load_params.n_samples.trainset` samples to format train-set.
    if load_params.n_samples.trainset is not None:
        train_idxs = np.random.choice(np.arange(len(X_train)), size=load_params.n_samples.trainset, replace=False)
        X_train = X_train[train_idxs,:,:]; y_train = y_train[train_idxs,:]
    # Shuffle the indices of train-set & validation-set & test-set.
    train_idxs = np.arange(len(X_train)); np.random.shuffle(train_idxs)
    validation_idxs = np.arange(len(X_validation)); np.random.shuffle(validation_idxs)
    test_idxs = np.arange(len(X_test)); np.random.shuffle(test_idxs)
    X_train = X_train[train_idxs,:,:]; y_train = y_train[train_idxs,:]
    X_validation = X_validation[validation_idxs,:,:]; y_validation = y_validation[validation_idxs,:]
    X_test = X_test[test_idxs,:,:]; y_test = y_test[test_idxs,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

# def _load_data_meg_anonymous func
def _load_data_meg_anonymous(load_params):
    """
    Load meg data from the specified subject in `meg_anonymous`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Make sure that `load_params.data_modality` has at least one data modality.
    assert len(load_params.data_modality) > 0
    # Initialize path_run.
    path_run = os.path.join(paths.base, "data", "meg.anonymous", "008", "20230308")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    datasets = utils.DotDict(); dataset_names = list(set(load_params.trainset) | set(load_params.testset))
    for dataset_name_i in dataset_names:
        # Load data from specified dataset name.
        session_type_i = "-".join(dataset_name_i.split("-")[:-1]); data_type_i = dataset_name_i.split("-")[-1]
        func_i = getattr(utils.data.meg.anonymous, "_".join(["load_run", session_type_i.split("-")[0], load_params.type]))
        Xs_i, y_i = func_i(path_run, session_type="-".join(session_type_i.split("-")[1:]))
        # Check whether current dataset has data items.
        if Xs_i[data_type_i] is None:
            msg = "WARNING: Skip dataset {} with no data item.".format(dataset_name_i)
            print(msg); paths.run.logger.summaries.info(msg); continue
        Xs_i = [Xs_i[data_type_i][data_modality_i].astype(np.float32)\
            for data_modality_i in load_params.data_modality]; y_i = y_i[data_type_i].astype(np.int64)
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        # If the type of dataset is `default`.
        if load_params.type == "default":
            Xs_i = [X_i[:,20:100,:] for X_i in Xs_i]
        # If the type of dataset is `lvbj`.
        elif load_params.type == "lvbj":
            Xs_i = [X_i[:,20:100,:] for X_i in Xs_i]
        # Get unknown type of dataset.
        else:
            raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
        # Do cross-trial normalization.
        Xs_i = [(X_i - np.mean(X_i)) / np.std(X_i) for X_i in Xs_i]
        # Concatenate `X_i`s of `Xs_i` along `n_channels`-axis.
        Xs_i = np.concatenate(Xs_i, axis=-1)
        # Set the corresponding item of dataset.
        datasets[dataset_name_i] = utils.DotDict({"X":Xs_i,"y":y_i,})
    # Initialize trainset & testset.
    X_train = []; y_train = []; X_test = []; y_test = []
    for dataset_name_i, dataset_i in datasets.items():
        # If trainset and testset are the same dataset, split X into trainset & testset.
        if dataset_name_i in load_params.trainset and dataset_name_i in load_params.testset:
            # Construct dataset from data items.
            train_ratio = params.train.train_ratio
            X_train.append(dataset_i.X[:int(train_ratio * dataset_i.X.shape[0]),:,:])
            y_train.append(dataset_i.y[:int(train_ratio * dataset_i.y.shape[0])])
            X_test.append(dataset_i.X[int(train_ratio * dataset_i.X.shape[0]):,:,:])
            y_test.append(dataset_i.y[int(train_ratio * dataset_i.y.shape[0]):])
        # If trainset and testset are not the same, construct trainset & testset separately.
        elif dataset_name_i in load_params.trainset and dataset_name_i not in load_params.testset:
            X_train.append(dataset_i.X); y_train.append(dataset_i.y)
        # If trainset and testset are not the same, construct trainset & testset separately.
        elif dataset_name_i not in load_params.trainset and dataset_name_i in load_params.testset:
            X_test.append(dataset_i.X); y_test.append(dataset_i.y)
        # Wrong cases.
        else:
            raise ValueError("ERROR: Unknown dataset name {}.".format(dataset_name_i))
    # Check whether trainset & testset both have data items.
    if len(X_train) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    X_train = np.concatenate(X_train, axis=0); y_train = np.concatenate(y_train, axis=0)
    X_test = np.concatenate(X_test, axis=0); y_test = np.concatenate(y_test, axis=0)
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to one-hot encoding.
    try:
        assert len(set(y_train)) == len(set(y_test)) == 15; labels = sorted(set(y_train))
    except AssertionError as e:
        msg = (
            "WARNING: Skip experiment (train:{};test:{}) due to that the classes of test cases are not enough."
        ).format(set(load_params.trainset), set(load_params.testset))
        print(msg); paths.run.logger.summaries.info(msg); return ([], []), ([], []), ([], [])
    # y - (n_samples, n_labels)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64); y_train = np.eye(len(labels))[y_train]
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64); y_test = np.eye(len(labels))[y_test]
    # Downsample trainset & testset to test the time redundancy.
    if load_params.type == "default":
        sample_rate = 100
    elif load_params.type == "lvbj":
        sample_rate = 100
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    assert sample_rate / load_params.resample_rate == sample_rate // load_params.resample_rate
    downsample_rate = sample_rate // load_params.resample_rate
    if downsample_rate != 1:
        assert X_train.shape[1] / downsample_rate == X_train.shape[1] // downsample_rate
        X_train = np.concatenate([X_train[:,np.arange(start_i, X_train.shape[1], downsample_rate),:]\
            for start_i in range(downsample_rate)], axis=0)
        y_train = np.concatenate([y_train for _ in range(downsample_rate)], axis=0)
        assert X_test.shape[1] / downsample_rate == X_test.shape[1] // downsample_rate
        X_test = np.concatenate([X_test[:,np.arange(start_i, X_test.shape[1], downsample_rate),:]\
            for start_i in range(downsample_rate)], axis=0)
        y_test = np.concatenate([y_test for _ in range(downsample_rate)], axis=0)
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Further split test-set into validation-set & test-set.
    validation_idxs = np.random.choice(np.arange(X_test.shape[0]), size=int(X_test.shape[0]/2), replace=False)
    validation_mask = np.zeros((X_test.shape[0],), dtype=np.bool_); validation_mask[validation_idxs] = True
    X_validation = X_test[validation_mask,:,:]; y_validation = y_test[validation_mask,:]
    X_test = X_test[~validation_mask,:,:]; y_test = y_test[~validation_mask,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

# def _load_data_meg_lv2023cpnl func
def _load_data_meg_lv2023cpnl(load_params):
    """
    Load meg data from the specified subject in `meg_lv2023cpnl`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize path_run.
    path_run = os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub018")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    func = getattr(utils.data.meg.lv2023cpnl, "load_run_{}".format(load_params.type))
    Xs, y = func(path_run)
    Xs = [Xs["image"][data_modality_i].astype(np.float32)\
        for data_modality_i in load_params.data_modality]
    y = y["image"].astype(np.int64)
    # If the type of dataset is `default`.
    if load_params.type == "default":
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        Xs = [X_i[:,20:100,:] for X_i in Xs]
    # If the type of dataset is `lvbj`.
    elif load_params.type == "lvbj":
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        Xs = [X_i[:,20:100,:] for X_i in Xs]
    # If the type of dataset is `baseline`.
    elif load_params.type == "baseline":
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        Xs = [X_i[:,20:100,:] for X_i in Xs]
    # Get unknown type of dataset.
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    # Do cross-trial normalization.
    Xs = [(X_i - np.mean(X_i)) / np.std(X_i) for X_i in Xs]
    # Concatenate `X_i`s of `Xs` along `n_channels`-axis.
    X = np.concatenate(Xs, axis=-1)
    # Initialize trainset & testset.
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    train_ratio = params.train.train_ratio
    X_train = X[:int(train_ratio * X.shape[0]),:,:]
    y_train = y[:int(train_ratio * y.shape[0])]
    X_test = X[int(train_ratio * X.shape[0]):,:,:]
    y_test = y[int(train_ratio * y.shape[0]):]
    # Check whether trainset & testset both have data items.
    if len(X_train) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to sorted order.
    assert len(set(y_train)) == len(set(y_test)) == 12; labels = sorted(set(y_train))
    # y - (n_samples, n_labels)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64); y_train = np.eye(len(labels))[y_train]
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64); y_test = np.eye(len(labels))[y_test]
    # Downsample trainset & testset to test the time redundancy.
    if load_params.type == "default":
        sample_rate = 100
    elif load_params.type == "lvbj":
        sample_rate = 100
    elif load_params.type == "baseline":
        sample_rate = 100
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    assert sample_rate / load_params.resample_rate == sample_rate // load_params.resample_rate
    downsample_rate = sample_rate // load_params.resample_rate
    if downsample_rate != 1:
        assert X_train.shape[1] / downsample_rate == X_train.shape[1] // downsample_rate
        X_train = np.concatenate([X_train[:,np.arange(start_i, X_train.shape[1], downsample_rate),:]\
            for start_i in range(downsample_rate)], axis=0)
        y_train = np.concatenate([y_train for _ in range(downsample_rate)], axis=0)
        assert X_test.shape[1] / downsample_rate == X_test.shape[1] // downsample_rate
        X_test = np.concatenate([X_test[:,np.arange(start_i, X_test.shape[1], downsample_rate),:]\
            for start_i in range(downsample_rate)], axis=0)
        y_test = np.concatenate([y_test for _ in range(downsample_rate)], axis=0)
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Further split test-set into validation-set & test-set.
    validation_idxs = np.random.choice(np.arange(X_test.shape[0]), size=int(X_test.shape[0]/2), replace=False)
    validation_mask = np.zeros((X_test.shape[0],), dtype=np.bool_); validation_mask[validation_idxs] = True
    X_validation = X_test[validation_mask,:,:]; y_validation = y_test[validation_mask,:]
    X_test = X_test[~validation_mask,:,:]; y_test = y_test[~validation_mask,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

# def _load_data_seeg_he2023xuanwu func
def _load_data_seeg_he2023xuanwu(load_params):
    """
    Load seeg data from the specified subject in `seeg_he2023xuanwu`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize path_run.
    path_run = os.path.join(paths.base, "data", "seeg.he2023xuanwu", "010")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    func = getattr(getattr(utils.data.seeg.he2023xuanwu, load_params.task), "load_subj_{}".format(load_params.type))
    X, y = func(path_run)
    X = X["speak"]["seeg"].astype(np.float32); y = y["speak"].astype(np.int64)
    # If the type of dataset is `default`.
    if load_params.type == "default":
        # Truncate `X` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X - (n_samples, seq_len, n_channels)
        X = X
        # Resample the original data to the specified `resample_rate`.
        sample_rate = 1000; X = sp.signal.resample(X, int(np.round(X.shape[1] /\
            (sample_rate / load_params.resample_rate))), axis=1)
        # Do cross-trial normalization.
        X = (X - np.mean(X)) / np.std(X)
    # If the type of dataset is `epoch`.
    elif load_params.type == "epoch":
        # Truncate `X` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X - (n_samples, seq_len, n_channels)
        X = X
        # Resample the original data to the specified `resample_rate`.
        sample_rate = 1000; X = sp.signal.resample(X, int(np.round(X.shape[1] /\
            (sample_rate / load_params.resample_rate))), axis=1)
        # Do cross-trial normalization.
        X = (X - np.mean(X)) / np.std(X)
    # If the type of dataset is `moses2021neuroprosthesis`.
    elif load_params.type == "moses2021neuroprosthesis":
        # Truncate `X` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X - (n_samples, seq_len, n_channels)
        X = X
        # Resample the original data to the specified `resample_rate`.
        sample_rate = 200; X = sp.signal.resample(X, int(np.round(X.shape[1] /\
            (sample_rate / load_params.resample_rate))), axis=1)
        # Do cross-trial normalization.
        X = (X - np.mean(X)) / np.std(X)
    # If the type of dataset is `bipolar`.
    elif load_params.type.startswith("bipolar"):
        # Truncate `X` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X - (n_samples, seq_len, n_channels)
        X = X
        # Resample the original data to the specified `resample_rate`.
        sample_rate = 1000; X = sp.signal.resample(X, int(np.round(X.shape[1] /\
            (sample_rate / load_params.resample_rate))), axis=1)
        # Truncate data according to epoch range (-0.2,1.0), the original epoch range is (-0.5,2.0).
        X = X[:,int(np.round((-0.2 - (-0.5)) * load_params.resample_rate)):\
            int(np.round((1.0 - (-0.5)) * load_params.resample_rate)),:]
        # Do Z-score for each channel.
        X = (X - np.mean(X, axis=1, keepdims=True)) / np.std(X, axis=1, keepdims=True)
    # Get unknown type of dataset.
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    # Initialize trainset & testset.
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    train_ratio = params.train.train_ratio; train_idxs = []; test_idxs = []
    for label_i in sorted(set(y)):
        label_idxs = np.where(y == label_i)[0].tolist()
        train_idxs.extend(label_idxs[:int(train_ratio * len(label_idxs))])
        test_idxs.extend(label_idxs[int(train_ratio * len(label_idxs)):])
    for train_idx in train_idxs: assert train_idx not in test_idxs
    train_idxs = np.array(train_idxs, dtype=np.int64); test_idxs = np.array(test_idxs, dtype=np.int64)
    X_train = X[train_idxs,:,:]; y_train = y[train_idxs]; X_test = X[test_idxs,:,:]; y_test = y[test_idxs]
    # Check whether trainset & testset both have data items.
    if len(X_train) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to sorted order.
    assert len(set(y_train)) == len(set(y_test)) == params.model.n_labels; labels = sorted(set(y_train))
    # y - (n_samples, n_labels)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64); y_train = np.eye(len(labels))[y_train]
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64); y_test = np.eye(len(labels))[y_test]
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Further split test-set into validation-set & test-set.
    validation_idxs = np.random.choice(np.arange(X_test.shape[0]), size=int(X_test.shape[0]/2), replace=False)
    validation_mask = np.zeros((X_test.shape[0],), dtype=np.bool_); validation_mask[validation_idxs] = True
    X_validation = X_test[validation_mask,:,:]; y_validation = y_test[validation_mask,:]
    X_test = X_test[~validation_mask,:,:]; y_test = y_test[~validation_mask,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

# def _load_data_meg_hebart2023things func
def _load_data_meg_hebart2023things(load_params):
    """
    Load meg data from the specified subject in `meg_hebart2023things`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize pattern to match subject.
    pattern = re.compile(r"sub-BIGMEG(\d+)")
    # Initialize path_dataset.
    if not hasattr(load_params, "path_run"):
        path_dataset = os.path.join(paths.base, "data", "meg.hebart2023things"); subj_idx = 1
    else:
        path_dataset, subj_i = os.path.split(load_params.path_run)
        assert pattern.match(subj_i) is not None; subj_idx = int(pattern.findall(subj_i)[0])
    # Load data from specified subject.
    func = getattr(utils.data.meg.hebart2023things, "load_data_{}".format(load_params.type))
    data = func(path_dataset, subj_idx=subj_idx, load_image=load_params.load_image)
    # Get `X` & `y` from `data`.
    X = np.array([data_i.image.data.mag for data_i in data], dtype=np.float32)
    labels = [data_i.image.name for data_i in data]
    labels_testset = sorted(set([label_i for label_i, count_i in Counter(labels).items()\
        if (count_i > 1) and ("catch" not in label_i)]))[:params.model.n_labels]
    print((
        "INFO: The number of detected test images is {:d}, but the images used to test signal quality include ({})."
    ).format(len(labels_testset), labels_testset))
    # Use `labels_testset` to get `test_idxs`.
    test_idxs = [label_idx for label_idx, label_i in enumerate(labels) if label_i in labels_testset]
    X = np.transpose(np.array([X[test_idx,:,:] for test_idx in test_idxs], dtype=X.dtype), axes=[0,2,1])
    labels = [labels[test_idx] for test_idx in test_idxs]
    # Use the remaining `X` & `labels` to construct `x` & `y`.
    labelset = sorted(set(labels)); y = np.array([labelset.index(label_i) for label_i in labels], dtype=np.int64)
    # If the type of dataset is `preprocessed`.
    if load_params.type == "preprocessed":
        # Truncate `X` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X - (n_samples, seq_len, n_channels)
        X = X[:,20:220,:]
        # Resample the original data to the specified `resample_rate`.
        sample_rate = 200; X = sp.signal.resample(X, int(np.round(X.shape[1] /\
            (sample_rate / load_params.resample_rate))), axis=1)
        # Do cross-trial normalization.
        X = (X - np.mean(X)) / np.std(X)
    # If the type of dataset is `whole`.
    elif load_params.type == "whole":
        # Truncate `X` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X - (n_samples, seq_len, n_channels)
        X = X
        # Resample the original data to the specified `resample_rate`.
        sample_rate = 1000; X = sp.signal.resample(X, int(np.round(X.shape[1] /\
            (sample_rate / load_params.resample_rate))), axis=1)
        # Do cross-trial normalization.
        X = (X - np.mean(X)) / np.std(X)
    # Get unknown type of dataset.
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    # Initialize trainset & testset.
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    train_ratio = params.train.train_ratio; train_idxs = []; test_idxs = []
    for label_i in sorted(set(y)):
        label_idxs = np.where(y == label_i)[0].tolist()
        train_idxs.extend(label_idxs[:int(train_ratio * len(label_idxs))])
        test_idxs.extend(label_idxs[int(train_ratio * len(label_idxs)):])
    for train_idx in train_idxs: assert train_idx not in test_idxs
    train_idxs = np.array(train_idxs, dtype=np.int64); test_idxs = np.array(test_idxs, dtype=np.int64)
    X_train = X[train_idxs,:,:]; y_train = y[train_idxs]; X_test = X[test_idxs,:,:]; y_test = y[test_idxs]
    # Check whether trainset & testset both have data items.
    if len(X_train) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to sorted order.
    assert len(set(y_train)) == len(set(y_test)) == params.model.n_labels; labels = sorted(set(y_train))
    # y - (n_samples, n_labels)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64); y_train = np.eye(len(labels))[y_train]
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64); y_test = np.eye(len(labels))[y_test]
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Further split test-set into validation-set & test-set.
    validation_idxs = np.random.choice(np.arange(X_test.shape[0]), size=int(X_test.shape[0]/2), replace=False)
    validation_mask = np.zeros((X_test.shape[0],), dtype=np.bool_); validation_mask[validation_idxs] = True
    X_validation = X_test[validation_mask,:,:]; y_validation = y_test[validation_mask,:]
    X_test = X_test[~validation_mask,:,:]; y_test = y_test[~validation_mask,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

# def _load_data_eeg_palazzo2020decoding func
def _load_data_eeg_palazzo2020decoding(load_params):
    """
    Load meg data from the specified subject in `eeg_palazzo2020decoding`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize path_dataset.
    if not hasattr(load_params, "path_run"):
        path_dataset = os.path.join(paths.base, "data", "eeg.palazzo2020decoding"); subj_idx = 1
    else:
        path_dataset, subj_i = os.path.split(load_params.path_run); subj_idx = int(subj_i)
    # Load data from specified subject.
    func = getattr(utils.data.eeg.palazzo2020decoding, "load_data")
    X, y = func(path_dataset, subj_idx=subj_idx)
    X = X["image"].astype(np.float32); y = y["image"].astype(np.int64)
    # Truncate `X` to let them have the same length.
    # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
    # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
    # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
    # X - (n_samples, seq_len, n_channels)
    X = X
    # Do cross-trial normalization.
    X = (X - np.mean(X)) / np.std(X)
    # Initialize trainset & testset.
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    train_ratio = params.train.train_ratio; train_idxs = []; test_idxs = []
    for label_i in sorted(set(y)):
        label_idxs = np.where(y == label_i)[0].tolist()
        train_idxs.extend(label_idxs[:int(train_ratio * len(label_idxs))])
        test_idxs.extend(label_idxs[int(train_ratio * len(label_idxs)):])
    for train_idx in train_idxs: assert train_idx not in test_idxs
    train_idxs = np.array(train_idxs, dtype=np.int64); test_idxs = np.array(test_idxs, dtype=np.int64)
    X_train = X[train_idxs,:,:]; y_train = y[train_idxs]; X_test = X[test_idxs,:,:]; y_test = y[test_idxs]
    # Check whether trainset & testset both have data items.
    if len(X_train) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to sorted order.
    assert len(set(y_train)) == len(set(y_test)) == params.model.n_labels; labels = sorted(set(y_train))
    # y - (n_samples, n_labels)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64); y_train = np.eye(len(labels))[y_train]
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64); y_test = np.eye(len(labels))[y_test]
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Further split test-set into validation-set & test-set.
    validation_idxs = np.random.choice(np.arange(X_test.shape[0]), size=int(X_test.shape[0]/2), replace=False)
    validation_mask = np.zeros((X_test.shape[0],), dtype=np.bool_); validation_mask[validation_idxs] = True
    X_validation = X_test[validation_mask,:,:]; y_validation = y_test[validation_mask,:]
    X_test = X_test[~validation_mask,:,:]; y_test = y_test[~validation_mask,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

# def _load_data_eeg_gifford2022large func
def _load_data_eeg_gifford2022large(load_params):
    """
    Load meg data from the specified subject in `eeg_gifford2022large`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize path_dataset.
    if not hasattr(load_params, "path_run"):
        path_dataset = os.path.join(paths.base, "data", "eeg.gifford2022large"); subj_idx = 1
    else:
        path_dataset, subj_i = os.path.split(load_params.path_run); subj_idx = int(subj_i)
    # Load data from specified subject.
    func = getattr(utils.data.eeg.gifford2022large, "load_data_{}".format("_".join(load_params.type.split("."))))
    _, dataset = func(path_dataset, subj_idx=subj_idx, load_image=load_params.load_image)
    # If the type of dataset is `preprocessed`.
    if load_params.type.startswith("preprocessed"):
        # Initialize `X` & `y` according to `dataset`.
        labels = sorted(set([data_i.concept for data_i in dataset])); X = []; y = []
        for data_i in dataset:
            X.append(data_i.data.T); y.append(labels.index(data_i.concept))
        X = np.array(X, dtype=np.float32); y = np.array(y, dtype=np.int64)
        # Truncate `X` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X - (n_samples, seq_len, n_channels)
        X = X[:,27:60,:]
        # Do cross-trial normalization.
        X = (X - np.mean(X)) / np.std(X)
    # If the type of dataset is `raw`.
    elif load_params.type.startswith("raw"):
        # Initialize `X` & `y` according to `dataset`.
        labels = sorted(set([data_i.concept for data_i in dataset])); X = []; y = []; sess_idxs = []
        for data_i in dataset:
            X.append(data_i.data.T); y.append(labels.index(data_i.concept)); sess_idxs.append(data_i.session)
        X = np.array(X, dtype=np.float32); y = np.array(y, dtype=np.int64); sess_idxs = np.array(sess_idxs, dtype=np.int64)
        # Preprocess data according to each session.
        if load_params.use_sigma:
            X_ = []; y_ = []
            for sess_idx in sorted(set(sess_idxs.tolist())):
                # Use `sess_idx` to get the indices of data items.
                # X_i - (n_samples, seq_len, n_channels), y_i - (n_samples,)
                data_idxs = np.where(sess_idxs == sess_idx)[0]
                X_i = X[data_idxs,:,:]; y_i = y[data_idxs]
                # Calculate covariance matrices at each time point, and then average across time points.
                # sigma_i - (n_channels, n_channels)
                if load_params.use_sigma_cond:
                    cond_idxs = [np.where(y_i == label_i)[0] for label_i in sorted(set(y_i.tolist()))]
                    sigma_i = np.mean([np.mean([_cov(X_i[cond_idxs_i,time_idx,:], shrinkage="auto")\
                        for time_idx in range(X_i.shape[1])], axis=0) for cond_idxs_i in cond_idxs], axis=0)
                else:
                    sigma_i = np.mean([_cov(X_i[:,time_idx,:], shrinkage="auto")\
                        for time_idx in range(X_i.shape[1])], axis=0)
                # Calculate the inverse of the covariance matrix.
                # sigma_inv_i - (n_channels, n_channels)
                sigma_inv_i = sp.linalg.fractional_matrix_power(sigma_i, t=-0.5)
                # Whiten the data.
                X_i = X_i @ sigma_inv_i
                # Update `X_` & `y_`.
                X_.append(X_i); y_.append(y_i)
            X = np.concatenate(X_, axis=0); y = np.concatenate(y_, axis=0)
        # Truncate `X` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X - (n_samples, seq_len, n_channels)
        X = X[:,27:60,:]
        # Do cross-trial normalization.
        X = (X - np.mean(X)) / np.std(X)
    # Get unknown type of dataset.
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    # Initialize trainset & testset.
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    train_ratio = params.train.train_ratio; train_idxs = []; test_idxs = []
    for label_i in sorted(set(y)):
        label_idxs = np.where(y == label_i)[0].tolist()
        train_idxs.extend(label_idxs[:int(train_ratio * len(label_idxs))])
        test_idxs.extend(label_idxs[int(train_ratio * len(label_idxs)):])
    for train_idx in train_idxs: assert train_idx not in test_idxs
    train_idxs = np.array(train_idxs, dtype=np.int64); test_idxs = np.array(test_idxs, dtype=np.int64)
    X_train = X[train_idxs,:,:]; y_train = y[train_idxs]; X_test = X[test_idxs,:,:]; y_test = y[test_idxs]
    # Check whether trainset & testset both have data items.
    if len(X_train) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to sorted order.
    assert len(set(y_train)) == len(set(y_test)) == params.model.n_labels; labels = sorted(set(y_train))
    # y - (n_samples, n_labels)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64); y_train = np.eye(len(labels))[y_train]
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64); y_test = np.eye(len(labels))[y_test]
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Further split test-set into validation-set & test-set.
    validation_idxs = np.random.choice(np.arange(X_test.shape[0]), size=int(X_test.shape[0]/2), replace=False)
    validation_mask = np.zeros((X_test.shape[0],), dtype=np.bool_); validation_mask[validation_idxs] = True
    X_validation = X_test[validation_mask,:,:]; y_validation = y_test[validation_mask,:]
    X_test = X_test[~validation_mask,:,:]; y_test = y_test[~validation_mask,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

"""
train funcs
"""
# def train func
def train(base_, params_):
    """
    Train the model.
    :param base_: The base path of current project.
    :param params_: The parameters of current training process.
    """
    global params, paths
    # Initialize parameters & variables of current training process.
    init(base_, params_)
    # Log the start of current training process.
    paths.run.logger.summaries.info("Training started with dataset {}.".format(params.train.dataset))
    # Initialize load_params. Each load_params_i corresponds to a sub-dataset.
    if params.train.dataset == "eeg_anonymous":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            os.path.join(paths.base, "data", "demo-dataset", "001"),
            os.path.join(paths.base, "data", "demo-dataset", "002"),
            os.path.join(paths.base, "data", "demo-dataset", "003"),
        ] * 20; load_type = "default"; n_samples = utils.DotDict({"trainset":None,})
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "testset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-task-all-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "testset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "n_samples": n_samples,
            }),
        ]
    elif params.train.dataset == "meg_anonymous":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            #os.path.join(paths.base, "data", "meg.anonymous", "004", "20221218"),
            #os.path.join(paths.base, "data", "meg.anonymous", "007", "20230308"),
            #os.path.join(paths.base, "data", "meg.anonymous", "009", "20230310"),
            #os.path.join(paths.base, "data", "meg.anonymous", "011", "20230314"),
            #os.path.join(paths.base, "data", "meg.anonymous", "012", "20230314"),
            #os.path.join(paths.base, "data", "meg.anonymous", "015", "20230325"),
            #os.path.join(paths.base, "data", "meg.anonymous", "016", "20230325"),
            #os.path.join(paths.base, "data", "meg.anonymous", "018", "20230406"),
            #os.path.join(paths.base, "data", "meg.anonymous", "019", "20230410"),
            #os.path.join(paths.base, "data", "meg.anonymous", "022", "20230414"),
            #os.path.join(paths.base, "data", "meg.anonymous", "023", "20230414"),
            #os.path.join(paths.base, "data", "meg.anonymous", "024", "20230417"),
            #os.path.join(paths.base, "data", "meg.anonymous", "029", "20230424"),
            #os.path.join(paths.base, "data", "meg.anonymous", "030", "20230425"),
            #os.path.join(paths.base, "data", "meg.anonymous", "031", "20230425"),
            #os.path.join(paths.base, "data", "meg.anonymous", "032", "20230427"),
            #os.path.join(paths.base, "data", "meg.anonymous", "033", "20230427"),
            #os.path.join(paths.base, "data", "meg.anonymous", "034", "20230504"),
            #os.path.join(paths.base, "data", "meg.anonymous", "036", "20230508"),
            #os.path.join(paths.base, "data", "meg.anonymous", "038", "20230511"),
            #os.path.join(paths.base, "data", "meg.anonymous", "039", "20230511"),
            #os.path.join(paths.base, "data", "meg.anonymous", "041", "20230512"),
            #os.path.join(paths.base, "data", "meg.anonymous", "042", "20230516"),
            #os.path.join(paths.base, "data", "meg.anonymous", "043", "20230516"),
            #os.path.join(paths.base, "data", "meg.anonymous", "044", "20230517"),
            #os.path.join(paths.base, "data", "meg.anonymous", "045", "20230517"),
            #os.path.join(paths.base, "data", "meg.anonymous", "046", "20230518"),
            #os.path.join(paths.base, "data", "meg.anonymous", "047", "20230518"),
            #os.path.join(paths.base, "data", "meg.anonymous", "048", "20230522"),
            #os.path.join(paths.base, "data", "meg.anonymous", "049", "20230522"),
            #os.path.join(paths.base, "data", "meg.anonymous", "051", "20230523"),
            os.path.join(paths.base, "data", "meg.anonymous", "052", "20230525"),
            os.path.join(paths.base, "data", "meg.anonymous", "054", "20230530"),
            os.path.join(paths.base, "data", "meg.anonymous", "055", "20230530"),
        ]; load_type = "lvbj"; load_modality = ["eeg", "mag", "grad",][-1:]
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-all-test-task-all-all
            utils.DotDict({
                "name": "train-task-all-all-test-task-all-all",
                "trainset": [
                    "task-image-audio-audio", "task-image-audio-image",
                    "task-audio-image-audio", "task-audio-image-image",
                ],
                "testset": [
                    "task-image-audio-audio", "task-image-audio-image",
                    "task-audio-image-audio", "task-audio-image-image",
                ],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image",
                "trainset": ["task-image-audio-image", "task-audio-image-image",],
                "testset": ["task-image-audio-image", "task-audio-image-image",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-task-all-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-task-all-audio",
                "trainset": ["task-image-audio-audio", "task-audio-image-audio",],
                "testset": ["task-image-audio-audio", "task-audio-image-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-tmr-n23-audio-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-rem-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-tmr-n23-audio-test-task-all-image
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-task-all-image",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["task-image-audio-image", "task-audio-image-image",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-tmr-n23-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-task-all-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["task-image-audio-audio", "task-audio-image-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-task-all-image-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-image-test-tmr-n23-audio",
                "trainset": ["task-image-audio-image", "task-audio-image-image",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-task-all-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-tmr-n23-audio",
                "trainset": ["task-image-audio-audio", "task-audio-image-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-task-all-image-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-task-all-image-test-tmr-rem-audio",
                "trainset": ["task-image-audio-image", "task-audio-image-image",],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-task-all-audio-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-tmr-rem-audio",
                "trainset": ["task-image-audio-audio", "task-audio-image-audio",],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-task-all-image-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-image-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["task-image-audio-image", "task-audio-image-image", "tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
            # train-task-all-audio-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-audio-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["task-image-audio-audio", "task-audio-image-audio", "tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
        ]
    elif params.train.dataset == "meg_lv2023cpnl":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub001"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub002"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub003"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub004"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub005"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub006"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub007"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub008"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub009"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub010"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub011"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub012"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub013"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub014"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub015"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub016"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub017"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub018"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub019"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub020"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub021"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub022"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub023"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub024"),
            os.path.join(paths.base, "data", "meg.lv2023cpnl", "sz-sub001"),
            os.path.join(paths.base, "data", "meg.lv2023cpnl", "sz-sub002"),
        ]; load_type = "lvbj"; load_modality = ["mag", "grad"][-1:]
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image", "type": load_type,
                "permutation": False, "resample_rate": 100, "data_modality": load_modality,
            }),
        ]
    elif params.train.dataset == "seeg_he2023xuanwu":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "003"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "004"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "005"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "006"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "007"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "008"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "009"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "010"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "011"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "012"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "013"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "014"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "015"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "016"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "017"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "018"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "019"),
            os.path.join(paths.base, "data", "seeg.he2023xuanwu", "020"),
            #os.path.join(paths.base, "data", "seeg.he2023xuanwu", "021"),
        ]; load_type = "bipolar_default"; load_task = "word_recitation"
        # Set `resample_rate` according to `load_type`.
        if load_type == "default":
            resample_rate = 1000
        elif load_type == "epoch":
            resample_rate = 1000
        elif load_type == "moses2021neuroprosthesis":
            resample_rate = 200
        elif load_type.startswith("bipolar"):
            resample_rate = 200
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-speak-test-task-all-speak
            utils.DotDict({
                "name": "train-task-all-speak-test-task-all-speak", "type": load_type,
                "permutation": False, "resample_rate": resample_rate, "task": load_task,
            }),
        ]
    elif params.train.dataset == "meg_hebart2023things":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            os.path.join(paths.base, "data", "meg.hebart2023things", "sub-BIGMEG1"),
            os.path.join(paths.base, "data", "meg.hebart2023things", "sub-BIGMEG2"),
            os.path.join(paths.base, "data", "meg.hebart2023things", "sub-BIGMEG3"),
            os.path.join(paths.base, "data", "meg.hebart2023things", "sub-BIGMEG4"),
        ]; load_type = "preprocessed"; load_image = False
        # Set `resample_rate` according to `load_type`.
        if load_type == "preprocessed":
            resample_rate = 200
        elif load_type == "whole":
            resample_rate = 1000
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image", "type": load_type,
                "permutation": False, "resample_rate": resample_rate, "load_image": load_image,
            }),
        ]
    elif params.train.dataset == "eeg_palazzo2020decoding":
        # Initialize the paths of runs that we want to execute experiments.
        source_dataset = [
            # Possible normal dataset.
            "eeg_5_95_std.npy", "eeg_14_70_std.npy", "eeg_55_95_std.npy",
            # Possible wrong dataset. Cannot pass sample unique check.
            "eeg_signals_raw_with_mean_std.npy"
        ][0]
        path_runs = [
            #os.path.join(paths.base, "data", "eeg.palazzo2020decoding", source_dataset, "0"),
            os.path.join(paths.base, "data", "eeg.palazzo2020decoding", source_dataset, "1"),
            os.path.join(paths.base, "data", "eeg.palazzo2020decoding", source_dataset, "2"),
            os.path.join(paths.base, "data", "eeg.palazzo2020decoding", source_dataset, "3"),
            os.path.join(paths.base, "data", "eeg.palazzo2020decoding", source_dataset, "4"),
            os.path.join(paths.base, "data", "eeg.palazzo2020decoding", source_dataset, "5"),
            os.path.join(paths.base, "data", "eeg.palazzo2020decoding", source_dataset, "6"),
        ]
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image", "permutation": False,
            }),
        ]
    elif params.train.dataset == "eeg_gifford2022large":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            os.path.join(paths.base, "data", "eeg.gifford2022large", "1"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "2"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "3"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "4"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "5"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "6"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "7"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "8"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "9"),
            #os.path.join(paths.base, "data", "eeg.gifford2022large", "10"),
        ]; load_type = "preprocessed.normal"; load_image = False; use_sigma = True; use_sigma_cond = True
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image",
                "permutation": False, "type": load_type, "load_image": load_image,
                "use_sigma": use_sigma, "use_sigma_cond": use_sigma_cond,
            }),
        ]
    else:
        raise ValueError("ERROR: Unknown dataset {} in train.naive_cnn.".format(params.train.dataset))
    # Execute experiments for each dataset run.
    for path_run_i in path_runs:
        # Log the start of current training iteration.
        msg = "Training started with sub-dataset {}.".format(path_run_i)
        print(msg); paths.run.logger.summaries.info(msg)
        for load_params_idx in range(len(load_params)):
            # Add `path_run` to `load_params_i`.
            load_params_i = cp.deepcopy(load_params[load_params_idx]); load_params_i.path_run = path_run_i
            # Load data from specified experiment.
            dataset_train, dataset_validation, dataset_test = load_data(load_params_i)
            # Check whether train-set & validation-set & test-set exists.
            if len(dataset_train[0]) == 0 or len(dataset_validation[0]) == 0 or len(dataset_test[0]) == 0:
                msg = (
                    "INFO: Skip experiment {} with train-set ({:d} items)" +\
                    " & validation-set ({:d} items) & test-set ({:d} items)."
                ).format(load_params_i.name, len(dataset_train[0]), len(dataset_validation[0]), len(dataset_test[0]))
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Start training process of current specified experiment.
            msg = "Start the training process of experiment {}.".format(load_params_i.name)
            print(msg); paths.run.logger.summaries.info(msg)

            # Train the model for each time segment.
            accuracy_validation = []; accuracy_test = []

            # Initialize model of current time segment.
            model = naive_cnn_model(params.model)
            for epoch_idx in range(params.train.n_epochs):
                # Record the start time of preparing data.
                time_start = time.time()
                # Fit the model using [X_train,y_train].
                model.fit(dataset_train[0], dataset_train[1], epochs=1, batch_size=params.train.batch_size)
                # Predict the corresponding `y_pred` of `X_validation`.
                _, accuracy_validation_i = model.evaluate(dataset_validation[0], dataset_validation[1])
                accuracy_validation.append(accuracy_validation_i)
                # Predict the corresponding `y_pred` of `X_test`.
                _, accuracy_test_i = model.evaluate(dataset_test[0], dataset_test[1]); accuracy_test.append(accuracy_test_i)
                # Record current time segment.
                time_stop = time.time()
                msg = (
                    "Finish epoch {:d} in {:.2f} seconds, with validation-accuracy ({:.2f}%) and test-accuracy ({:.2f}%)."
                ).format(epoch_idx, time_stop-time_start, accuracy_validation_i*100., accuracy_test_i*100.)
                print(msg); paths.run.logger.summaries.info(msg)
                # Summarize model information.
                if epoch_idx == 0:
                    model.summary(print_fn=print); model.summary(print_fn=paths.run.logger.summaries.info)
            # Convert `accuracy_validation` & `accuracy_test` to `np.array`.
            accuracy_validation = np.round(np.array(accuracy_validation, dtype=np.float32), decimals=4)
            accuracy_test = np.round(np.array(accuracy_test, dtype=np.float32), decimals=4)
            epoch_maxacc_idxs = np.where(accuracy_validation == np.max(accuracy_validation))[0]
            epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(accuracy_test[epoch_maxacc_idxs])]
            # Finish training process of current specified experiment.
            msg = (
                "Finish the training process of experiment {}, with test-accuracy ({:.2f}%)" +\
                " according to max validation-accuracy ({:.2f}%) at epoch {:d}."
            ).format(load_params_i.name, accuracy_test[epoch_maxacc_idx]*100.,
                accuracy_validation[epoch_maxacc_idx]*100., epoch_maxacc_idx)
            print(msg); paths.run.logger.summaries.info(msg)
        # Log the end of current training iteration.
        msg = "Training finished with sub-dataset {}.".format(path_run_i)
        print(msg); paths.run.logger.summaries.info(msg)
    # Log the end of current training process.
    msg = "Training finished with dataset {}.".format(params.train.dataset)
    print(msg); paths.run.logger.summaries.info(msg)

if __name__ == "__main__":
    import os
    # local dep
    from params.naive_cnn_params import naive_cnn_params

    # macro
    dataset = "eeg_anonymous"

    # Initialize random seed.
    utils.model.set_seeds(4)

    ## Instantiate naive_cnn.
    # Initialize base.
    base = os.path.join(os.getcwd(), os.pardir)
    # Instantiate naive_cnn_params.
    naive_cnn_params_inst = naive_cnn_params(dataset=dataset)
    # Train naive_cnn.
    train(base, naive_cnn_params_inst)

