#!/usr/bin/env python3
"""
Created on 16:25, Dec. 21st, 2022

@author: Anonymous
"""
import time
import copy as cp
import numpy as np
from sklearn import datasets
from collections import Counter
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.pardir)
import utils; import utils.model
import utils.data.eeg; import utils.data.meg
from models.lasso_glm import lasso_glm as lasso_glm_model

__all__ = [
    "train",
]

# Global variables.
params = None; paths = None

"""
init funcs
"""
# def init func
def init(base_, params_):
    """
    Initialize `lasso_glm` training variables.

    Args:
        base_: The base path of current project.
        params_: The parameters of current training process.

    Returns:
        None
    """
    global params, paths
    # Initialize params.
    params = cp.deepcopy(params_)
    paths = utils.Paths(base=base_, params=params)
    # Initialize model.
    _init_model()
    # Initialize training process.
    _init_train()

# def _init_model func
def _init_model():
    """
    Initialize model used in the training process.
    """
    pass

# def _init_train func
def _init_train():
    """
    Initialize the training process.
    """
    pass

"""
data funcs
"""
# def load_data func
def load_data(load_params):
    """
    Load data from specified subject.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params
    # Load data from specified dataset.
    try:
        func = getattr(sys.modules[__name__], "_".join(["_load_data", params.train.dataset]))
        dataset_train_, dataset_validation_, dataset_test_ = func(load_params)
    except Exception:
        raise ValueError("ERROR: Unknown dataset type {} in train.logistic_regression.".format(params.train.dataset))
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return dataset_train_, dataset_validation_, dataset_test_

# def _load_data_eeg_anonymous func
def _load_data_eeg_anonymous(load_params):
    """
    Load eeg data from the specified subject in `eeg_anonymous`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize path_run.
    path_run = os.path.join(paths.base, "data", "eeg.anonymous", "020", "20230405")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    datasets = utils.DotDict(); dataset_names = list(set(load_params.trainset) | set(load_params.testset))
    for dataset_name_i in dataset_names:
        # Load data from specified dataset name.
        session_type_i = "-".join(dataset_name_i.split("-")[:-1]); data_type_i = dataset_name_i.split("-")[-1]
        func_i = getattr(utils.data.eeg.anonymous, "_".join(["load_run", session_type_i.split("-")[0]]))
        X_i, y_i = func_i(path_run, session_type="-".join(session_type_i.split("-")[1:]))
        # Check whether current dataset has data items.
        if X_i[data_type_i] is None:
            msg = "WARNING: Skip dataset {} with no data item.".format(dataset_name_i)
            print(msg); paths.run.logger.summaries.info(msg); continue
        X_i = X_i[data_type_i].astype(np.float32); y_i = y_i[data_type_i].astype(np.int64)
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        # If the type of dataset is `default`.
        if load_params.type == "default":
            X_i = X_i[:,20:100,:]
        # If the type of dataset is `lvbj`.
        elif load_params.type == "lvbj":
            X_i = X_i[:,20:100,:]
        # Downsample sequence data X.
        if load_params.downsample_rate > 1:
            assert X_i.shape[1] / load_params.downsample_rate == X_i.shape[1] // load_params.downsample_rate
            X_i = np.split(X_i, np.arange(load_params.downsample_rate, X_i.shape[1], load_params.downsample_rate), axis=1)
            X_i = np.concatenate([np.mean(x_i, axis=1, keepdims=True) for x_i in X_i], axis=1)
        # Do cross-trial normalization.
        X_i = (X_i - np.mean(X_i)) / np.var(X_i)
        # Set the corresponding item of dataset.
        datasets[dataset_name_i] = utils.DotDict({"X":X_i,"y":y_i,})
    # Initialize train-set & validation-set & test-set.
    X_train = []; y_train = []; X_validation = []; y_validation = []; X_test = []; y_test = []
    for dataset_name_i, dataset_i in datasets.items():
        # Force the number of samples belonging to different categories is the same.
        # Drop redandunt samples of each category to make sure that the number of samples in each category is equal.
        label_counter = Counter(dataset_i.y); min_samples = np.min(list(label_counter.values())); drop_idxs = []
        for label_i in sorted(label_counter.keys()):
            drop_idxs.extend(np.random.choice(np.where(dataset_i.y == label_i)[0],
                size=((label_counter[label_i] - min_samples),), replace=False).tolist())
        if len(drop_idxs) > 0: assert len(set(dataset_i.y[drop_idxs].tolist())) < len(set(dataset_i.y))
        assert len(set(drop_idxs)) == len(drop_idxs)
        keep_idxs = sorted(set(range(dataset_i.y.shape[0])) - set(drop_idxs))
        dataset_i.X = dataset_i.X[keep_idxs,:,:]; dataset_i.y = dataset_i.y[keep_idxs]
        # Make sure that the number of samples in each category is equal!
        label_counter = Counter(dataset_i.y); assert (np.diff(list(label_counter.values())) == 0).all()
        # If dataset belongs to both train-set and test-set, split X into train-set & validation-set & test-set.
        if dataset_name_i in load_params.trainset and dataset_name_i in load_params.testset:
            # Check whether the number of samples per category is satisfied before spliting.
            if not (np.array(list(label_counter.values())) >= 3).all():
                msg = (
                    "WARNING: Skip dataset {} with insufficient data items ({})"+\
                    " to be split into train-set & validation-set & test-set."
                ).format(dataset_name_i, label_counter)
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Get the number of samples cooresponding to train-set & validation-set & test-set.
            # If enables leave-one-out mode, the validation-set & test-set
            # from this dataset will have only one sample per category.
            if load_params.use_loo:
                # Initialize `n_samples_validation` & `n_samples_test` as default values (1s).
                n_samples_test = dict(zip(
                    sorted(label_counter.keys()), [1,]*len(sorted(label_counter.keys()))
                )); n_samples_validation = cp.deepcopy(n_samples_test)
                # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
                n_samples_train = dict(zip(
                    sorted(label_counter.keys()),
                    [(label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])\
                        for label_i in sorted(label_counter.keys())]
                ))
            # If enable pseudo-K-fold mode, the validation-set & test-set from this dataset
            # will have samples of the same size per category. The samples per category
            # is determined as fewer as possible, at least one sample per category.
            else:
                # Initialize `train_ratio` & `validation_ratio` & `test_ratio`.
                train_ratio = params.train.train_ratio; validation_ratio = test_ratio = (1. - train_ratio) / 2.
                # Get default values (1s) for `n_samples_test`, at least one sample per category.
                n_samples_test = dict(zip(
                    sorted(label_counter.keys()), [1,]*len(sorted(label_counter.keys()))
                ))
                # Loop all available labels to update `n_samples_test`.
                for label_i in n_samples_test.keys():
                    # Get `n_samples_test_i` according to `test_ratio`.
                    n_samples_test_i = int(label_counter[label_i] * test_ratio)
                    # Update `n_samples_test` only when `n_samples_test_i` is greater than default values (1s).
                    if n_samples_test_i > 1: n_samples_test[label_i] = n_samples_test_i
                # Copy `n_samples_test` to get `n_samples_validation`.
                n_samples_validation = cp.deepcopy(n_samples_test)
                # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
                n_samples_train = dict([(label_i,
                    (label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])
                ) for label_i in sorted(label_counter.keys())])
        # If dataset only belongs to train-set, append X to train-set.
        elif dataset_name_i in load_params.trainset and dataset_name_i not in load_params.testset:
            # Check whether the number of samples per category is satisfied as train-set.
            if not (np.array(list(label_counter.values())) >= 1).all():
                msg = (
                    "WARNING: Skip dataset {} with insufficient data items ({}) as train-set."
                ).format(dataset_name_i, label_counter)
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Get the number of samples cooresponding to train-set & validation-set & test-set.
            # Initialize `n_samples_validation` & `n_samples_test` as default values (0s).
            n_samples_test = dict(zip(
                sorted(label_counter.keys()), [0,]*len(sorted(label_counter.keys()))
            )); n_samples_validation = cp.deepcopy(n_samples_test)
            # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
            n_samples_train = dict([(label_i,
                (label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])
            ) for label_i in sorted(label_counter.keys())])
        # If dataset only belongs to test-set, split X into validation-set & test-set.
        elif dataset_name_i not in load_params.trainset and dataset_name_i in load_params.testset:
            # Check whether the number of samples per category is satisfied before spliting.
            if not (np.array(list(label_counter.values())) >= 2).all():
                msg = (
                    "WARNING: Skip dataset {} with insufficient data items ({})"+\
                    " to be split into validation-set & test-set."
                ).format(dataset_name_i, label_counter)
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Initialize `n_samples_validation` as default values.
            n_samples_validation = dict([(label_i, int(label_counter[label_i] / 2))\
                for label_i in sorted(label_counter.keys())])
            # Get `n_samples_test` according to `n_samples_validation`.
            n_samples_test = dict([(label_i,
                (label_counter[label_i] - n_samples_validation[label_i])
            ) for label_i in sorted(label_counter.keys())])
            # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
            n_samples_train = dict([(label_i,
                (label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])
            ) for label_i in sorted(label_counter.keys())])
        # Wrong cases.
        else:
            raise ValueError("ERROR: Unknown dataset name {}.".format(dataset_name_i))
        # Use `n_samples_*` to get `*_idxs`.
        train_idxs = []; validation_idxs = []; test_idxs = []
        for label_i in sorted(n_samples_test.keys()):
            test_idxs.extend(np.random.choice(
                np.where(dataset_i.y == label_i)[0],
                size=(n_samples_test[label_i],), replace=False
            ).tolist())
        for label_i in sorted(n_samples_validation.keys()):
            validation_idxs.extend(np.random.choice(
                list(set(np.where(dataset_i.y == label_i)[0]) - set(test_idxs)),
                size=(n_samples_validation[label_i],), replace=False
            ).tolist())
        train_idxs = sorted(set(range(dataset_i.y.shape[0])) - set(validation_idxs) - set(test_idxs))
        # Convert `*_idxs` to `np.array`.
        train_idxs = np.array(train_idxs, dtype=np.int64)
        validation_idxs = np.array(validation_idxs, dtype=np.int64)
        test_idxs = np.array(test_idxs, dtype=np.int64)
        # Get the corresponding `X_*_i` & `y_*_i`.
        X_train_i = dataset_i.X[train_idxs,:,:]; y_train_i = dataset_i.y[train_idxs]
        X_validation_i = dataset_i.X[validation_idxs,:,:]; y_validation_i = dataset_i.y[validation_idxs]
        X_test_i = dataset_i.X[test_idxs,:,:]; y_test_i = dataset_i.y[test_idxs]
        # Check whether `train_idxs` is consistent with `n_samples_train`.
        for label_i in sorted(n_samples_train.keys()): assert np.sum(y_train_i == label_i) == n_samples_train[label_i]
        # Append `X_*_i` & `y_*_i` to `X_*` & `y_*`.
        X_train.append(X_train_i); y_train.append(y_train_i)
        X_validation.append(X_validation_i); y_validation.append(y_validation_i)
        X_test.append(X_test_i); y_test.append(y_test_i)
    # Check whether train-set & validation-set & test-set all have data items.
    if len(X_train) == 0 or len(X_validation) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    X_train = np.concatenate(X_train, axis=0); y_train = np.concatenate(y_train, axis=0)
    X_validation = np.concatenate(X_validation, axis=0); y_validation = np.concatenate(y_validation, axis=0)
    X_test = np.concatenate(X_test, axis=0); y_test = np.concatenate(y_test, axis=0)
    # Check whether train-set & test-set are totally different, if True, randomly remove some items from trainset.
    if (len(set(load_params.trainset) & set(load_params.trainset)) == 0) and load_params.drop_ratio is not None:
        # Count the number of samples belonging to each category in train-set.
        label_counter = Counter(y_train); drop_idxs = []
        for label_i in sorted(label_counter.keys()):
            drop_idxs.extend(np.random.choice(np.where(y_train == label_i)[0],
                size=(int(label_counter[label_i] * load_params.drop_ratio),), replace=False
            ).tolist())
        if len(drop_idxs) > 0: assert len(set(y_train[drop_idxs].tolist())) < len(set(y_train))
        assert len(set(drop_idxs)) == len(drop_idxs)
        keep_idxs = list(set(range(y_train.shape[0])) - set(drop_idxs))
        # Drop `X_train` & `y_train` according to `keep_idxs`.
        X_train = X_train[keep_idxs,:,:]; y_train = y_train[keep_idxs]
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Make sure there is no overlap between X_train & X_validation.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_validation.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_validation[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Make sure there is no overlap between X_validation & X_test.
    samples_same = None; n_samples = 10; assert X_validation.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_validation.shape[1])
        sample_same_i = np.intersect1d(X_validation[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to sorted order.
    try:
        assert len(set(y_train)) == len(set(y_validation)) == len(set(y_test)) == 15; labels = sorted(set(y_train))
    except AssertionError as e:
        msg = (
            "WARNING: Skip experiment (train:{};test:{}) due to that the classes of test cases are not enough."
        ).format(set(load_params.trainset), set(load_params.testset))
        print(msg); paths.run.logger.summaries.info(msg); return ([], []), ([], []), ([], [])
    # y - (n_samples,)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64)
    y_validation = np.array([labels.index(y_i) for y_i in y_validation], dtype=np.int64)
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64)
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

# def _load_data_meg_anonymous func
def _load_data_meg_anonymous(load_params):
    """
    Load meg data from the specified subject in `meg_anonymous`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize path_run.
    path_run = os.path.join(paths.base, "data", "meg.anonymous", "008", "20230308")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    datasets = utils.DotDict(); dataset_names = list(set(load_params.trainset) | set(load_params.testset))
    for dataset_name_i in dataset_names:
        # Load data from specified dataset name.
        session_type_i = "-".join(dataset_name_i.split("-")[:-1]); data_type_i = dataset_name_i.split("-")[-1]
        func_i = getattr(utils.data.meg.anonymous, "_".join(["load_run", session_type_i.split("-")[0], load_params.type]))
        Xs_i, y_i = func_i(path_run, session_type="-".join(session_type_i.split("-")[1:]))
        # Check whether current dataset has data items.
        if Xs_i[data_type_i] is None:
            msg = "WARNING: Skip dataset {} with no data item.".format(dataset_name_i)
            print(msg); paths.run.logger.summaries.info(msg); continue
        Xs_i = [Xs_i[data_type_i][data_modality_i].astype(np.float32)\
            for data_modality_i in load_params.data_modality]; y_i = y_i[data_type_i].astype(np.int64)
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        # If the type of dataset is `default`.
        if load_params.type == "default":
            Xs_i = [X_i[:,20:100,:] for X_i in Xs_i]
        # If the type of dataset is `lvbj`.
        elif load_params.type == "lvbj":
            Xs_i = [X_i[:,20:100,:] for X_i in Xs_i]
        # Get unknown type of dataset.
        else:
            raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
        # Downsample sequence data X.
        if load_params.downsample_rate > 1:
            assert Xs_i[0].shape[1] / load_params.downsample_rate == Xs_i[0].shape[1] // load_params.downsample_rate
            Xs_i = [np.split(X_i, np.arange(load_params.downsample_rate,
                X_i.shape[1], load_params.downsample_rate), axis=1) for X_i in Xs_i]
            Xs_i = [np.concatenate([np.mean(x_i, axis=1, keepdims=True) for x_i in X_i], axis=1) for X_i in Xs_i]
        # Do cross-trial normalization.
        Xs_i = [(X_i - np.mean(X_i)) / np.std(X_i) for X_i in Xs_i]
        # Concatenate `X_i`s of `Xs_i` along `n_channels`-axis.
        Xs_i = np.concatenate(Xs_i, axis=-1)
        # Set the corresponding item of dataset.
        datasets[dataset_name_i] = utils.DotDict({"X":Xs_i,"y":y_i,})
    # Initialize train-set & validation-set & test-set.
    X_train = []; y_train = []; X_validation = []; y_validation = []; X_test = []; y_test = []
    for dataset_name_i, dataset_i in datasets.items():
        # Force the number of samples belonging to different categories is the same.
        # Drop redandunt samples of each category to make sure that the number of samples in each category is equal.
        label_counter = Counter(dataset_i.y); min_samples = np.min(list(label_counter.values())); drop_idxs = []
        for label_i in sorted(label_counter.keys()):
            drop_idxs.extend(np.random.choice(np.where(dataset_i.y == label_i)[0],
                size=((label_counter[label_i] - min_samples),), replace=False).tolist())
        if len(drop_idxs) > 0: assert len(set(dataset_i.y[drop_idxs].tolist())) < len(set(dataset_i.y))
        assert len(set(drop_idxs)) == len(drop_idxs)
        keep_idxs = sorted(set(range(dataset_i.y.shape[0])) - set(drop_idxs))
        dataset_i.X = dataset_i.X[keep_idxs,:,:]; dataset_i.y = dataset_i.y[keep_idxs]
        # Make sure that the number of samples in each category is equal!
        label_counter = Counter(dataset_i.y); assert (np.diff(list(label_counter.values())) == 0).all()
        # If dataset belongs to both train-set and test-set, split X into train-set & validation-set & test-set.
        if dataset_name_i in load_params.trainset and dataset_name_i in load_params.testset:
            # Check whether the number of samples per category is satisfied before spliting.
            if not (np.array(list(label_counter.values())) >= 3).all():
                msg = (
                    "WARNING: Skip dataset {} with insufficient data items ({})"+\
                    " to be split into train-set & validation-set & test-set."
                ).format(dataset_name_i, label_counter)
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Get the number of samples cooresponding to train-set & validation-set & test-set.
            # If enables leave-one-out mode, the validation-set & test-set
            # from this dataset will have only one sample per category.
            if load_params.use_loo:
                # Initialize `n_samples_validation` & `n_samples_test` as default values (1s).
                n_samples_test = dict(zip(
                    sorted(label_counter.keys()), [1,]*len(sorted(label_counter.keys()))
                )); n_samples_validation = cp.deepcopy(n_samples_test)
                # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
                n_samples_train = dict(zip(
                    sorted(label_counter.keys()),
                    [(label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])\
                        for label_i in sorted(label_counter.keys())]
                ))
            # If enable pseudo-K-fold mode, the validation-set & test-set from this dataset
            # will have samples of the same size per category. The samples per category
            # is determined as fewer as possible, at least one sample per category.
            else:
                # Initialize `train_ratio` & `validation_ratio` & `test_ratio`.
                train_ratio = params.train.train_ratio; validation_ratio = test_ratio = (1. - train_ratio) / 2.
                # Get default values (1s) for `n_samples_test`, at least one sample per category.
                n_samples_test = dict(zip(
                    sorted(label_counter.keys()), [1,]*len(sorted(label_counter.keys()))
                ))
                # Loop all available labels to update `n_samples_test`.
                for label_i in n_samples_test.keys():
                    # Get `n_samples_test_i` according to `test_ratio`.
                    n_samples_test_i = int(label_counter[label_i] * test_ratio)
                    # Update `n_samples_test` only when `n_samples_test_i` is greater than default values (1s).
                    if n_samples_test_i > 1: n_samples_test[label_i] = n_samples_test_i
                # Copy `n_samples_test` to get `n_samples_validation`.
                n_samples_validation = cp.deepcopy(n_samples_test)
                # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
                n_samples_train = dict([(label_i,
                    (label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])
                ) for label_i in sorted(label_counter.keys())])
        # If dataset only belongs to train-set, append X to train-set.
        elif dataset_name_i in load_params.trainset and dataset_name_i not in load_params.testset:
            # Check whether the number of samples per category is satisfied as train-set.
            if not (np.array(list(label_counter.values())) >= 1).all():
                msg = (
                    "WARNING: Skip dataset {} with insufficient data items ({}) as train-set."
                ).format(dataset_name_i, label_counter)
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Get the number of samples cooresponding to train-set & validation-set & test-set.
            # Initialize `n_samples_validation` & `n_samples_test` as default values (0s).
            n_samples_test = dict(zip(
                sorted(label_counter.keys()), [0,]*len(sorted(label_counter.keys()))
            )); n_samples_validation = cp.deepcopy(n_samples_test)
            # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
            n_samples_train = dict([(label_i,
                (label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])
            ) for label_i in sorted(label_counter.keys())])
        # If dataset only belongs to test-set, split X into validation-set & test-set.
        elif dataset_name_i not in load_params.trainset and dataset_name_i in load_params.testset:
            # Check whether the number of samples per category is satisfied before spliting.
            if not (np.array(list(label_counter.values())) >= 2).all():
                msg = (
                    "WARNING: Skip dataset {} with insufficient data items ({})"+\
                    " to be split into validation-set & test-set."
                ).format(dataset_name_i, label_counter)
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Initialize `n_samples_validation` as default values.
            n_samples_validation = dict([(label_i, int(label_counter[label_i] / 2))\
                for label_i in sorted(label_counter.keys())])
            # Get `n_samples_test` according to `n_samples_validation`.
            n_samples_test = dict([(label_i,
                (label_counter[label_i] - n_samples_validation[label_i])
            ) for label_i in sorted(label_counter.keys())])
            # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
            n_samples_train = dict([(label_i,
                (label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])
            ) for label_i in sorted(label_counter.keys())])
        # Wrong cases.
        else:
            raise ValueError("ERROR: Unknown dataset name {}.".format(dataset_name_i))
        # Use `n_samples_*` to get `*_idxs`.
        train_idxs = []; validation_idxs = []; test_idxs = []
        for label_i in sorted(n_samples_test.keys()):
            test_idxs.extend(np.random.choice(
                np.where(dataset_i.y == label_i)[0],
                size=(n_samples_test[label_i],), replace=False
            ).tolist())
        for label_i in sorted(n_samples_validation.keys()):
            validation_idxs.extend(np.random.choice(
                list(set(np.where(dataset_i.y == label_i)[0]) - set(test_idxs)),
                size=(n_samples_validation[label_i],), replace=False
            ).tolist())
        train_idxs = sorted(set(range(dataset_i.y.shape[0])) - set(validation_idxs) - set(test_idxs))
        # Convert `*_idxs` to `np.array`.
        train_idxs = np.array(train_idxs, dtype=np.int64)
        validation_idxs = np.array(validation_idxs, dtype=np.int64)
        test_idxs = np.array(test_idxs, dtype=np.int64)
        # Get the corresponding `X_*_i` & `y_*_i`.
        X_train_i = dataset_i.X[train_idxs,:,:]; y_train_i = dataset_i.y[train_idxs]
        X_validation_i = dataset_i.X[validation_idxs,:,:]; y_validation_i = dataset_i.y[validation_idxs]
        X_test_i = dataset_i.X[test_idxs,:,:]; y_test_i = dataset_i.y[test_idxs]
        # Check whether `train_idxs` is consistent with `n_samples_train`.
        for label_i in sorted(n_samples_train.keys()): assert np.sum(y_train_i == label_i) == n_samples_train[label_i]
        # Append `X_*_i` & `y_*_i` to `X_*` & `y_*`.
        X_train.append(X_train_i); y_train.append(y_train_i)
        X_validation.append(X_validation_i); y_validation.append(y_validation_i)
        X_test.append(X_test_i); y_test.append(y_test_i)
    # Check whether train-set & validation-set & test-set all have data items.
    if len(X_train) == 0 or len(X_validation) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    X_train = np.concatenate(X_train, axis=0); y_train = np.concatenate(y_train, axis=0)
    X_validation = np.concatenate(X_validation, axis=0); y_validation = np.concatenate(y_validation, axis=0)
    X_test = np.concatenate(X_test, axis=0); y_test = np.concatenate(y_test, axis=0)
    # Check whether train-set & test-set are totally different, if True, randomly remove some items from trainset.
    if (len(set(load_params.trainset) & set(load_params.trainset)) == 0) and load_params.drop_ratio is not None:
        # Count the number of samples belonging to each category in train-set.
        label_counter = Counter(y_train); drop_idxs = []
        for label_i in sorted(label_counter.keys()):
            drop_idxs.extend(np.random.choice(np.where(y_train == label_i)[0],
                size=(int(label_counter[label_i] * load_params.drop_ratio),), replace=False
            ).tolist())
        if len(drop_idxs) > 0: assert len(set(y_train[drop_idxs].tolist())) < len(set(y_train))
        assert len(set(drop_idxs)) == len(drop_idxs)
        keep_idxs = list(set(range(y_train.shape[0])) - set(drop_idxs))
        # Drop `X_train` & `y_train` according to `keep_idxs`.
        X_train = X_train[keep_idxs,:,:]; y_train = y_train[keep_idxs]
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Make sure there is no overlap between X_train & X_validation.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_validation.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_validation[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Make sure there is no overlap between X_validation & X_test.
    samples_same = None; n_samples = 10; assert X_validation.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_validation.shape[1])
        sample_same_i = np.intersect1d(X_validation[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to sorted order.
    try:
        assert len(set(y_train)) == len(set(y_validation)) == len(set(y_test)) == 15; labels = sorted(set(y_train))
    except AssertionError as e:
        msg = (
            "WARNING: Skip experiment (train:{};test:{}) due to that the classes of test cases are not enough."
        ).format(set(load_params.trainset), set(load_params.testset))
        print(msg); paths.run.logger.summaries.info(msg); return ([], []), ([], []), ([], [])
    # y - (n_samples,)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64)
    y_validation = np.array([labels.index(y_i) for y_i in y_validation], dtype=np.int64)
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64)
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

# def _load_data_meg_lv2023cpnl func
def _load_data_meg_lv2023cpnl(load_params):
    """
    Load meg data from the specified subject in `meg_lv2023cpnl`.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The train dataset, including (X_train, y_train).
        dataset_validation_: tuple - The validation dataset, including (X_validation, y_validation).
        dataset_test_: tuple - The test dataset, including (X_test, y_test).
    """
    global params, paths
    # Initialize path_run.
    path_run = os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub018")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    func = getattr(utils.data.meg.lv2023cpnl, "load_run_{}".format(load_params.type))
    Xs, y = func(path_run)
    Xs = [Xs["image"][data_modality_i].astype(np.float32)\
        for data_modality_i in load_params.data_modality]
    y = y["image"].astype(np.int64)
    # If the type of dataset is `default`.
    if load_params.type == "default":
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        Xs = [X_i[:,20:100,:] for X_i in Xs]
    # If the type of dataset is `lvbj`.
    elif load_params.type == "lvbj":
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        Xs = [X_i[:,20:100,:] for X_i in Xs]
    # If the type of dataset is `baseline`.
    elif load_params.type == "baseline":
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        Xs = [X_i[:,20:100,:] for X_i in Xs]
    # Get unknown type of dataset.
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    # Downsample sequence data X.
    if load_params.downsample_rate > 1:
        assert Xs[0].shape[1] / load_params.downsample_rate == Xs[0].shape[1] // load_params.downsample_rate
        Xs = [np.split(X_i, np.arange(load_params.downsample_rate,
            X_i.shape[1], load_params.downsample_rate), axis=1) for X_i in Xs]
        Xs = [np.concatenate([np.mean(x_i, axis=1, keepdims=True) for x_i in X_i], axis=1) for X_i in Xs]
    # Do cross-trial normalization.
    Xs = [(X_i - np.mean(X_i)) / np.std(X_i) for X_i in Xs]
    # Concatenate `X_i`s of `Xs` along `n_channels`-axis.
    X = np.concatenate(Xs, axis=-1)
    # Force the number of samples belonging to different categories is the same.
    # Drop redandunt samples of each category to make sure that the number of samples in each category is equal.
    label_counter = Counter(y); min_samples = np.min(list(label_counter.values())); drop_idxs = []
    for label_i in sorted(label_counter.keys()):
        drop_idxs.extend(np.random.choice(np.where(y == label_i)[0],
            size=((label_counter[label_i] - min_samples),), replace=False).tolist())
    if len(drop_idxs) > 0: assert len(set(y[drop_idxs].tolist())) < len(set(y))
    assert len(set(drop_idxs)) == len(drop_idxs)
    keep_idxs = sorted(set(range(y.shape[0])) - set(drop_idxs))
    X = X[keep_idxs,:,:]; y = y[keep_idxs]
    # Make sure that the number of samples in each category is equal!
    label_counter = Counter(y); assert (np.diff(list(label_counter.values())) == 0).all()
    # Check whether the number of samples per category is satisfied before spliting.
    if not (np.array(list(label_counter.values())) >= 3).all():
        msg = (
            "WARNING: Skip dataset {} with insufficient data items ({})"+\
            " to be split into train-set & validation-set & test-set."
        ).format(dataset_name_i, label_counter)
        print(msg); paths.run.logger.summaries.info(msg); return ([], []), ([], []), ([], [])
    # Get the number of samples cooresponding to train-set & validation-set & test-set.
    # If enables leave-one-out mode, the validation-set & test-set
    # from this dataset will have only one sample per category.
    if load_params.use_loo:
        # Initialize `n_samples_validation` & `n_samples_test` as default values (1s).
        n_samples_test = dict(zip(
            sorted(label_counter.keys()), [1,]*len(sorted(label_counter.keys()))
        )); n_samples_validation = cp.deepcopy(n_samples_test)
        # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
        n_samples_train = dict(zip(
            sorted(label_counter.keys()),
            [(label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])\
                for label_i in sorted(label_counter.keys())]
        ))
    # If enable pseudo-K-fold mode, the validation-set & test-set from this dataset
    # will have samples of the same size per category. The samples per category
    # is determined as fewer as possible, at least one sample per category.
    else:
        # Initialize `train_ratio` & `validation_ratio` & `test_ratio`.
        train_ratio = params.train.train_ratio; validation_ratio = test_ratio = (1. - train_ratio) / 2.
        # Get default values (1s) for `n_samples_test`, at least one sample per category.
        n_samples_test = dict(zip(
            sorted(label_counter.keys()), [1,]*len(sorted(label_counter.keys()))
        ))
        # Loop all available labels to update `n_samples_test`.
        for label_i in n_samples_test.keys():
            # Get `n_samples_test_i` according to `test_ratio`.
            n_samples_test_i = int(label_counter[label_i] * test_ratio)
            # Update `n_samples_test` only when `n_samples_test_i` is greater than default values (1s).
            if n_samples_test_i > 1: n_samples_test[label_i] = n_samples_test_i
        # Copy `n_samples_test` to get `n_samples_validation`.
        n_samples_validation = cp.deepcopy(n_samples_test)
        # Get `n_samples_train` according to `n_samples_validation` & `n_samples_test`.
        n_samples_train = dict([(label_i,
            (label_counter[label_i] - n_samples_validation[label_i] - n_samples_test[label_i])
        ) for label_i in sorted(label_counter.keys())])
    # Use `n_samples_*` to get `*_idxs`.
    train_idxs = []; validation_idxs = []; test_idxs = []
    for label_i in sorted(n_samples_test.keys()):
        test_idxs.extend(np.random.choice(
            np.where(dataset_i.y == label_i)[0],
            size=(n_samples_test[label_i],), replace=False
        ).tolist())
    for label_i in sorted(n_samples_validation.keys()):
        validation_idxs.extend(np.random.choice(
            list(set(np.where(dataset_i.y == label_i)[0]) - set(test_idxs)),
            size=(n_samples_validation[label_i],), replace=False
        ).tolist())
    train_idxs = sorted(set(range(dataset_i.y.shape[0])) - set(validation_idxs) - set(test_idxs))
    # Convert `*_idxs` to `np.array`.
    train_idxs = np.array(train_idxs, dtype=np.int64)
    validation_idxs = np.array(validation_idxs, dtype=np.int64)
    test_idxs = np.array(test_idxs, dtype=np.int64)
    # Get the corresponding `X_*` & `y_*`.
    X_train = X[train_idxs,:,:]; y_train = y[train_idxs]
    X_validation = X[validation_idxs,:,:]; y_validation = y[validation_idxs]
    X_test = X[test_idxs,:,:]; y_test = y[test_idxs]
    # Check whether `train_idxs` is consistent with `n_samples_train`.
    for label_i in sorted(n_samples_train.keys()): assert np.sum(y_train == label_i) == n_samples_train[label_i]
    # Check whether train-set & validation-set & test-set all have data items.
    if len(X_train) == 0 or len(X_validation) == 0 or len(X_test) == 0: return ([], []), ([], []), ([], [])
    # As train-set & test-set are the same, do not have to randomly remove some items from train-set.
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Make sure there is no overlap between X_train & X_validation.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_validation.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_validation[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Make sure there is no overlap between X_validation & X_test.
    samples_same = None; n_samples = 10; assert X_validation.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_validation.shape[1])
        sample_same_i = np.intersect1d(X_validation[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to sorted order.
    assert len(set(y_train)) == len(set(y_validation)) == len(set(y_test)) == 12; labels = sorted(set(y_train))
    # y - (n_samples,)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64)
    y_validation = np.array([labels.index(y_i) for y_i in y_validation], dtype=np.int64)
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64)
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return (X_train, y_train), (X_validation, y_validation), (X_test, y_test)

"""
train funcs
"""
# def train func
def train(base_, params_):
    """
    Train the model.

    Args:
        base_: The base path of current project.
        params_: The parameters of current training process.

    Returns:
        None
    """
    global params, paths
    # Initialize parameters & variables of current training process.
    init(base_, params_)
    # Log the start of current training process.
    paths.run.logger.summaries.info("Training started with dataset {}.".format(params.train.dataset))
    # Initialize load_params. Each load_params_i corresponds to a sub-dataset.
    if params.train.dataset == "eeg_anonymous":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            #os.path.join(paths.base, "data", "eeg.anonymous", "005", "20221223"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "006", "20230103"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "007", "20230106"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "011", "20230214"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "013", "20230308"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "018", "20230331"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "019", "20230403"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "020", "20230405"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "021", "20230407"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "023", "20230412"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "024", "20230414"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "025", "20230417"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "026", "20230419"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "027", "20230421"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "028", "20230424"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "029", "20230428"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "030", "20230504"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "031", "20230510"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-002", "20230509"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "033", "20230517"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "034", "20230519"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-003", "20230524"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-004", "20230528"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-005", "20230601"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "036", "20230526"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "037", "20230529"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "038", "20230531"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "039", "20230605"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "040", "20230607"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-006", "20230603"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-007", "20230608"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-008", "20230610"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "042", "20230614"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "043", "20230616"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "044", "20230619"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "045", "20230626"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "046", "20230628"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-009", "20230613"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-010", "20230615"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-012", "20230623"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-013", "20230627"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-014", "20230629"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-015", "20230701"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "047", "20230703"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "048", "20230705"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-016", "20230703"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-017", "20230706"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "049", "20230710"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "050", "20230712"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "051", "20230717"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "052", "20230719"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-018", "20230710"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-019", "20230712"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-020", "20230714"),
            os.path.join(paths.base, "data", "eeg.anonymous", "sz-021", "20230718"),
        ]; load_type = "default"
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-all-test-task-all-all
            utils.DotDict({
                "name": "train-task-all-all-test-task-all-all",
                "trainset": [
                    "task-image-audio-pre-audio", "task-image-audio-pre-image",
                    "task-audio-image-pre-audio", "task-audio-image-pre-image",
                    "task-image-audio-post-audio", "task-image-audio-post-image",
                    "task-audio-image-post-audio", "task-audio-image-post-image",
                ],
                "testset": [
                    "task-image-audio-pre-audio", "task-image-audio-pre-image",
                    "task-audio-image-pre-audio", "task-audio-image-pre-image",
                    "task-image-audio-post-audio", "task-image-audio-post-image",
                    "task-audio-image-post-audio", "task-audio-image-post-image",
                ],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "testset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-task-all-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-task-all-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "testset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-tmr-n23-audio-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-rem-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-tmr-n23-audio-test-task-all-image
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-task-all-image",
                "trainset": ["tmr-N2/3-audio",],
                "testset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-tmr-n23-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-task-all-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-task-all-image-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-image-test-tmr-n23-audio",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-task-all-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-tmr-n23-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-task-all-image-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-task-all-image-test-tmr-rem-audio",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-task-all-audio-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-tmr-rem-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-task-all-image-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-image-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                    "tmr-N2/3-audio",
                ],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
            # train-task-all-audio-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-audio-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                    "tmr-N2/3-audio",
                ],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2, "use_loo": True,
            }),
        ]
    elif params.train.dataset == "meg_anonymous":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            #os.path.join(paths.base, "data", "meg.anonymous", "004", "20221218"),
            #os.path.join(paths.base, "data", "meg.anonymous", "007", "20230308"),
            #os.path.join(paths.base, "data", "meg.anonymous", "009", "20230310"),
            #os.path.join(paths.base, "data", "meg.anonymous", "011", "20230314"),
            #os.path.join(paths.base, "data", "meg.anonymous", "012", "20230314"),
            #os.path.join(paths.base, "data", "meg.anonymous", "015", "20230325"),
            #os.path.join(paths.base, "data", "meg.anonymous", "016", "20230325"),
            #os.path.join(paths.base, "data", "meg.anonymous", "018", "20230406"),
            #os.path.join(paths.base, "data", "meg.anonymous", "019", "20230410"),
            #os.path.join(paths.base, "data", "meg.anonymous", "022", "20230414"),
            #os.path.join(paths.base, "data", "meg.anonymous", "023", "20230414"),
            #os.path.join(paths.base, "data", "meg.anonymous", "024", "20230417"),
            #os.path.join(paths.base, "data", "meg.anonymous", "029", "20230424"),
            #os.path.join(paths.base, "data", "meg.anonymous", "030", "20230425"),
            #os.path.join(paths.base, "data", "meg.anonymous", "031", "20230425"),
            #os.path.join(paths.base, "data", "meg.anonymous", "032", "20230427"),
            #os.path.join(paths.base, "data", "meg.anonymous", "033", "20230427"),
            #os.path.join(paths.base, "data", "meg.anonymous", "034", "20230504"),
            #os.path.join(paths.base, "data", "meg.anonymous", "036", "20230508"),
            #os.path.join(paths.base, "data", "meg.anonymous", "038", "20230511"),
            #os.path.join(paths.base, "data", "meg.anonymous", "039", "20230511"),
            #os.path.join(paths.base, "data", "meg.anonymous", "041", "20230512"),
            #os.path.join(paths.base, "data", "meg.anonymous", "042", "20230516"),
            #os.path.join(paths.base, "data", "meg.anonymous", "043", "20230516"),
            #os.path.join(paths.base, "data", "meg.anonymous", "044", "20230517"),
            #os.path.join(paths.base, "data", "meg.anonymous", "045", "20230517"),
            #os.path.join(paths.base, "data", "meg.anonymous", "046", "20230518"),
            #os.path.join(paths.base, "data", "meg.anonymous", "047", "20230518"),
            #os.path.join(paths.base, "data", "meg.anonymous", "048", "20230522"),
            #os.path.join(paths.base, "data", "meg.anonymous", "049", "20230522"),
            #os.path.join(paths.base, "data", "meg.anonymous", "051", "20230523"),
            os.path.join(paths.base, "data", "meg.anonymous", "052", "20230525"),
        ]; load_type = "lvbj"; load_modality = ["eeg", "mag", "grad",][-1:]
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-all-test-task-all-all
            utils.DotDict({
                "name": "train-task-all-all-test-task-all-all",
                "trainset": [
                    "task-image-audio-audio", "task-image-audio-image",
                    "task-audio-image-audio", "task-audio-image-image",
                ],
                "testset": [
                    "task-image-audio-audio", "task-image-audio-image",
                    "task-audio-image-audio", "task-audio-image-image",
                ],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image",
                "trainset": ["task-image-audio-image", "task-audio-image-image",],
                "testset": ["task-image-audio-image", "task-audio-image-image",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-task-all-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-task-all-audio",
                "trainset": ["task-image-audio-audio", "task-audio-image-audio",],
                "testset": ["task-image-audio-audio", "task-audio-image-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-tmr-n23-audio-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-rem-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-tmr-n23-audio-test-task-all-image
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-task-all-image",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["task-image-audio-image", "task-audio-image-image",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-tmr-n23-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-task-all-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["task-image-audio-audio", "task-audio-image-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-task-all-image-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-image-test-tmr-n23-audio",
                "trainset": ["task-image-audio-image", "task-audio-image-image",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-task-all-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-tmr-n23-audio",
                "trainset": ["task-image-audio-audio", "task-audio-image-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-task-all-image-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-task-all-image-test-tmr-rem-audio",
                "trainset": ["task-image-audio-image", "task-audio-image-image",],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-task-all-audio-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-tmr-rem-audio",
                "trainset": ["task-image-audio-audio", "task-audio-image-audio",],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-task-all-image-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-image-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["task-image-audio-image", "task-audio-image-image", "tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
            # train-task-all-audio-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-audio-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["task-image-audio-audio", "task-audio-image-audio", "tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "downsample_rate": 1, "drop_ratio": 0.2,
                "use_loo": True, "data_modality": load_modality,
            }),
        ]
    elif params.train.dataset == "meg_lv2023cpnl":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub001"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub002"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub003"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub004"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub005"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub006"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub007"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub008"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub009"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub010"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub011"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub012"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub013"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub014"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub015"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub016"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub017"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub018"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub019"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub020"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub021"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub022"),
            #os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub023"),
            os.path.join(paths.base, "data", "meg.lv2023cpnl", "sub024"),
        ]; load_type = "baseline"; load_modality = ["mag", "grad"][-1:]
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image", "type": load_type, "use_loo": True,
                "permutation": False, "downsample_rate": 1, "data_modality": load_modality,
            }),
        ]
    else:
        raise ValueError("ERROR: Unknown dataset {} in train.lasso_glm.".format(params.train.dataset))
    # Execute experiments for each dataset run.
    for path_run_i in path_runs:
        # Log the start of current training iteration.
        msg = "Training started with sub-dataset {}.".format(path_run_i)
        print(msg); paths.run.logger.summaries.info(msg)
        for load_params_idx in range(len(load_params)):
            # Add `path_run` to `load_params_i`.
            load_params_i = cp.deepcopy(load_params[load_params_idx]); load_params_i.path_run = path_run_i
            # Load demo dataset to check whether train-set & validation-set & test-set exists.
            dataset_train, dataset_validation, dataset_test = load_data(load_params_i)
            if len(dataset_train[0]) == 0 or len(dataset_validation[0]) == 0 or len(dataset_test[0]) == 0:
                msg = (
                    "INFO: Skip experiment {} with train-set ({:d} items)" +\
                    " & validation-set ({:d} items) & test-set ({:d} items)."
                ).format(load_params_i.name, len(dataset_train[0]), len(dataset_validation[0]), len(dataset_test[0]))
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Start training process of current specified experiment.
            msg = "Start the training process of experiment {}.".format(load_params_i.name)
            print(msg); paths.run.logger.summaries.info(msg)
            # Run model with different train-set & validation-set & test-set for `n_runs` runs, to average the accuracy curve.
            run_idx = 0; accuracies_validation = []; accuracies_test = []
            while run_idx < params.train.n_runs:
                # Record the start time of preparing data.
                time_start = time.time()
                # Load data from specified experiment.
                dataset_train, dataset_validation, dataset_test = load_data(load_params_i)
                # Get `X` & `y` from `dataset`.
                X_train, y_train = dataset_train
                X_validation, y_validation = dataset_validation
                X_test, y_test = dataset_test
                # Train the model for each time point.
                model = lasso_glm_model(params.model)
                try:
                    accuracy_validation, accuracy_test =\
                        model.fit((X_train, X_validation, X_test), (y_train, y_validation, y_test))
                except ValueError as e:
                    msg = (
                        "ERROR: Get ValueError {}, re-run current run."
                    ).format(e); print(msg); paths.run.logger.summaries.info(msg); continue
                # Append `accuracy` to `accuracies`.
                accuracies_validation.append(accuracy_validation); accuracies_test.append(accuracy_test)
                # Record current time point.
                time_stop = time.time()
                # Convert `accuracy_validation` & `accuracy_test` to `np.array`.
                accuracy_validation = np.round(np.array(accuracy_validation, dtype=np.float32), decimals=4)
                accuracy_test = np.round(np.array(accuracy_test, dtype=np.float32), decimals=4)
                time_maxacc_idxs = np.where(accuracy_validation == np.max(accuracy_validation))[0]
                time_maxacc_idx = time_maxacc_idxs[np.argmax(accuracy_test[time_maxacc_idxs])]
                # Log information of current run.
                msg = (
                    "Finish run index {:d} in {:.2f} seconds, with test-accuracy ({:.2f}%)" +\
                    " according to max validation-accuracy ({:.2f}%) at time index {:d}."
                ).format(run_idx, time_stop-time_start, accuracy_test[time_maxacc_idx]*100.,
                    accuracy_validation[time_maxacc_idx]*100., time_maxacc_idx)
                for time_idx, (accuracy_validation_i, accuracy_test_i) in enumerate(zip(accuracy_validation, accuracy_test)):
                    msg += (
                        "\nGet validation-accuracy ({:.2f}%) and test-accuracy ({:.2f}%) at time index {:d}."
                    ).format(accuracy_validation_i*100., accuracy_test_i*100., time_idx)
                print(msg); paths.run.logger.summaries.info(msg)
                # Update `run_idx` to enter next iteration.
                run_idx += 1
            # Average `accuracies` to get `avg_accuracy`.
            avg_accuracy_validation = np.mean(accuracies_validation, axis=0)
            avg_accuracy_test = np.mean(accuracies_test, axis=0)
            # Convert `avg_accuracy_validation` & `avg_accuracy_test` to `np.array`.
            avg_accuracy_validation = np.round(np.array(avg_accuracy_validation, dtype=np.float32), decimals=4)
            avg_accuracy_test = np.round(np.array(avg_accuracy_test, dtype=np.float32), decimals=4)
            time_maxacc_idxs = np.where(avg_accuracy_validation == np.max(avg_accuracy_validation))[0]
            time_maxacc_idx = time_maxacc_idxs[np.argmax(avg_accuracy_test[time_maxacc_idxs])]
            # Finish training process of current specified experiment.
            msg = (
                "Finish the training process of experiment {}, with average test-accuracy ({:.2f}%)" +\
                " according to max average validation-accuracy ({:.2f}%) at time index {:d}."
            ).format(load_params_i.name, avg_accuracy_test[time_maxacc_idx]*100.,
                avg_accuracy_validation[time_maxacc_idx]*100., time_maxacc_idx)
            for time_idx, (avg_accuracy_validation_i, avg_accuracy_test_i) in\
                enumerate(zip(avg_accuracy_validation, avg_accuracy_test)):
                msg += (
                    "\nGet average validation-accuracy ({:.2f}%) and average test-accuracy ({:.2f}%) at time index {:d}."
                ).format(avg_accuracy_validation_i*100., avg_accuracy_test_i*100., time_idx)
            print(msg); paths.run.logger.summaries.info(msg)
        # Log the end of current training iteration.
        msg = "Training finished with sub-dataset {}.".format(path_run_i)
        print(msg); paths.run.logger.summaries.info(msg)
    # Log the end of current training process.
    msg = "Training finished with dataset {}.".format(params.train.dataset)
    print(msg); paths.run.logger.summaries.info(msg)

if __name__ == "__main__":
    import os
    # local dep
    from params.lasso_glm_params import lasso_glm_params

    # macro
    dataset = "eeg_anonymous"

    # Initialize random seed.
    utils.model.set_seeds(42)

    ## Instantiate lasso_glm.
    # Initialize base.
    base = os.path.join(os.getcwd(), os.pardir)
    # Instantiate lasso_glm_params.
    lasso_glm_params_inst = lasso_glm_params(dataset=dataset)
    # Train lasso_glm.
    train(base, lasso_glm_params_inst)

