import sys
import os
# 获取当前文件的绝对路径，并向上追溯到项目根目录（your_project）
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)  # 根据实际层级调整
sys.path.append(project_root)
# import a funciton at ../datasets/datasetloader.py
from datasets.datasetloader import getdatasetpath
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import root_mean_squared_error, r2_score, mean_absolute_error, mean_absolute_percentage_error, mean_squared_error
from sklearn.model_selection import train_test_split
import warnings
from sklearn.model_selection import KFold
warnings.filterwarnings("ignore")

def winsorize(data, lower=0.05, upper=0.95): 
    lower_bound = np.quantile(data, lower)
    upper_bound = np.quantile(data, upper)
    return np.clip(data, lower_bound, upper_bound)

n_exp = 30
n_group = 30
dataset_name = 'kin8nm'

# ------------------------assessment of iide and bias-------------------------
p_values = {}
model_names = ['linear', 'mlp', 'svm', 'rf', 'lgbm', 'xgb', 'catboost'] # , 'tabpfn']
test_names = ['iid', 'id', 'bias']
target_names = ['rmse', 'r2']
subject_names = ['lr', 'nn', 'fusion']
shap_features = ['fusion', 'subject', 'proxy']

for model_name in model_names:
    for test_name in test_names:
        for target_name in target_names:
            for subject_name in subject_names:
                p_values[f"{model_name}_{test_name}_{target_name}_{subject_name}"] = []
# -------------------------------------------------------------------

# ------------------------assessment of errors-------------------------
errors = {}
errors_subject = {}
errors_proxy = {}

for model_name in model_names:
    for target_name in target_names:
        for subject_name in subject_names:
            errors[f"{model_name}_{target_name}_{subject_name}"] = []
            errors_subject[f"{model_name}_{target_name}_{subject_name}"] = []
            errors_proxy[f"{model_name}_{target_name}_{subject_name}"] = []
# add baseline model: cv5, cv10, bootstrap, holdout100, holdout50, holdout20, holdout10
errors['rmse_holdout100_fusion'] = []
errors['rmse_holdout50_fusion'] = []
errors['rmse_holdout20_fusion'] = []
errors['rmse_holdout10_fusion'] = []
errors['rmse_cv5_fusion'] = []
errors['rmse_cv10_fusion'] = []
errors['rmse_bootstrap_fusion'] = []
errors['rmse_holdout100_lr'] = []
errors['rmse_holdout50_lr'] = []
errors['rmse_holdout20_lr'] = []
errors['rmse_holdout10_lr'] = []
errors['rmse_cv5_lr'] = []
errors['rmse_cv10_lr'] = []
errors['rmse_bootstrap_lr'] = []
errors['rmse_holdout100_nn'] = []
errors['rmse_holdout50_nn'] = []
errors['rmse_holdout20_nn'] = []
errors['rmse_holdout10_nn'] = []
errors['rmse_cv5_nn'] = []
errors['rmse_cv10_nn'] = []
errors['rmse_bootstrap_nn'] = []
errors['r2_holdout100_fusion'] = []
errors['r2_holdout50_fusion'] = []
errors['r2_holdout20_fusion'] = []
errors['r2_holdout10_fusion'] = []
errors['r2_cv5_fusion'] = []
errors['r2_cv10_fusion'] = []
errors['r2_bootstrap_fusion'] = []
errors['r2_holdout100_lr'] = []
errors['r2_holdout50_lr'] = []
errors['r2_holdout20_lr'] = []
errors['r2_holdout10_lr'] = []
errors['r2_cv5_lr'] = []
errors['r2_cv10_lr'] = []
errors['r2_bootstrap_lr'] = []
errors['r2_holdout100_nn'] = []
errors['r2_holdout50_nn'] = []
errors['r2_holdout20_nn'] = []
errors['r2_holdout10_nn'] = []
errors['r2_cv5_nn'] = []
errors['r2_cv10_nn'] = []
errors['r2_bootstrap_nn'] = []
# -------------------------------------------------------------------

for i_exp in range(n_exp):
    print(f"i_exp: {i_exp}")
    # ---------------------------------read dataset---------------------------------
    datasetpath = getdatasetpath(dataset_name)
    import arff
    data = arff.load(open(datasetpath, 'r'))
    data_df = pd.DataFrame(data['data'])
    # ---------------------------------selected the needed columns---------------------------------
    x_column_names = data_df.columns.tolist()[:-1]
    y_column_names = [data_df.columns.tolist()[-1]]
    data_df = data_df[x_column_names + y_column_names]
    # ---------------------------------------------------------------------------------------------------
    # ---------------------------------------------------------------------------------------------

    # ---------------------------------preprocess the dataset---------------------------------
    # check NAN values, and remove rows with any NAN values, 
    data_df = data_df.dropna()
    test_ratio = 0.2
    df_tr = data_df.head(int(data_df.shape[0] * (1 - test_ratio)))
    df_te = data_df.tail(int(data_df.shape[0] * test_ratio))
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    df_tr = pd.DataFrame(scaler.fit_transform(df_tr), columns=df_tr.columns)
    df_te = pd.DataFrame(scaler.transform(df_te), columns=df_te.columns)
    print("preprocessd:", data_df.shape)
    # input()

    # -------------evaluation task--------------------
    # subject: logistic regression, and 2-layer NN
    #   subject outut: predict death14 after intentional alert
    # condition: non
    # metric: roc_auc_score
    # ------------------------------------------------

    # -------------es (use test as es), bench (use va as bench)----------------
    # randomly sample val from train, 
    x_va = df_tr[x_column_names]
    y_va = df_tr[y_column_names]
    x_te = df_te[x_column_names]
    y_te = df_te[y_column_names]
    # print("shape of x va:", x_va.shape)
    # print("shape of y va:", y_va.shape)
    # print("shape of x te:", x_te.shape)
    # print("shape of y te:", y_te.shape)
    # to tensor, 
    x_va = torch.tensor(x_va.values, dtype=torch.float32)
    y_va = torch.tensor(y_va.values, dtype=torch.float32)
    x_te = torch.tensor(x_te.values, dtype=torch.float32)
    y_te = torch.tensor(y_te.values, dtype=torch.float32)
    # print("shape of x va:", x_va.shape)
    # print("shape of y va:", y_va.shape)
    # print("shape of x te:", x_te.shape)
    # print("shape of y te:", y_te.shape)
    # ---------------------------------------------------------------------------------------------------

    # ------------------------------create subject space---------------------------------------------------------------------
    # subject space: logistic regression, and 2-layer NN
    # logistic regression
    class LinearSubject(torch.nn.Module):
        def __init__(self, n_input):
            super(LinearSubject, self).__init__()
            self.linear = torch.nn.Linear(n_input, 1)
            self.sigmoid = torch.nn.Sigmoid()
        def forward(self, x):
            x = self.linear(x)
            x = self.sigmoid(x)
            return x
    # 2-layer NN
    class NNSubject(torch.nn.Module):
        def __init__(self, n_input, maxhidden = 8):
            super(NNSubject, self).__init__()
            self.linear1 = torch.nn.Linear(n_input, maxhidden)
            self.linear2 = torch.nn.Linear(maxhidden, 1)
            self.sigmoid = torch.nn.Sigmoid()
            self.layer1 = torch.nn.Sequential(
                self.linear1,
                self.sigmoid
            )
            self.layer2 = torch.nn.Sequential(
                self.linear2,
                self.sigmoid
            )

        def forward(self, x):
            layer1_out = self.layer1(x)
            layer2_out = self.layer2(layer1_out)
            return layer2_out
        
    # ---------------------------------------------------------------------------------------------------

    # ---------------------------sampling from subject space-----------------------------
    n_subjects = 2000
    subject_lr_list = []
    subject_lr_vector_list = []
    subject_nn_list = []
    subject_nn_vector_list = []
    subject_label = []
    subject_list = []
    subject_vector_list = []
    n_input = len(x_column_names)
    for i in range(n_subjects): # ------------subject space sampling methods------------ 
        # 0: 0.5 probability to select a linear, and 1: 0.5 probability to select a NN, 
        subject_label.append(np.random.choice([0, 1]))
        n_input = len(x_column_names)
        # print("n input:", n_input)
        if subject_label[-1] == 0: # create a linear subject, 
            subject = LinearSubject(n_input)
            subject_lr_list.append(subject)
            subject_lr_vector_list.append(torch.cat([p.view(-1) for p in subject.parameters()]).detach().numpy().tolist())
        else: # create a NN subject, 
            subject = NNSubject(n_input)
            subject_nn_list.append(subject)
            subject_nn_vector_list.append(torch.cat([p.view(-1) for p in subject.parameters()]).detach().numpy().tolist())
        subject_list.append(subject)
        subject_vector = torch.cat([p.view(-1) for p in subject.parameters()]).detach().numpy().tolist()
        subject_vector_list.append(subject_vector)
    # ---------------------------------------------------------------------------------------------------

    # -----------------------sampling from the metric space-------------------------------
    va_metric_names = ['rmse', 'r2', 'mae', 'mape', 'mse']
    subject_names = ['lr', 'nn', 'fusion']
    va_valid_names = ['cv5', 'cv10', 'bootstrap', 'holdout100', 'holdout50', 'holdout20', 'holdout10']

    va_lists = {}
    for va_metric_name in va_metric_names:
        for subject_name in subject_names:
            for va_valid_name in va_valid_names:
                va_lists[f"{va_metric_name}_{subject_name}_{va_valid_name}"] = []
    
    te_lists = {}
    for va_metric_name in va_metric_names:
        for subject_name in subject_names:
            te_lists[f"{va_metric_name}_{subject_name}"] = []
    # ---------------------------------------------------------------------------------------------------

    for i in range(n_subjects): # on test set, 
        subject = subject_list[i]
        output = subject(x_te)
        rmse = root_mean_squared_error(y_te, output.detach().numpy())
        r2 = r2_score(y_te, output.detach().numpy())
        mae = mean_absolute_error(y_te, output.detach().numpy())
        mape = mean_absolute_percentage_error(y_te, output.detach().numpy())
        mse = mean_squared_error(y_te, output.detach().numpy())
        te_lists[f"rmse_fusion"].append(rmse)
        te_lists[f"r2_fusion"].append(r2)
        te_lists[f"mae_fusion"].append(mae)
        te_lists[f"mape_fusion"].append(mape)
        te_lists[f"mse_fusion"].append(mse)
        if subject_label[i] == 0:
            te_lists[f"rmse_lr"].append(rmse)
            te_lists[f"r2_lr"].append(r2)
            te_lists[f"mae_lr"].append(mae)
            te_lists[f"mape_lr"].append(mape)
            te_lists[f"mse_lr"].append(mse)
        else:
            te_lists[f"rmse_nn"].append(rmse)
            te_lists[f"r2_nn"].append(r2)
            te_lists[f"mae_nn"].append(mae)
            te_lists[f"mape_nn"].append(mape)
            te_lists[f"mse_nn"].append(mse)
    import time
    start = time.time()
    t_holdout10 = 0
    t_holdout20 = 0
    t_holdout50 = 0
    t_holdout100 = 0
    t_cv5 = 0
    t_cv10 = 0
    t_bootstrap = 0
    for i in range(n_subjects): # on validation set, 
        # print(f"validation i: {i}")
        subject = subject_list[i]
        # holdout 0.1
        start_holdout10 = time.time()
        x_va10 = x_va[:int(len(x_va) * 0.1)]
        y_va10 = y_va[:int(len(y_va) * 0.1)]
        output = subject(x_va10)
        rmse = root_mean_squared_error(y_va10, output.detach().numpy())
        r2 = r2_score(y_va10, output.detach().numpy())
        mae = mean_absolute_error(y_va10, output.detach().numpy())
        mape = mean_absolute_percentage_error(y_va10, output.detach().numpy())
        mse = mean_squared_error(y_va10, output.detach().numpy())
        va_lists[f"rmse_fusion_holdout10"].append(rmse)
        va_lists[f"r2_fusion_holdout10"].append(r2)
        va_lists[f"mae_fusion_holdout10"].append(mae)
        va_lists[f"mape_fusion_holdout10"].append(mape)
        va_lists[f"mse_fusion_holdout10"].append(mse)
        if subject_label[i] == 0:
            va_lists[f"rmse_lr_holdout10"].append(rmse)
            va_lists[f"r2_lr_holdout10"].append(r2)
            va_lists[f"mae_lr_holdout10"].append(mae)
            va_lists[f"mape_lr_holdout10"].append(mape)
            va_lists[f"mse_lr_holdout10"].append(mse)
        else:
            va_lists[f"rmse_nn_holdout10"].append(rmse)
            va_lists[f"r2_nn_holdout10"].append(r2)
            va_lists[f"mae_nn_holdout10"].append(mae)
            va_lists[f"mape_nn_holdout10"].append(mape)
            va_lists[f"mse_nn_holdout10"].append(mse)
        t_holdout10 += time.time() - start_holdout10
        # holdout 0.2
        start_holdout20 = time.time()
        x_va20 = x_va[:int(len(x_va) * 0.2)]
        y_va20 = y_va[:int(len(y_va) * 0.2)]
        output = subject(x_va20)
        rmse = root_mean_squared_error(y_va20, output.detach().numpy())
        r2 = r2_score(y_va20, output.detach().numpy())
        mae = mean_absolute_error(y_va20, output.detach().numpy())
        mape = mean_absolute_percentage_error(y_va20, output.detach().numpy())
        mse = mean_squared_error(y_va20, output.detach().numpy())
        va_lists[f"rmse_fusion_holdout20"].append(rmse)
        va_lists[f"r2_fusion_holdout20"].append(r2)
        va_lists[f"mae_fusion_holdout20"].append(mae)
        va_lists[f"mape_fusion_holdout20"].append(mape)
        va_lists[f"mse_fusion_holdout20"].append(mse)
        if subject_label[i] == 0:
            va_lists[f"rmse_lr_holdout20"].append(rmse)
            va_lists[f"r2_lr_holdout20"].append(r2)
            va_lists[f"mae_lr_holdout20"].append(mae)
            va_lists[f"mape_lr_holdout20"].append(mape)
            va_lists[f"mse_lr_holdout20"].append(mse)
        else:
            va_lists[f"rmse_nn_holdout20"].append(rmse)
            va_lists[f"r2_nn_holdout20"].append(r2)
            va_lists[f"mae_nn_holdout20"].append(mae)
            va_lists[f"mape_nn_holdout20"].append(mape)
            va_lists[f"mse_nn_holdout20"].append(mse)
        t_holdout20 += time.time() - start_holdout20
        # holdout 0.5
        start_holdout50 = time.time()
        x_va50 = x_va[:int(len(x_va) * 0.5)]
        y_va50 = y_va[:int(len(y_va) * 0.5)]
        output = subject(x_va50)
        rmse = root_mean_squared_error(y_va50, output.detach().numpy())
        r2 = r2_score(y_va50, output.detach().numpy())
        mae = mean_absolute_error(y_va50, output.detach().numpy())
        mape = mean_absolute_percentage_error(y_va50, output.detach().numpy())
        mse = mean_squared_error(y_va50, output.detach().numpy())
        va_lists[f"rmse_fusion_holdout50"].append(rmse)
        va_lists[f"r2_fusion_holdout50"].append(r2)
        va_lists[f"mae_fusion_holdout50"].append(mae)
        va_lists[f"mape_fusion_holdout50"].append(mape)
        va_lists[f"mse_fusion_holdout50"].append(mse)
        if subject_label[i] == 0:
            va_lists[f"rmse_lr_holdout50"].append(rmse)
            va_lists[f"r2_lr_holdout50"].append(r2)
            va_lists[f"mae_lr_holdout50"].append(mae)
            va_lists[f"mape_lr_holdout50"].append(mape)
            va_lists[f"mse_lr_holdout50"].append(mse)
        else:
            va_lists[f"rmse_nn_holdout50"].append(rmse)
            va_lists[f"r2_nn_holdout50"].append(r2)
            va_lists[f"mae_nn_holdout50"].append(mae)
            va_lists[f"mape_nn_holdout50"].append(mape)
            va_lists[f"mse_nn_holdout50"].append(mse)
        t_holdout50 += time.time() - start_holdout50

        # holdout 1.0
        start_holdout100 = time.time()
        x_va100 = x_va
        y_va100 = y_va
        output = subject(x_va100)
        rmse = root_mean_squared_error(y_va100, output.detach().numpy())
        r2 = r2_score(y_va100, output.detach().numpy())
        mae = mean_absolute_error(y_va100, output.detach().numpy())
        mape = mean_absolute_percentage_error(y_va100, output.detach().numpy())
        mse = mean_squared_error(y_va100, output.detach().numpy())
        va_lists[f"rmse_fusion_holdout100"].append(rmse)
        va_lists[f"r2_fusion_holdout100"].append(r2)
        va_lists[f"mae_fusion_holdout100"].append(mae)
        va_lists[f"mape_fusion_holdout100"].append(mape)
        va_lists[f"mse_fusion_holdout100"].append(mse)
        if subject_label[i] == 0:
            va_lists[f"rmse_lr_holdout100"].append(rmse)
            va_lists[f"r2_lr_holdout100"].append(r2)
            va_lists[f"mae_lr_holdout100"].append(mae)
            va_lists[f"mape_lr_holdout100"].append(mape)
            va_lists[f"mse_lr_holdout100"].append(mse)
        else:
            va_lists[f"rmse_nn_holdout100"].append(rmse)
            va_lists[f"r2_nn_holdout100"].append(r2)
            va_lists[f"mae_nn_holdout100"].append(mae)
            va_lists[f"mape_nn_holdout100"].append(mape)
            va_lists[f"mse_nn_holdout100"].append(mse)
        t_holdout100 += time.time() - start_holdout100

        # cv5
        start_cv5 = time.time()
        kf = KFold(n_splits=5)
        tmp_va_lists = {}
        for va_metric_name in va_metric_names:
            for subject_name in subject_names:
                tmp_va_lists[f"{va_metric_name}_{subject_name}_cv5"] = []
        for train_index, test_index in kf.split(x_va):
            x_train, x_test = x_va[train_index], x_va[test_index]
            y_train, y_test = y_va[train_index], y_va[test_index]
            output = subject(x_test)
            rmse = root_mean_squared_error(y_test, output.detach().numpy())
            r2 = r2_score(y_test, output.detach().numpy())
            mae = mean_absolute_error(y_test, output.detach().numpy())
            mape = mean_absolute_percentage_error(y_test, output.detach().numpy())
            mse = mean_squared_error(y_test, output.detach().numpy())
            tmp_va_lists[f"rmse_fusion_cv5"].append(rmse)
            tmp_va_lists[f"r2_fusion_cv5"].append(r2)
            tmp_va_lists[f"mae_fusion_cv5"].append(mae)
            tmp_va_lists[f"mape_fusion_cv5"].append(mape)
            tmp_va_lists[f"mse_fusion_cv5"].append(mse)
            if subject_label[i] == 0:
                tmp_va_lists[f"rmse_lr_cv5"].append(rmse)
                tmp_va_lists[f"r2_lr_cv5"].append(r2)
                tmp_va_lists[f"mae_lr_cv5"].append(mae)
                tmp_va_lists[f"mape_lr_cv5"].append(mape)
                tmp_va_lists[f"mse_lr_cv5"].append(mse)
            else:
                tmp_va_lists[f"rmse_nn_cv5"].append(rmse)
                tmp_va_lists[f"r2_nn_cv5"].append(r2)
                tmp_va_lists[f"mae_nn_cv5"].append(mae)
                tmp_va_lists[f"mape_nn_cv5"].append(mape)
                tmp_va_lists[f"mse_nn_cv5"].append(mse)
        # print("tmp va lists:", tmp_va_lists)
        # mean of cv5, check whether this subject is lr or nn, 
        for va_metric_name in va_metric_names:
            mean_cv5_fusion = np.mean(tmp_va_lists[f"{va_metric_name}_fusion_cv5"])
            mean_cv5_lr = np.mean(tmp_va_lists[f"{va_metric_name}_lr_cv5"])
            mean_cv5_nn = np.mean(tmp_va_lists[f"{va_metric_name}_nn_cv5"])
            va_lists[f"{va_metric_name}_fusion_cv5"].append(mean_cv5_fusion)
            if subject_label[i] == 0:
                va_lists[f"{va_metric_name}_lr_cv5"].append(mean_cv5_lr)
            else:
                va_lists[f"{va_metric_name}_nn_cv5"].append(mean_cv5_nn)

        t_cv5 += time.time() - start_cv5
        # cv10
        start_cv10 = time.time()
        kf = KFold(n_splits=10)
        tmp_va_lists = {}
        for va_metric_name in va_metric_names:
            for subject_name in subject_names:
                tmp_va_lists[f"{va_metric_name}_{subject_name}_cv10"] = []
        for train_index, test_index in kf.split(x_va):
            x_train, x_test = x_va[train_index], x_va[test_index]
            y_train, y_test = y_va[train_index], y_va[test_index]
            output = subject(x_test)
            rmse = root_mean_squared_error(y_test, output.detach().numpy())
            r2 = r2_score(y_test, output.detach().numpy())
            mae = mean_absolute_error(y_test, output.detach().numpy())
            mape = mean_absolute_percentage_error(y_test, output.detach().numpy())
            mse = mean_squared_error(y_test, output.detach().numpy())
            tmp_va_lists[f"rmse_fusion_cv10"].append(rmse)
            tmp_va_lists[f"r2_fusion_cv10"].append(r2)
            tmp_va_lists[f"mae_fusion_cv10"].append(mae)
            tmp_va_lists[f"mape_fusion_cv10"].append(mape)
            tmp_va_lists[f"mse_fusion_cv10"].append(mse)
            if subject_label[i] == 0:
                tmp_va_lists[f"rmse_lr_cv10"].append(rmse)
                tmp_va_lists[f"r2_lr_cv10"].append(r2)
                tmp_va_lists[f"mae_lr_cv10"].append(mae)
                tmp_va_lists[f"mape_lr_cv10"].append(mape)
                tmp_va_lists[f"mse_lr_cv10"].append(mse)
            else:
                tmp_va_lists[f"rmse_nn_cv10"].append(rmse)
                tmp_va_lists[f"r2_nn_cv10"].append(r2)
                tmp_va_lists[f"mae_nn_cv10"].append(mae)
                tmp_va_lists[f"mape_nn_cv10"].append(mape)
                tmp_va_lists[f"mse_nn_cv10"].append(mse)

        # mean of cv10 
        for va_metric_name in va_metric_names:
            mean_cv10_fusion = np.mean(tmp_va_lists[f"{va_metric_name}_fusion_cv10"])
            mean_cv10_lr = np.mean(tmp_va_lists[f"{va_metric_name}_lr_cv10"])
            mean_cv10_nn = np.mean(tmp_va_lists[f"{va_metric_name}_nn_cv10"])
            va_lists[f"{va_metric_name}_fusion_cv10"].append(mean_cv10_fusion)
            if subject_label[i] == 0:
                va_lists[f"{va_metric_name}_lr_cv10"].append(mean_cv10_lr)
            else:
                va_lists[f"{va_metric_name}_nn_cv10"].append(mean_cv10_nn)

        t_cv10 += time.time() - start_cv10
        # bootstrap 
        start_bootstrap = time.time()
        # randomly choose 1000 samples with replacement from a tensor, 
        bootstrap_index = torch.randint(0, len(x_va), (len(x_va),))
        x_va_bootstrap = x_va[bootstrap_index]
        y_va_bootstrap = y_va[bootstrap_index]
        output = subject(x_va_bootstrap)
        rmse = root_mean_squared_error(y_va_bootstrap, output.detach().numpy())
        r2 = r2_score(y_va_bootstrap, output.detach().numpy())
        mae = mean_absolute_error(y_va_bootstrap, output.detach().numpy())
        mape = mean_absolute_percentage_error(y_va_bootstrap, output.detach().numpy())
        mse = mean_squared_error(y_va_bootstrap, output.detach().numpy())
        va_lists[f"rmse_fusion_bootstrap"].append(rmse)
        va_lists[f"r2_fusion_bootstrap"].append(r2)
        va_lists[f"mae_fusion_bootstrap"].append(mae)
        va_lists[f"mape_fusion_bootstrap"].append(mape)
        va_lists[f"mse_fusion_bootstrap"].append(mse)
        if subject_label[i] == 0:
            va_lists[f"rmse_lr_bootstrap"].append(rmse)
            va_lists[f"r2_lr_bootstrap"].append(r2)
            va_lists[f"mae_lr_bootstrap"].append(mae)
            va_lists[f"mape_lr_bootstrap"].append(mape)
            va_lists[f"mse_lr_bootstrap"].append(mse)
        else:
            va_lists[f"rmse_nn_bootstrap"].append(rmse)
            va_lists[f"r2_nn_bootstrap"].append(r2)
            va_lists[f"mae_nn_bootstrap"].append(mae)
            va_lists[f"mape_nn_bootstrap"].append(mape)
            va_lists[f"mse_nn_bootstrap"].append(mse)
        t_bootstrap += time.time() - start_bootstrap
    print("t_holdout10:", t_holdout10, "per subject:", t_holdout10 / n_subjects)
    print("t_holdout20:", t_holdout20, "per subject:", t_holdout20 / n_subjects)
    print("t_holdout50:", t_holdout50, "per subject:", t_holdout50 / n_subjects)
    print("t_holdout100:", t_holdout100, "per subject:", t_holdout100 / n_subjects)
    print("t_cv5:", t_cv5, "per subject:", t_cv5 / n_subjects)
    print("t_cv10:", t_cv10, "per subject:", t_cv10 / n_subjects)
    print("t_bootstrap:", t_bootstrap, "per subject:", t_bootstrap / n_subjects)
    print("all time:", time.time() - start, "per subject:", (time.time() - start) / n_subjects)
    input()
    # merge all the va metrics, 
    lr_proxy_lists = []
    nn_proxy_lists = []
    fusion_proxy_lists = []
    for va_metric_name in va_metric_names:
        for va_valid_name in va_valid_names:
            # print("len of va lists:", len(va_lists[f"{va_metric_name}_lr_{va_valid_name}"]))
            # print("len of va lists:", len(va_lists[f"{va_metric_name}_nn_{va_valid_name}"]))
            # print("len of va lists:", len(va_lists[f"{va_metric_name}_fusion_{va_valid_name}"]))
            lr_proxy_lists.append(va_lists[f"{va_metric_name}_lr_{va_valid_name}"])
            nn_proxy_lists.append(va_lists[f"{va_metric_name}_nn_{va_valid_name}"])
            fusion_proxy_lists.append(va_lists[f"{va_metric_name}_fusion_{va_valid_name}"])
    # print("lr proxy lists:", lr_proxy_lists)
    # print("nn proxy lists:", nn_proxy_lists)
    # print("fusion proxy lists:", fusion_proxy_lists)
    # transpose the lists,
    lr_proxy_lists = np.array(lr_proxy_lists).T
    nn_proxy_lists = np.array(nn_proxy_lists).T
    fusion_proxy_lists = np.array(fusion_proxy_lists).T

    # # to numpy, 
    # lr_proxy_lists = np.array(lr_proxy_lists)
    # nn_proxy_lists = np.array(nn_proxy_lists)
    # fusion_proxy_lists = np.array(fusion_proxy_lists)
    # subject_lr_vector_list = np.array(subject_lr_vector_list)
    # reshape, transpose, a
    # show shape, 
    print("lr proxy lists shape:", lr_proxy_lists.shape)
    print("nn proxy lists shape:", nn_proxy_lists.shape)
    print("fusion proxy lists shape:", fusion_proxy_lists.shape)
    # print("subject lr vector list shape:", subject_lr_vector_list.shape)
    # ---------------------------------------------------------------------

    # ---------------------------evaluation task-----------------------------
    # randomly select 0.8 subjects as train, 0.2 as test, from subject list, 
    # subject: train, test; metrics: train, test; proxy metrics on validation: train, test, 
    s_lr_tr, s_lr_te, \
        rmse_lr_tr_te, rmse_lr_te_te, proxy_lr_tr_va, proxy_lr_te_va, \
        r2_lr_tr_te, r2_lr_te_te, \
        mae_lr_tr_te, mae_lr_te_te,  \
        mape_lr_tr_te, mape_lr_te_te, \
        mse_lr_tr_te, mse_lr_te_te,  \
        = train_test_split(subject_lr_vector_list, \
                            te_lists[f"rmse_lr"], lr_proxy_lists, \
                            te_lists[f"r2_lr"], \
                            te_lists[f"mae_lr"], \
                            te_lists[f"mape_lr"], \
                            te_lists[f"mse_lr"], \
                            test_size=0.2)
    s_nn_tr, s_nn_te, \
        rmse_nn_tr_te, rmse_nn_te_te, proxy_nn_tr_va, proxy_nn_te_va, \
        r2_nn_tr_te, r2_nn_te_te, \
        mae_nn_tr_te, mae_nn_te_te, \
        mape_nn_tr_te, mape_nn_te_te, \
        mse_nn_tr_te, mse_nn_te_te,  \
        = train_test_split(subject_nn_vector_list, \
                            te_lists[f"rmse_nn"], nn_proxy_lists, \
                            te_lists[f"r2_nn"], \
                            te_lists[f"mae_nn"], \
                            te_lists[f"mape_nn"], \
                            te_lists[f"mse_nn"], \
                            test_size=0.2)
    # ---------------------------------------------------------------------
    # concat s_lr_tr, roc_lr_tr, r2_lr_tr, recall_lr_tr, precision_lr_tr, f1_lr_tr, prauc_lr_tr to a tensor,
    # va set is used as a covariate, 
    # to tensor, 
    s_lr_tr = torch.tensor(s_lr_tr, dtype=torch.float32)
    proxy_lr_tr_va = torch.tensor(proxy_lr_tr_va, dtype=torch.float32)
    # concat s_lr_te, roc_lr_te, r2_lr_te, recall_lr_te, precision_lr_te, f1_lr_te, prauc_lr_te to a tensor,
    sm_lr_tr = torch.cat([s_lr_tr, proxy_lr_tr_va], axis=1).detach().numpy()
    # to tensor, 
    s_lr_te = torch.tensor(s_lr_te, dtype=torch.float32)
    proxy_lr_te_va = torch.tensor(proxy_lr_te_va, dtype=torch.float32)
    # concat s_lr_te, roc_lr_te, r2_lr_te, recall_lr_te, precision_lr_te, f1_lr_te, prauc_lr_te to a tensor,
    sm_lr_te = torch.cat([s_lr_te, proxy_lr_te_va], axis=1).detach().numpy()
    # to tensor, 
    s_nn_tr = torch.tensor(s_nn_tr, dtype=torch.float32)
    proxy_nn_tr_va = torch.tensor(proxy_nn_tr_va, dtype=torch.float32)
    # concat s_nn_te, roc_nn_te, r2_nn_te, recall_nn_te, precision_nn_te, f1_nn_te, prauc_nn_te to a tensor,
    sm_nn_tr = torch.cat([s_nn_tr, proxy_nn_tr_va], axis=1).detach().numpy()
    # to tensor, 
    s_nn_te = torch.tensor(s_nn_te, dtype=torch.float32)
    proxy_nn_te_va = torch.tensor(proxy_nn_te_va, dtype=torch.float32)
    # concat s_nn_te, roc_nn_te, r2_nn_te, recall_nn_te, precision_nn_te, f1_nn_te, prauc_nn_te to a tensor,
    sm_nn_te = torch.cat([s_nn_te, proxy_nn_te_va], axis=1).detach().numpy()
    # they are all true metric (result on test set), only consider the roc_auc_score and r2uracy, 
    m_lr_tr_rmse = torch.tensor(rmse_lr_tr_te, dtype=torch.float32).detach().numpy()
    m_lr_te_rmse = torch.tensor(rmse_lr_te_te, dtype=torch.float32).detach().numpy()
    m_nn_tr_rmse = torch.tensor(rmse_nn_tr_te, dtype=torch.float32).detach().numpy()
    m_nn_te_rmse = torch.tensor(rmse_nn_te_te, dtype=torch.float32).detach().numpy()
    m_lr_tr_r2 = torch.tensor(r2_lr_tr_te, dtype=torch.float32).detach().numpy()
    m_lr_te_r2 = torch.tensor(r2_lr_te_te, dtype=torch.float32).detach().numpy()
    m_nn_tr_r2 = torch.tensor(r2_nn_tr_te, dtype=torch.float32).detach().numpy()
    m_nn_te_r2 = torch.tensor(r2_nn_te_te, dtype=torch.float32).detach().numpy()
    s_lr_te = s_lr_te.detach().numpy()
    s_nn_te = s_nn_te.detach().numpy()
    s_lr_tr = s_lr_tr.detach().numpy()
    s_nn_tr = s_nn_tr.detach().numpy()
    proxy_lr_te_va = proxy_lr_te_va.detach().numpy()
    proxy_nn_te_va = proxy_nn_te_va.detach().numpy()
    proxy_lr_tr_va = proxy_lr_tr_va.detach().numpy()
    proxy_nn_tr_va = proxy_nn_tr_va.detach().numpy()

    # error is va_list - te_list, 
    errors['rmse_holdout100_fusion'].append([a - b for a, b in zip(te_lists['rmse_fusion'],va_lists['rmse_fusion_holdout100'])])
    errors['rmse_holdout50_fusion'].append([a - b for a, b in zip(te_lists['rmse_fusion'],va_lists['rmse_fusion_holdout50'])])
    errors['rmse_holdout20_fusion'].append([a - b for a, b in zip(te_lists['rmse_fusion'],va_lists['rmse_fusion_holdout20'])])
    errors['rmse_holdout10_fusion'].append([a - b for a, b in zip(te_lists['rmse_fusion'],va_lists['rmse_fusion_holdout10'])])
    errors['rmse_cv5_fusion'].append([a - b for a, b in zip(te_lists['rmse_fusion'],va_lists['rmse_fusion_cv5'])])
    errors['rmse_cv10_fusion'].append([a - b for a, b in zip(te_lists['rmse_fusion'],va_lists['rmse_fusion_cv10'])])
    errors['rmse_bootstrap_fusion'].append([a - b for a, b in zip(te_lists['rmse_fusion'],va_lists['rmse_fusion_bootstrap'])])

    errors['r2_holdout100_fusion'].append([a - b for a, b in zip(te_lists['r2_fusion'],va_lists['r2_fusion_holdout100'])])
    errors['r2_holdout50_fusion'].append([a - b for a, b in zip(te_lists['r2_fusion'],va_lists['r2_fusion_holdout50'])])
    errors['r2_holdout20_fusion'].append([a - b for a, b in zip(te_lists['r2_fusion'],va_lists['r2_fusion_holdout20'])])
    errors['r2_holdout10_fusion'].append([a - b for a, b in zip(te_lists['r2_fusion'],va_lists['r2_fusion_holdout10'])])
    errors['r2_cv5_fusion'].append([a - b for a, b in zip(te_lists['r2_fusion'],va_lists['r2_fusion_cv5'])])
    errors['r2_cv10_fusion'].append([a - b for a, b in zip(te_lists['r2_fusion'],va_lists['r2_fusion_cv10'])])
    errors['r2_bootstrap_fusion'].append([a - b for a, b in zip(te_lists['r2_fusion'],va_lists['r2_fusion_bootstrap'])])

    errors['rmse_holdout100_lr'].append([a - b for a, b in zip(te_lists['rmse_lr'],va_lists['rmse_lr_holdout100'])])
    errors['rmse_holdout50_lr'].append([a - b for a, b in zip(te_lists['rmse_lr'],va_lists['rmse_lr_holdout50'])])
    errors['rmse_holdout20_lr'].append([a - b for a, b in zip(te_lists['rmse_lr'],va_lists['rmse_lr_holdout20'])])
    errors['rmse_holdout10_lr'].append([a - b for a, b in zip(te_lists['rmse_lr'],va_lists['rmse_lr_holdout10'])])
    errors['rmse_cv5_lr'].append([a - b for a, b in zip(te_lists['rmse_lr'],va_lists['rmse_lr_cv5'])])
    errors['rmse_cv10_lr'].append([a - b for a, b in zip(te_lists['rmse_lr'],va_lists['rmse_lr_cv10'])])
    errors['rmse_bootstrap_lr'].append([a - b for a, b in zip(te_lists['rmse_lr'],va_lists['rmse_lr_bootstrap'])])

    errors['r2_holdout100_lr'].append([a - b for a, b in zip(te_lists['r2_lr'],va_lists['r2_lr_holdout100'])])
    errors['r2_holdout50_lr'].append([a - b for a, b in zip(te_lists['r2_lr'],va_lists['r2_lr_holdout50'])])
    errors['r2_holdout20_lr'].append([a - b for a, b in zip(te_lists['r2_lr'],va_lists['r2_lr_holdout20'])])
    errors['r2_holdout10_lr'].append([a - b for a, b in zip(te_lists['r2_lr'],va_lists['r2_lr_holdout10'])])
    errors['r2_cv5_lr'].append([a - b for a, b in zip(te_lists['r2_lr'],va_lists['r2_lr_cv5'])])
    errors['r2_cv10_lr'].append([a - b for a, b in zip(te_lists['r2_lr'],va_lists['r2_lr_cv10'])])
    errors['r2_bootstrap_lr'].append([a - b for a, b in zip(te_lists['r2_lr'],va_lists['r2_lr_bootstrap'])])

    errors['rmse_holdout100_nn'].append([a - b for a, b in zip(te_lists['rmse_nn'],va_lists['rmse_nn_holdout100'])])
    errors['rmse_holdout50_nn'].append([a - b for a, b in zip(te_lists['rmse_nn'],va_lists['rmse_nn_holdout50'])])
    errors['rmse_holdout20_nn'].append([a - b for a, b in zip(te_lists['rmse_nn'],va_lists['rmse_nn_holdout20'])])
    errors['rmse_holdout10_nn'].append([a - b for a, b in zip(te_lists['rmse_nn'],va_lists['rmse_nn_holdout10'])])
    errors['rmse_cv5_nn'].append([a - b for a, b in zip(te_lists['rmse_nn'],va_lists['rmse_nn_cv5'])])
    errors['rmse_cv10_nn'].append([a - b for a, b in zip(te_lists['rmse_nn'],va_lists['rmse_nn_cv10'])])
    errors['rmse_bootstrap_nn'].append([a - b for a, b in zip(te_lists['rmse_nn'],va_lists['rmse_nn_bootstrap'])])

    errors['r2_holdout100_nn'].append([a - b for a, b in zip(te_lists['r2_nn'],va_lists['r2_nn_holdout100'])])
    errors['r2_holdout50_nn'].append([a - b for a, b in zip(te_lists['r2_nn'],va_lists['r2_nn_holdout50'])])
    errors['r2_holdout20_nn'].append([a - b for a, b in zip(te_lists['r2_nn'],va_lists['r2_nn_holdout20'])])
    errors['r2_holdout10_nn'].append([a - b for a, b in zip(te_lists['r2_nn'],va_lists['r2_nn_holdout10'])])
    errors['r2_cv5_nn'].append([a - b for a, b in zip(te_lists['r2_nn'],va_lists['r2_nn_cv5'])])
    errors['r2_cv10_nn'].append([a - b for a, b in zip(te_lists['r2_nn'],va_lists['r2_nn_cv10'])])
    errors['r2_bootstrap_nn'].append([a - b for a, b in zip(te_lists['r2_nn'],va_lists['r2_nn_bootstrap'])])
    # ----------------------- em -----------------------
    for model_name in model_names:
        print(f"model_name: {model_name}")
        if model_name == 'linear':
            from sklearn.linear_model import LinearRegression
            lrmodel_rmse = LinearRegression()
            nnmodel_rmse = LinearRegression()
            lrmodel_r2 = LinearRegression()
            nnmodel_r2 = LinearRegression()
            lrmodel_rmse.fit(sm_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse.fit(sm_nn_tr, m_nn_tr_rmse)
            lrmodel_r2.fit(sm_lr_tr, m_lr_tr_r2)
            nnmodel_r2.fit(sm_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_subject = LinearRegression()
            nnmodel_rmse_subject = LinearRegression()
            lrmodel_r2_subject = LinearRegression()
            nnmodel_r2_subject = LinearRegression()
            lrmodel_rmse_subject.fit(s_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse_subject.fit(s_nn_tr, m_nn_tr_rmse)
            lrmodel_r2_subject.fit(s_lr_tr, m_lr_tr_r2)
            nnmodel_r2_subject.fit(s_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_proxy = LinearRegression()
            nnmodel_rmse_proxy = LinearRegression()
            lrmodel_r2_proxy = LinearRegression()
            nnmodel_r2_proxy = LinearRegression()
            lrmodel_rmse_proxy.fit(proxy_lr_tr_va, m_lr_tr_rmse)
            nnmodel_rmse_proxy.fit(proxy_nn_tr_va, m_nn_tr_rmse)
            lrmodel_r2_proxy.fit(proxy_lr_tr_va, m_lr_tr_r2)
            nnmodel_r2_proxy.fit(proxy_nn_tr_va, m_nn_tr_r2)
            lr_error_rmse = m_lr_te_rmse - lrmodel_rmse.predict(sm_lr_te)
            nn_error_rmse = m_nn_te_rmse - nnmodel_rmse.predict(sm_nn_te)
            lr_error_r2 = m_lr_te_r2 - lrmodel_r2.predict(sm_lr_te)
            nn_error_r2 = m_nn_te_r2 - nnmodel_r2.predict(sm_nn_te)
            lr_error_rmse_subject = m_lr_te_rmse - lrmodel_rmse_subject.predict(s_lr_te)
            nn_error_rmse_subject = m_nn_te_rmse - nnmodel_rmse_subject.predict(s_nn_te)
            lr_error_r2_subject = m_lr_te_r2 - lrmodel_r2_subject.predict(s_lr_te)
            nn_error_r2_subject = m_nn_te_r2 - nnmodel_r2_subject.predict(s_nn_te)
            lr_error_rmse_proxy = m_lr_te_rmse - lrmodel_rmse_proxy.predict(proxy_lr_te_va)
            nn_error_rmse_proxy = m_nn_te_rmse - nnmodel_rmse_proxy.predict(proxy_nn_te_va)
            lr_error_r2_proxy = m_lr_te_r2 - lrmodel_r2_proxy.predict(proxy_lr_te_va)
            nn_error_r2_proxy = m_nn_te_r2 - nnmodel_r2_proxy.predict(proxy_nn_te_va)
        elif model_name == 'mlp':
            from sklearn.neural_network import MLPRegressor
            lrmodel_rmse = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_rmse = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_r2 = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_r2 = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_rmse.fit(sm_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse.fit(sm_nn_tr, m_nn_tr_rmse)
            lrmodel_r2.fit(sm_lr_tr, m_lr_tr_r2)
            nnmodel_r2.fit(sm_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_subject = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_rmse_subject = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_r2_subject = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_r2_subject = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_rmse_subject.fit(s_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse_subject.fit(s_nn_tr, m_nn_tr_rmse)
            lrmodel_r2_subject.fit(s_lr_tr, m_lr_tr_r2)
            nnmodel_r2_subject.fit(s_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_proxy = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_rmse_proxy = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_r2_proxy = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_r2_proxy = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_rmse_proxy.fit(proxy_lr_tr_va, m_lr_tr_rmse)
            nnmodel_rmse_proxy.fit(proxy_nn_tr_va, m_nn_tr_rmse)
            lrmodel_r2_proxy.fit(proxy_lr_tr_va, m_lr_tr_r2)
            nnmodel_r2_proxy.fit(proxy_nn_tr_va, m_nn_tr_r2)

            lr_error_rmse = m_lr_te_rmse - lrmodel_rmse.predict(sm_lr_te)
            nn_error_rmse = m_nn_te_rmse - nnmodel_rmse.predict(sm_nn_te)
            lr_error_r2 = m_lr_te_r2 - lrmodel_r2.predict(sm_lr_te)
            nn_error_r2 = m_nn_te_r2 - nnmodel_r2.predict(sm_nn_te)
            lr_error_rmse_subject = m_lr_te_rmse - lrmodel_rmse_subject.predict(s_lr_te)
            nn_error_rmse_subject = m_nn_te_rmse - nnmodel_rmse_subject.predict(s_nn_te)
            lr_error_r2_subject = m_lr_te_r2 - lrmodel_r2_subject.predict(s_lr_te)
            nn_error_r2_subject = m_nn_te_r2 - nnmodel_r2_subject.predict(s_nn_te)
            lr_error_rmse_proxy = m_lr_te_rmse - lrmodel_rmse_proxy.predict(proxy_lr_te_va)
            nn_error_rmse_proxy = m_nn_te_rmse - nnmodel_rmse_proxy.predict(proxy_nn_te_va)
            lr_error_r2_proxy = m_lr_te_r2 - lrmodel_r2_proxy.predict(proxy_lr_te_va)
            nn_error_r2_proxy = m_nn_te_r2 - nnmodel_r2_proxy.predict(proxy_nn_te_va)

        elif model_name == 'svm':
            from sklearn.svm import SVR
            lrmodel_rmse = SVR()
            nnmodel_rmse = SVR()
            lrmodel_r2 = SVR()
            nnmodel_r2 = SVR()
            lrmodel_rmse.fit(sm_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse.fit(sm_nn_tr, m_nn_tr_rmse)
            lrmodel_r2.fit(sm_lr_tr, m_lr_tr_r2)
            nnmodel_r2.fit(sm_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_subject = SVR()
            nnmodel_rmse_subject = SVR()
            lrmodel_r2_subject = SVR()
            nnmodel_r2_subject = SVR()
            lrmodel_rmse_subject.fit(s_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse_subject.fit(s_nn_tr, m_nn_tr_rmse)
            lrmodel_r2_subject.fit(s_lr_tr, m_lr_tr_r2)
            nnmodel_r2_subject.fit(s_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_proxy = SVR()
            nnmodel_rmse_proxy = SVR()
            lrmodel_r2_proxy = SVR()
            nnmodel_r2_proxy = SVR()
            lrmodel_rmse_proxy.fit(proxy_lr_tr_va, m_lr_tr_rmse)
            nnmodel_rmse_proxy.fit(proxy_nn_tr_va, m_nn_tr_rmse)
            lrmodel_r2_proxy.fit(proxy_lr_tr_va, m_lr_tr_r2)
            nnmodel_r2_proxy.fit(proxy_nn_tr_va, m_nn_tr_r2)
            lr_error_rmse = m_lr_te_rmse - lrmodel_rmse.predict(sm_lr_te)
            nn_error_rmse = m_nn_te_rmse - nnmodel_rmse.predict(sm_nn_te)
            lr_error_r2 = m_lr_te_r2 - lrmodel_r2.predict(sm_lr_te)
            nn_error_r2 = m_nn_te_r2 - nnmodel_r2.predict(sm_nn_te)
            lr_error_rmse_subject = m_lr_te_rmse - lrmodel_rmse_subject.predict(s_lr_te)
            nn_error_rmse_subject = m_nn_te_rmse - nnmodel_rmse_subject.predict(s_nn_te)
            lr_error_r2_subject = m_lr_te_r2 - lrmodel_r2_subject.predict(s_lr_te)
            nn_error_r2_subject = m_nn_te_r2 - nnmodel_r2_subject.predict(s_nn_te)
            lr_error_rmse_proxy = m_lr_te_rmse - lrmodel_rmse_proxy.predict(proxy_lr_te_va)
            nn_error_rmse_proxy = m_nn_te_rmse - nnmodel_rmse_proxy.predict(proxy_nn_te_va)
            lr_error_r2_proxy = m_lr_te_r2 - lrmodel_r2_proxy.predict(proxy_lr_te_va)
            nn_error_r2_proxy = m_nn_te_r2 - nnmodel_r2_proxy.predict(proxy_nn_te_va)
        elif model_name == 'rf': 
            from sklearn.ensemble import RandomForestRegressor
            lrmodel_rmse = RandomForestRegressor()
            nnmodel_rmse = RandomForestRegressor()
            lrmodel_r2 = RandomForestRegressor()
            nnmodel_r2 = RandomForestRegressor()
            lrmodel_rmse.fit(sm_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse.fit(sm_nn_tr, m_nn_tr_rmse)
            lrmodel_r2.fit(sm_lr_tr, m_lr_tr_r2)
            nnmodel_r2.fit(sm_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_subject = RandomForestRegressor()
            nnmodel_rmse_subject = RandomForestRegressor()
            lrmodel_r2_subject = RandomForestRegressor()
            nnmodel_r2_subject = RandomForestRegressor()
            lrmodel_rmse_subject.fit(s_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse_subject.fit(s_nn_tr, m_nn_tr_rmse)
            lrmodel_r2_subject.fit(s_lr_tr, m_lr_tr_r2)
            nnmodel_r2_subject.fit(s_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_proxy = RandomForestRegressor()
            nnmodel_rmse_proxy = RandomForestRegressor()
            lrmodel_r2_proxy = RandomForestRegressor()
            nnmodel_r2_proxy = RandomForestRegressor()
            lrmodel_rmse_proxy.fit(proxy_lr_tr_va, m_lr_tr_rmse)
            nnmodel_rmse_proxy.fit(proxy_nn_tr_va, m_nn_tr_rmse)
            lrmodel_r2_proxy.fit(proxy_lr_tr_va, m_lr_tr_r2)
            nnmodel_r2_proxy.fit(proxy_nn_tr_va, m_nn_tr_r2)

            lr_error_rmse = m_lr_te_rmse - lrmodel_rmse.predict(sm_lr_te)
            nn_error_rmse = m_nn_te_rmse - nnmodel_rmse.predict(sm_nn_te)
            lr_error_r2 = m_lr_te_r2 - lrmodel_r2.predict(sm_lr_te)
            nn_error_r2 = m_nn_te_r2 - nnmodel_r2.predict(sm_nn_te)
            lr_error_rmse_subject = m_lr_te_rmse - lrmodel_rmse_subject.predict(s_lr_te)
            nn_error_rmse_subject = m_nn_te_rmse - nnmodel_rmse_subject.predict(s_nn_te)
            lr_error_r2_subject = m_lr_te_r2 - lrmodel_r2_subject.predict(s_lr_te)
            nn_error_r2_subject = m_nn_te_r2 - nnmodel_r2_subject.predict(s_nn_te)
            lr_error_rmse_proxy = m_lr_te_rmse - lrmodel_rmse_proxy.predict(proxy_lr_te_va)
            nn_error_rmse_proxy = m_nn_te_rmse - nnmodel_rmse_proxy.predict(proxy_nn_te_va)
            lr_error_r2_proxy = m_lr_te_r2 - lrmodel_r2_proxy.predict(proxy_lr_te_va)
            nn_error_r2_proxy = m_nn_te_r2 - nnmodel_r2_proxy.predict(proxy_nn_te_va)

        elif model_name == 'lgbm':
            import lightgbm as lgb
            lrmodel_rmse = lgb.LGBMRegressor()
            nnmodel_rmse = lgb.LGBMRegressor()
            lrmodel_r2 = lgb.LGBMRegressor()
            nnmodel_r2 = lgb.LGBMRegressor()
            lrmodel_rmse.fit(sm_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse.fit(sm_nn_tr, m_nn_tr_rmse)
            lrmodel_r2.fit(sm_lr_tr, m_lr_tr_r2)
            nnmodel_r2.fit(sm_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_subject = lgb.LGBMRegressor()
            nnmodel_rmse_subject = lgb.LGBMRegressor()
            lrmodel_r2_subject = lgb.LGBMRegressor()
            nnmodel_r2_subject = lgb.LGBMRegressor()
            lrmodel_rmse_subject.fit(s_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse_subject.fit(s_nn_tr, m_nn_tr_rmse)
            lrmodel_r2_subject.fit(s_lr_tr, m_lr_tr_r2)
            nnmodel_r2_subject.fit(s_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_proxy = lgb.LGBMRegressor()
            nnmodel_rmse_proxy = lgb.LGBMRegressor()
            lrmodel_r2_proxy = lgb.LGBMRegressor()
            nnmodel_r2_proxy = lgb.LGBMRegressor()
            lrmodel_rmse_proxy.fit(proxy_lr_tr_va, m_lr_tr_rmse)
            nnmodel_rmse_proxy.fit(proxy_nn_tr_va, m_nn_tr_rmse)
            lrmodel_r2_proxy.fit(proxy_lr_tr_va, m_lr_tr_r2)
            nnmodel_r2_proxy.fit(proxy_nn_tr_va, m_nn_tr_r2)

            lr_error_rmse = m_lr_te_rmse - lrmodel_rmse.predict(sm_lr_te)
            nn_error_rmse = m_nn_te_rmse - nnmodel_rmse.predict(sm_nn_te)
            lr_error_r2 = m_lr_te_r2 - lrmodel_r2.predict(sm_lr_te)
            nn_error_r2 = m_nn_te_r2 - nnmodel_r2.predict(sm_nn_te)
            lr_error_rmse_subject = m_lr_te_rmse - lrmodel_rmse_subject.predict(s_lr_te)
            nn_error_rmse_subject = m_nn_te_rmse - nnmodel_rmse_subject.predict(s_nn_te)
            lr_error_r2_subject = m_lr_te_r2 - lrmodel_r2_subject.predict(s_lr_te)
            nn_error_r2_subject = m_nn_te_r2 - nnmodel_r2_subject.predict(s_nn_te)
            lr_error_rmse_proxy = m_lr_te_rmse - lrmodel_rmse_proxy.predict(proxy_lr_te_va)
            nn_error_rmse_proxy = m_nn_te_rmse - nnmodel_rmse_proxy.predict(proxy_nn_te_va)
            lr_error_r2_proxy = m_lr_te_r2 - lrmodel_r2_proxy.predict(proxy_lr_te_va)
            nn_error_r2_proxy = m_nn_te_r2 - nnmodel_r2_proxy.predict(proxy_nn_te_va)

        elif model_name == 'xgb':
            import xgboost as xgb
            lrmodel_rmse = xgb.XGBRegressor()
            nnmodel_rmse = xgb.XGBRegressor()
            lrmodel_r2 = xgb.XGBRegressor()
            nnmodel_r2 = xgb.XGBRegressor()
            lrmodel_rmse.fit(sm_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse.fit(sm_nn_tr, m_nn_tr_rmse)
            lrmodel_r2.fit(sm_lr_tr, m_lr_tr_r2)
            nnmodel_r2.fit(sm_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_subject = xgb.XGBRegressor()
            nnmodel_rmse_subject = xgb.XGBRegressor()
            lrmodel_r2_subject = xgb.XGBRegressor()
            nnmodel_r2_subject = xgb.XGBRegressor()
            lrmodel_rmse_subject.fit(s_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse_subject.fit(s_nn_tr, m_nn_tr_rmse)
            lrmodel_r2_subject.fit(s_lr_tr, m_lr_tr_r2)
            nnmodel_r2_subject.fit(s_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_proxy = xgb.XGBRegressor()
            nnmodel_rmse_proxy = xgb.XGBRegressor()
            lrmodel_r2_proxy = xgb.XGBRegressor()
            nnmodel_r2_proxy = xgb.XGBRegressor()
            lrmodel_rmse_proxy.fit(proxy_lr_tr_va, m_lr_tr_rmse)
            nnmodel_rmse_proxy.fit(proxy_nn_tr_va, m_nn_tr_rmse)
            lrmodel_r2_proxy.fit(proxy_lr_tr_va, m_lr_tr_r2)
            nnmodel_r2_proxy.fit(proxy_nn_tr_va, m_nn_tr_r2)

            lr_error_rmse = m_lr_te_rmse - lrmodel_rmse.predict(sm_lr_te)
            nn_error_rmse = m_nn_te_rmse - nnmodel_rmse.predict(sm_nn_te)
            lr_error_r2 = m_lr_te_r2 - lrmodel_r2.predict(sm_lr_te)
            nn_error_r2 = m_nn_te_r2 - nnmodel_r2.predict(sm_nn_te)
            lr_error_rmse_subject = m_lr_te_rmse - lrmodel_rmse_subject.predict(s_lr_te)
            nn_error_rmse_subject = m_nn_te_rmse - nnmodel_rmse_subject.predict(s_nn_te)
            lr_error_r2_subject = m_lr_te_r2 - lrmodel_r2_subject.predict(s_lr_te)
            nn_error_r2_subject = m_nn_te_r2 - nnmodel_r2_subject.predict(s_nn_te)
            lr_error_rmse_proxy = m_lr_te_rmse - lrmodel_rmse_proxy.predict(proxy_lr_te_va)
            nn_error_rmse_proxy = m_nn_te_rmse - nnmodel_rmse_proxy.predict(proxy_nn_te_va)
            lr_error_r2_proxy = m_lr_te_r2 - lrmodel_r2_proxy.predict(proxy_lr_te_va)
            nn_error_r2_proxy = m_nn_te_r2 - nnmodel_r2_proxy.predict(proxy_nn_te_va)

        elif model_name == 'catboost':
            from catboost import CatBoostRegressor
            lrmodel_rmse = CatBoostRegressor()
            nnmodel_rmse = CatBoostRegressor()
            lrmodel_r2 = CatBoostRegressor()
            nnmodel_r2 = CatBoostRegressor()
            lrmodel_rmse.fit(sm_lr_tr, m_lr_tr_rmse, verbose=0)
            nnmodel_rmse.fit(sm_nn_tr, m_nn_tr_rmse, verbose=0)
            lrmodel_r2.fit(sm_lr_tr, m_lr_tr_r2, verbose=0)
            nnmodel_r2.fit(sm_nn_tr, m_nn_tr_r2, verbose=0)
            lrmodel_rmse_subject = CatBoostRegressor()
            nnmodel_rmse_subject = CatBoostRegressor()
            lrmodel_r2_subject = CatBoostRegressor()
            nnmodel_r2_subject = CatBoostRegressor()
            lrmodel_rmse_subject.fit(s_lr_tr, m_lr_tr_rmse, verbose=0)
            nnmodel_rmse_subject.fit(s_nn_tr, m_nn_tr_rmse, verbose=0)
            lrmodel_r2_subject.fit(s_lr_tr, m_lr_tr_r2, verbose=0)
            nnmodel_r2_subject.fit(s_nn_tr, m_nn_tr_r2, verbose=0)
            lrmodel_rmse_proxy = CatBoostRegressor()
            nnmodel_rmse_proxy = CatBoostRegressor()
            lrmodel_r2_proxy = CatBoostRegressor()
            nnmodel_r2_proxy = CatBoostRegressor()
            lrmodel_rmse_proxy.fit(proxy_lr_tr_va, m_lr_tr_rmse, verbose=0)
            nnmodel_rmse_proxy.fit(proxy_nn_tr_va, m_nn_tr_rmse, verbose=0)
            lrmodel_r2_proxy.fit(proxy_lr_tr_va, m_lr_tr_r2, verbose=0)
            nnmodel_r2_proxy.fit(proxy_nn_tr_va, m_nn_tr_r2, verbose=0)

            lr_error_rmse = m_lr_te_rmse - lrmodel_rmse.predict(sm_lr_te)
            nn_error_rmse = m_nn_te_rmse - nnmodel_rmse.predict(sm_nn_te)
            lr_error_r2 = m_lr_te_r2 - lrmodel_r2.predict(sm_lr_te)
            nn_error_r2 = m_nn_te_r2 - nnmodel_r2.predict(sm_nn_te)
            lr_error_rmse_subject = m_lr_te_rmse - lrmodel_rmse_subject.predict(s_lr_te)
            nn_error_rmse_subject = m_nn_te_rmse - nnmodel_rmse_subject.predict(s_nn_te)
            lr_error_r2_subject = m_lr_te_r2 - lrmodel_r2_subject.predict(s_lr_te)
            nn_error_r2_subject = m_nn_te_r2 - nnmodel_r2_subject.predict(s_nn_te)
            lr_error_rmse_proxy = m_lr_te_rmse - lrmodel_rmse_proxy.predict(proxy_lr_te_va)
            nn_error_rmse_proxy = m_nn_te_rmse - nnmodel_rmse_proxy.predict(proxy_nn_te_va)
            lr_error_r2_proxy = m_lr_te_r2 - lrmodel_r2_proxy.predict(proxy_lr_te_va)
            nn_error_r2_proxy = m_nn_te_r2 - nnmodel_r2_proxy.predict(proxy_nn_te_va)

        elif model_name == 'tabpfn':
            from tabpfn import TabPFNClassifier, TabPFNRegressor
            lrmodel_rmse = TabPFNRegressor()
            nnmodel_rmse = TabPFNRegressor()
            lrmodel_r2 = TabPFNRegressor()
            nnmodel_r2 = TabPFNRegressor()
            lrmodel_rmse.fit(sm_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse.fit(sm_nn_tr, m_nn_tr_rmse)
            lrmodel_r2.fit(sm_lr_tr, m_lr_tr_r2)
            nnmodel_r2.fit(sm_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_subject = TabPFNRegressor()
            nnmodel_rmse_subject = TabPFNRegressor()
            lrmodel_r2_subject = TabPFNRegressor()
            nnmodel_r2_subject = TabPFNRegressor()
            lrmodel_rmse_subject.fit(s_lr_tr, m_lr_tr_rmse)
            nnmodel_rmse_subject.fit(s_nn_tr, m_nn_tr_rmse)
            lrmodel_r2_subject.fit(s_lr_tr, m_lr_tr_r2)
            nnmodel_r2_subject.fit(s_nn_tr, m_nn_tr_r2)
            lrmodel_rmse_proxy = TabPFNRegressor()
            nnmodel_rmse_proxy = TabPFNRegressor()
            lrmodel_r2_proxy = TabPFNRegressor()
            nnmodel_r2_proxy = TabPFNRegressor()
            lrmodel_rmse_proxy.fit(proxy_lr_tr_va, m_lr_tr_rmse)
            nnmodel_rmse_proxy.fit(proxy_nn_tr_va, m_nn_tr_rmse)
            lrmodel_r2_proxy.fit(proxy_lr_tr_va, m_lr_tr_r2)
            nnmodel_r2_proxy.fit(proxy_nn_tr_va, m_nn_tr_r2)

            lr_error_rmse = m_lr_te_rmse - lrmodel_rmse.predict(sm_lr_te)
            nn_error_rmse = m_nn_te_rmse - nnmodel_rmse.predict(sm_nn_te)
            lr_error_r2 = m_lr_te_r2 - lrmodel_r2.predict(sm_lr_te)
            nn_error_r2 = m_nn_te_r2 - nnmodel_r2.predict(sm_nn_te)
            lr_error_rmse_subject = m_lr_te_rmse - lrmodel_rmse_subject.predict(s_lr_te)
            nn_error_rmse_subject = m_nn_te_rmse - nnmodel_rmse_subject.predict(s_nn_te)
            lr_error_r2_subject = m_lr_te_r2 - lrmodel_r2_subject.predict(s_lr_te)
            nn_error_r2_subject = m_nn_te_r2 - nnmodel_r2_subject.predict(s_nn_te)
            lr_error_rmse_proxy = m_lr_te_rmse - lrmodel_rmse_proxy.predict(proxy_lr_te_va)
            nn_error_rmse_proxy = m_nn_te_rmse - nnmodel_rmse_proxy.predict(proxy_nn_te_va)
            lr_error_r2_proxy = m_lr_te_r2 - lrmodel_r2_proxy.predict(proxy_lr_te_va)
            nn_error_r2_proxy = m_nn_te_r2 - nnmodel_r2_proxy.predict(proxy_nn_te_va)

        else: 
            raise ValueError("model_name is not supported!")
        
        # merge the lr error and nn error, 
        error_list_rmse = np.concatenate([lr_error_rmse, nn_error_rmse])
        error_list_r2 = np.concatenate([lr_error_r2, nn_error_r2])
        error_list_rmse_subject = np.concatenate([lr_error_rmse_subject, nn_error_rmse_subject])
        error_list_r2_subject = np.concatenate([lr_error_r2_subject, nn_error_r2_subject])
        error_list_rmse_proxy = np.concatenate([lr_error_rmse_proxy, nn_error_rmse_proxy])
        error_list_r2_proxy = np.concatenate([lr_error_r2_proxy, nn_error_r2_proxy])

        # error, 
        errors[f"{model_name}_rmse_fusion"].append(error_list_rmse)
        errors[f"{model_name}_r2_fusion"].append(error_list_r2)
        errors[f"{model_name}_rmse_lr"].append(lr_error_rmse)
        errors[f"{model_name}_rmse_nn"].append(nn_error_rmse)
        errors[f"{model_name}_r2_lr"].append(lr_error_r2)
        errors[f"{model_name}_r2_nn"].append(nn_error_r2)
        # subject error, 
        errors_subject[f"{model_name}_rmse_fusion"].append(error_list_rmse_subject)
        errors_subject[f"{model_name}_r2_fusion"].append(error_list_r2_subject)
        errors_subject[f"{model_name}_rmse_lr"].append(lr_error_rmse_subject)
        errors_subject[f"{model_name}_rmse_nn"].append(nn_error_rmse_subject)
        errors_subject[f"{model_name}_r2_lr"].append(lr_error_r2_subject)
        errors_subject[f"{model_name}_r2_nn"].append(nn_error_r2_subject)
        # proxy error, 
        errors_proxy[f"{model_name}_rmse_fusion"].append(error_list_rmse_proxy)
        errors_proxy[f"{model_name}_r2_fusion"].append(error_list_r2_proxy)
        errors_proxy[f"{model_name}_rmse_lr"].append(lr_error_rmse_proxy)
        errors_proxy[f"{model_name}_rmse_nn"].append(nn_error_rmse_proxy)
        errors_proxy[f"{model_name}_r2_lr"].append(lr_error_r2_proxy)
        errors_proxy[f"{model_name}_r2_nn"].append(nn_error_r2_proxy)

        # reshape the error_list to 1-d array, 
        error_list_rmse = error_list_rmse.reshape(-1)
        error_list_r2 = error_list_r2.reshape(-1)
        error_list_rmse = winsorize(error_list_rmse)
        error_list_r2 = winsorize(error_list_r2)

        # ------iid check by central limit theorem-------
        # (1) randomly select half errors, total 30 groups, 
        n_group = 30
        group_means_rmse = []
        group_means_r2 = []
        group_means_rmse_lr = []
        group_means_rmse_nn = []
        group_means_r2_lr = []
        group_means_r2_nn = []
        for i in range(n_group):
            # randomly select half errors from error_list, 
            selected_errors_rmse = np.random.choice(error_list_rmse, int(len(error_list_rmse)/2))
            selected_errors_r2 = np.random.choice(error_list_r2, int(len(error_list_r2)/2))
            selected_errors_rmse_lr = np.random.choice(lr_error_rmse, int(len(lr_error_rmse)/2))
            selected_errors_rmse_nn = np.random.choice(nn_error_rmse, int(len(nn_error_rmse)/2))
            selected_errors_r2_lr = np.random.choice(lr_error_r2, int(len(lr_error_r2)/2))
            selected_errors_r2_nn = np.random.choice(nn_error_r2, int(len(nn_error_r2)/2))
            # calculate the mean of the error, 
            group_means_rmse.append(np.mean(selected_errors_rmse))
            group_means_r2.append(np.mean(selected_errors_r2))
            group_means_rmse_lr.append(np.mean(selected_errors_rmse_lr))
            group_means_rmse_nn.append(np.mean(selected_errors_rmse_nn))
            group_means_r2_lr.append(np.mean(selected_errors_r2_lr))
            group_means_r2_nn.append(np.mean(selected_errors_r2_nn))
        
        # (2) mean of each group, should be a normal distribution, 
        # use kurtosistest test to check the normal distribution, 
        from scipy.stats import normaltest
        stat_rmse, p_rmse = normaltest(group_means_rmse)
        stat_r2, p_r2 = normaltest(group_means_r2)
        stat_rmse_lr, p_rmse_lr = normaltest(group_means_rmse_lr)
        stat_rmse_nn, p_rmse_nn = normaltest(group_means_rmse_nn)
        stat_r2_lr, p_r2_lr = normaltest(group_means_r2_lr)
        stat_r2_nn, p_r2_nn = normaltest(group_means_r2_nn)
        # model_name = "linear", test_name = "iid", target_name = "rmse", subject_name = "fusion"
        p_values[f"{model_name}_iid_rmse_fusion"].append(p_rmse)
        p_values[f"{model_name}_iid_r2_fusion"].append(p_r2)
        p_values[f"{model_name}_iid_rmse_lr"].append(p_rmse_lr)
        p_values[f"{model_name}_iid_rmse_nn"].append(p_rmse_nn)
        p_values[f"{model_name}_iid_r2_lr"].append(p_r2_lr)
        p_values[f"{model_name}_iid_r2_nn"].append(p_r2_nn)

        # ------id check by ks test-------
        n_sample_per_group = 30
        n_group = len(error_list_rmse) // n_sample_per_group
        n_group_lr = len(lr_error_rmse) // n_sample_per_group
        n_group_nn = len(nn_error_rmse) // n_sample_per_group
        # randomly split the error_list into n_group groups, 
        group_errors_rmse = np.array_split(error_list_rmse, n_group)
        group_errors_r2 = np.array_split(error_list_r2, n_group)
        group_errors_rmse_lr = np.array_split(lr_error_rmse, n_group_lr)
        group_errors_rmse_nn = np.array_split(nn_error_rmse, n_group_nn)
        group_errors_r2_lr = np.array_split(lr_error_r2, n_group_lr)
        group_errors_r2_nn = np.array_split(nn_error_r2, n_group_nn)

        # transformation
        # use ks test pair-wise to check the whether n groups are from the same distribution, 
        from scipy.stats import kruskal
        stat_rmse, p_rmse = kruskal(*group_errors_rmse)
        stat_r2, p_r2 = kruskal(*group_errors_r2)
        stat_rmse_lr, p_rmse_lr = kruskal(*group_errors_rmse_lr)
        stat_rmse_nn, p_rmse_nn = kruskal(*group_errors_rmse_nn)
        stat_r2_lr, p_r2_lr = kruskal(*group_errors_r2_lr)
        stat_r2_nn, p_r2_nn = kruskal(*group_errors_r2_nn)
        # model_name = "linear", test_name = "id", target_name = "rmse", subject_name = "fusion"
        p_values[f"{model_name}_id_rmse_fusion"].append(p_rmse)
        p_values[f"{model_name}_id_r2_fusion"].append(p_r2)
        p_values[f"{model_name}_id_rmse_lr"].append(p_rmse_lr)
        p_values[f"{model_name}_id_rmse_nn"].append(p_rmse_nn)
        p_values[f"{model_name}_id_r2_lr"].append(p_r2_lr)
        p_values[f"{model_name}_id_r2_nn"].append(p_r2_nn)
        
        # ------bias check by something-------
        # null hypothesis: check the bias of the error is 0,
        from scipy.stats import ttest_1samp
        stat_rmse, p_rmse = ttest_1samp(error_list_rmse, 0)
        stat_r2, p_r2 = ttest_1samp(error_list_r2, 0)
        stat_rmse_lr, p_rmse_lr = ttest_1samp(lr_error_rmse, 0)
        stat_rmse_nn, p_rmse_nn = ttest_1samp(nn_error_rmse, 0)
        stat_r2_lr, p_r2_lr = ttest_1samp(lr_error_r2, 0)
        stat_r2_nn, p_r2_nn = ttest_1samp(nn_error_r2, 0)
        # model_name = "linear", test_name = "unbias", target_name = "rmse", subject_name = "fusion"
        p_values[f"{model_name}_bias_rmse_fusion"].append(p_rmse)
        p_values[f"{model_name}_bias_r2_fusion"].append(p_r2)
        p_values[f"{model_name}_bias_rmse_lr"].append(p_rmse_lr)
        p_values[f"{model_name}_bias_rmse_nn"].append(p_rmse_nn)
        p_values[f"{model_name}_bias_r2_lr"].append(p_r2_lr)
        p_values[f"{model_name}_bias_r2_nn"].append(p_r2_nn)
        # ---------------------------------------------------------------------------------------------------
print("complete evaluation!")
# ---------------------------evaluation task-----------------------------
# save the p_values, errors, errors_subject, errors_proxy, (dictories file)
import pickle
with open('./expdata/' + dataset_name + '/p_values.pkl', 'wb') as f:
    pickle.dump(p_values, f)
with open('./expdata/' + dataset_name + '/errors.pkl', 'wb') as f:
    pickle.dump(errors, f)
with open('./expdata/' + dataset_name + '/errors_subject.pkl', 'wb') as f:
    pickle.dump(errors_subject, f)
with open('./expdata/' + dataset_name + '/errors_proxy.pkl', 'wb') as f:
    pickle.dump(errors_proxy, f)
print("complete write!")
# ---------------------------evaluation task-----------------------------
# load the p_values, errors, errors_subject, errors_proxy, (dictories file)
import pickle
with open('./expdata/' + dataset_name + '/p_values.pkl', 'rb') as f:
    p_values = pickle.load(f)
with open('./expdata/' + dataset_name + '/errors.pkl', 'rb') as f:
    errors = pickle.load(f)
with open('./expdata/' + dataset_name + '/errors_subject.pkl', 'rb') as f:
    errors_subject = pickle.load(f)
with open('./expdata/' + dataset_name + '/errors_proxy.pkl', 'rb') as f:
    errors_proxy = pickle.load(f)
# ---------------------------evaluation task-----------------------------
print("complete load!")

# t_holdout10: 3.4820401668548584 per subject: 0.0017410200834274293
# t_holdout20: 3.4632561206817627 per subject: 0.0017316280603408814
# t_holdout50: 3.619875192642212 per subject: 0.001809937596321106
# t_holdout100: 3.9251928329467773 per subject: 0.001962596416473389
# t_cv5: 19.77456498146057 per subject: 0.009887282490730285
# t_cv10: 38.206010818481445 per subject: 0.019103005409240724
# t_bootstrap: 4.334221839904785 per subject: 0.0021671109199523928
# all time: 76.8140230178833 per subject: 0.03840701150894165