import sys
import os
# 获取当前文件的绝对路径，并向上追溯到项目根目录（your_project）
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)  # 根据实际层级调整
sys.path.append(project_root)
# import a funciton at ../datasets/datasetloader.py
from datasets.datasetloader import getdatasetpath
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, accuracy_score, recall_score, precision_score, f1_score, average_precision_score
from sklearn.model_selection import train_test_split
import warnings
from sklearn.model_selection import KFold
warnings.filterwarnings("ignore")

def winsorize(data, lower=0.05, upper=0.95): 
    lower_bound = np.quantile(data, lower)
    upper_bound = np.quantile(data, upper)
    return np.clip(data, lower_bound, upper_bound)

n_exp = 30
n_group = 30
dataset_name = 'grid'

# ------------------------assessment of iide and bias-------------------------
p_values = {}
model_names = ['linear', 'mlp', 'svm', 'rf', 'lgbm', 'xgb', 'catboost'] # , 'tabpfn']
test_names = ['iid', 'id', 'bias']
target_names = ['rocauc', 'acc']
subject_names = ['lr', 'nn', 'fusion']
shap_features = ['fusion', 'subject', 'proxy']

for model_name in model_names:
    for test_name in test_names:
        for target_name in target_names:
            for subject_name in subject_names:
                p_values[f"{model_name}_{test_name}_{target_name}_{subject_name}"] = []
# -------------------------------------------------------------------

# ------------------------assessment of errors-------------------------
errors = {}
errors_subject = {}
errors_proxy = {}

for model_name in model_names:
    for target_name in target_names:
        for subject_name in subject_names:
            errors[f"{model_name}_{target_name}_{subject_name}"] = []
            errors_subject[f"{model_name}_{target_name}_{subject_name}"] = []
            errors_proxy[f"{model_name}_{target_name}_{subject_name}"] = []
# add baseline model: cv5, cv10, bootstrap, holdout100, holdout50, holdout20, holdout10
errors['rocauc_holdout100_fusion'] = []
errors['rocauc_holdout50_fusion'] = []
errors['rocauc_holdout20_fusion'] = []
errors['rocauc_holdout10_fusion'] = []
errors['rocauc_cv5_fusion'] = []
errors['rocauc_cv10_fusion'] = []
errors['rocauc_bootstrap_fusion'] = []
errors['rocauc_holdout100_lr'] = []
errors['rocauc_holdout50_lr'] = []
errors['rocauc_holdout20_lr'] = []
errors['rocauc_holdout10_lr'] = []
errors['rocauc_cv5_lr'] = []
errors['rocauc_cv10_lr'] = []
errors['rocauc_bootstrap_lr'] = []
errors['rocauc_holdout100_nn'] = []
errors['rocauc_holdout50_nn'] = []
errors['rocauc_holdout20_nn'] = []
errors['rocauc_holdout10_nn'] = []
errors['rocauc_cv5_nn'] = []
errors['rocauc_cv10_nn'] = []
errors['rocauc_bootstrap_nn'] = []
errors['acc_holdout100_fusion'] = []
errors['acc_holdout50_fusion'] = []
errors['acc_holdout20_fusion'] = []
errors['acc_holdout10_fusion'] = []
errors['acc_cv5_fusion'] = []
errors['acc_cv10_fusion'] = []
errors['acc_bootstrap_fusion'] = []
errors['acc_holdout100_lr'] = []
errors['acc_holdout50_lr'] = []
errors['acc_holdout20_lr'] = []
errors['acc_holdout10_lr'] = []
errors['acc_cv5_lr'] = []
errors['acc_cv10_lr'] = []
errors['acc_bootstrap_lr'] = []
errors['acc_holdout100_nn'] = []
errors['acc_holdout50_nn'] = []
errors['acc_holdout20_nn'] = []
errors['acc_holdout10_nn'] = []
errors['acc_cv5_nn'] = []
errors['acc_cv10_nn'] = []
errors['acc_bootstrap_nn'] = []
# -------------------------------------------------------------------

for i_exp in range(n_exp):
    print(f"i_exp: {i_exp}")
    # ---------------------------------read dataset---------------------------------
    datasetpath = getdatasetpath(dataset_name)
    # data_df = pd.read_csv(datasetpath)
    # load arff file, 
    import arff
    data = arff.load(open(datasetpath, 'r'))
    data_df = pd.DataFrame(data['data'])
    
    # ---------------------------------selected the needed columns---------------------------------
    x_column_names = data_df.columns.tolist()[:-1]
    y_column_names = [data_df.columns.tolist()[-1]]
    data_df = data_df[x_column_names + y_column_names]
    # ---------------------------------------------------------------------------------------------------
    
    # change y column to binary, 'stable' to 0, 'unstable' to 1
    data_df[y_column_names] = data_df[y_column_names].replace({'stable': 0, 'unstable': 1})
    # print(data_df.head())
    # ---------------------------------------------------------------------------------------------

    # ---------------------------------preprocess the dataset---------------------------------
    # check NAN values, and remove rows with any NAN values, 
    data_df = data_df.dropna()
    # normalize the data of x columns, 
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    x = data_df.drop(columns=y_column_names)
    y = data_df[y_column_names]
    X_scaled = scaler.fit_transform(x)
    X_scaled_df = pd.DataFrame(X_scaled, columns=x.columns)
    data_df = pd.concat([X_scaled_df, y], axis=1)
    print("data_df head:", data_df.shape)
    # input()

    # -------------evaluation task--------------------
    # subject: logistic regression, and 2-layer NN
    #   subject outut: predict death14 after intentional alert
    # condition: non
    # metric: roc_auc_score
    # ------------------------------------------------

    # -------------es (use test as es), bench (use va as bench)----------------
    test_ratio = 0.2
    df_tr = data_df.head(int(data_df.shape[0] * (1 - test_ratio)))
    df_te = data_df.tail(int(data_df.shape[0] * test_ratio))
    # randomly sample val from train, 
    x_va = df_tr[x_column_names]
    y_va = df_tr[y_column_names]
    x_te = df_te[x_column_names]
    y_te = df_te[y_column_names]
    # to tensor, 
    x_va = torch.tensor(x_va.values, dtype=torch.float32)
    y_va = torch.tensor(y_va.values, dtype=torch.float32)
    x_te = torch.tensor(x_te.values, dtype=torch.float32)
    y_te = torch.tensor(y_te.values, dtype=torch.float32)
    # ---------------------------------------------------------------------------------------------------

    # ------------------------------create subject space---------------------------------------------------------------------
    # subject space: logistic regression, and 2-layer NN
    # logistic regression
    class LinearSubject(torch.nn.Module):
        def __init__(self, n_input):
            super(LinearSubject, self).__init__()
            self.linear = torch.nn.Linear(n_input, 1)
            self.sigmoid = torch.nn.Sigmoid()
        def forward(self, x):
            x = self.linear(x)
            x = self.sigmoid(x)
            # 0 or 1
            x = torch.round(x)
            return x
    # 2-layer NN
    class NNSubject(torch.nn.Module):
        def __init__(self, n_input, maxhidden = 8):
            super(NNSubject, self).__init__()
            self.linear1 = torch.nn.Linear(n_input, maxhidden)
            self.linear2 = torch.nn.Linear(maxhidden, 1)
            self.sigmoid = torch.nn.Sigmoid()
            self.layer1 = torch.nn.Sequential(
                self.linear1,
                self.sigmoid
            )
            self.layer2 = torch.nn.Sequential(
                self.linear2,
                self.sigmoid
            )

        def forward(self, x):
            layer1_out = self.layer1(x)
            layer2_out = self.layer2(layer1_out)
            layer2_out = torch.round(layer2_out)
            return layer2_out
        
    # ---------------------------------------------------------------------------------------------------

    # ---------------------------sampling from subject space-----------------------------
    n_subjects = 2000
    subject_lr_list = []
    subject_lr_vector_list = []
    subject_nn_list = []
    subject_nn_vector_list = []
    subject_label = []
    subject_list = []
    subject_vector_list = []
    n_input = len(x_column_names)
    for i in range(n_subjects): # ------------subject space sampling methods------------ 
        # 0: 0.5 probability to select a linear, and 1: 0.5 probability to select a NN
        subject_label.append(np.random.choice([0, 1]))
        if subject_label[-1] == 0: # create a linear subject, 
            subject = LinearSubject(n_input)
            subject_lr_list.append(subject)
            subject_lr_vector_list.append(torch.cat([p.view(-1) for p in subject.parameters()]).detach().numpy().tolist())
        else: # create a NN subject, 
            subject = NNSubject(n_input)
            subject_nn_list.append(subject)
            subject_nn_vector_list.append(torch.cat([p.view(-1) for p in subject.parameters()]).detach().numpy().tolist())
        subject_list.append(subject)
        subject_vector = torch.cat([p.view(-1) for p in subject.parameters()]).detach().numpy().tolist()
        subject_vector_list.append(subject_vector)
    # ---------------------------------------------------------------------------------------------------

    # -----------------------sampling from the metric space-------------------------------
    va_metric_names = ['rocauc', 'acc', 'recall', 'precision', 'f1', 'prauc']
    subject_names = ['lr', 'nn', 'fusion']
    va_valid_names = ['cv5', 'cv10', 'bootstrap', 'holdout100', 'holdout50', 'holdout20', 'holdout10']

    va_lists = {}
    for va_metric_name in va_metric_names:
        for subject_name in subject_names:
            for va_valid_name in va_valid_names:
                va_lists[f"{va_metric_name}_{subject_name}_{va_valid_name}"] = []
    
    te_lists = {}
    for va_metric_name in va_metric_names:
        for subject_name in subject_names:
            te_lists[f"{va_metric_name}_{subject_name}"] = []
    # ---------------------------------------------------------------------------------------------------

    for i in range(n_subjects): # on test set, 
        subject = subject_list[i]
        output = subject(x_te)
        rocauc = roc_auc_score(y_te, output.detach().numpy())
        # print("y te, output:", y_te, output)
        acc = accuracy_score(y_te, output.detach().numpy())
        recall = recall_score(y_te, output.detach().numpy())
        precision = precision_score(y_te, output.detach().numpy())
        f1 = f1_score(y_te, output.detach().numpy())
        prauc = average_precision_score(y_te, output.detach().numpy())
        te_lists[f"rocauc_fusion"].append(rocauc)
        te_lists[f"acc_fusion"].append(acc)
        te_lists[f"recall_fusion"].append(recall)
        te_lists[f"precision_fusion"].append(precision)
        te_lists[f"f1_fusion"].append(f1)
        te_lists[f"prauc_fusion"].append(prauc)
        if subject_label[i] == 0:
            te_lists[f"rocauc_lr"].append(rocauc)
            te_lists[f"acc_lr"].append(acc)
            te_lists[f"recall_lr"].append(recall)
            te_lists[f"precision_lr"].append(precision)
            te_lists[f"f1_lr"].append(f1)
            te_lists[f"prauc_lr"].append(prauc)
        else:
            te_lists[f"rocauc_nn"].append(rocauc)
            te_lists[f"acc_nn"].append(acc)
            te_lists[f"recall_nn"].append(recall)
            te_lists[f"precision_nn"].append(precision)
            te_lists[f"f1_nn"].append(f1)
            te_lists[f"prauc_nn"].append(prauc)
    import time
    start = time.time()
    t_holdout10 = 0
    t_holdout20 = 0
    t_holdout50 = 0
    t_holdout100 = 0
    t_cv5 = 0
    t_cv10 = 0
    t_bootstrap = 0
    for i in range(n_subjects): # on validation set, 
        # print(f"validation i: {i}")
        subject = subject_list[i]
        # holdout 0.1
        start_holdout10 = time.time()
        x_va10 = x_va[:int(len(x_va) * 0.1)]
        y_va10 = y_va[:int(len(y_va) * 0.1)]
        output = subject(x_va10)
        rocauc = roc_auc_score(y_va10, output.detach().numpy())
        acc = accuracy_score(y_va10, output.detach().numpy())
        recall = recall_score(y_va10, output.detach().numpy())
        precision = precision_score(y_va10, output.detach().numpy())
        f1 = f1_score(y_va10, output.detach().numpy())
        prauc = average_precision_score(y_va10, output.detach().numpy())
        va_lists[f"rocauc_fusion_holdout10"].append(rocauc)
        va_lists[f"acc_fusion_holdout10"].append(acc)
        va_lists[f"recall_fusion_holdout10"].append(recall)
        va_lists[f"precision_fusion_holdout10"].append(precision)
        va_lists[f"f1_fusion_holdout10"].append(f1)
        va_lists[f"prauc_fusion_holdout10"].append(prauc)
        if subject_label[i] == 0: 
            va_lists[f"rocauc_lr_holdout10"].append(rocauc)
            va_lists[f"acc_lr_holdout10"].append(acc)
            va_lists[f"recall_lr_holdout10"].append(recall)
            va_lists[f"precision_lr_holdout10"].append(precision)
            va_lists[f"f1_lr_holdout10"].append(f1)
            va_lists[f"prauc_lr_holdout10"].append(prauc)
        else:
            va_lists[f"rocauc_nn_holdout10"].append(rocauc)
            va_lists[f"acc_nn_holdout10"].append(acc)
            va_lists[f"recall_nn_holdout10"].append(recall)
            va_lists[f"precision_nn_holdout10"].append(precision)
            va_lists[f"f1_nn_holdout10"].append(f1)
            va_lists[f"prauc_nn_holdout10"].append(prauc)
        t_holdout10 += time.time() - start_holdout10
        # holdout 0.2
        start_holdout20 = time.time()
        x_va20 = x_va[:int(len(x_va) * 0.2)]
        y_va20 = y_va[:int(len(y_va) * 0.2)]
        output = subject(x_va20)
        rocauc = roc_auc_score(y_va20, output.detach().numpy())
        acc = accuracy_score(y_va20, output.detach().numpy())
        recall = recall_score(y_va20, output.detach().numpy())
        precision = precision_score(y_va20, output.detach().numpy())
        f1 = f1_score(y_va20, output.detach().numpy())
        prauc = average_precision_score(y_va20, output.detach().numpy())
        va_lists[f"rocauc_fusion_holdout20"].append(rocauc)
        va_lists[f"acc_fusion_holdout20"].append(acc)
        va_lists[f"recall_fusion_holdout20"].append(recall)
        va_lists[f"precision_fusion_holdout20"].append(precision)
        va_lists[f"f1_fusion_holdout20"].append(f1)
        va_lists[f"prauc_fusion_holdout20"].append(prauc)
        if subject_label[i] == 0: 
            va_lists[f"rocauc_lr_holdout20"].append(rocauc)
            va_lists[f"acc_lr_holdout20"].append(acc)
            va_lists[f"recall_lr_holdout20"].append(recall)
            va_lists[f"precision_lr_holdout20"].append(precision)
            va_lists[f"f1_lr_holdout20"].append(f1)
            va_lists[f"prauc_lr_holdout20"].append(prauc)
        else:
            va_lists[f"rocauc_nn_holdout20"].append(rocauc)
            va_lists[f"acc_nn_holdout20"].append(acc)
            va_lists[f"recall_nn_holdout20"].append(recall)
            va_lists[f"precision_nn_holdout20"].append(precision)
            va_lists[f"f1_nn_holdout20"].append(f1)
            va_lists[f"prauc_nn_holdout20"].append(prauc)
        t_holdout20 += time.time() - start_holdout20
        # holdout 0.5
        start_holdout50 = time.time()
        x_va50 = x_va[:int(len(x_va) * 0.5)]
        y_va50 = y_va[:int(len(y_va) * 0.5)]
        output = subject(x_va50)
        rocauc = roc_auc_score(y_va50, output.detach().numpy())
        acc = accuracy_score(y_va50, output.detach().numpy())
        recall = recall_score(y_va50, output.detach().numpy())
        precision = precision_score(y_va50, output.detach().numpy())
        f1 = f1_score(y_va50, output.detach().numpy())
        prauc = average_precision_score(y_va50, output.detach().numpy())
        va_lists[f"rocauc_fusion_holdout50"].append(rocauc)
        va_lists[f"acc_fusion_holdout50"].append(acc)
        va_lists[f"recall_fusion_holdout50"].append(recall)
        va_lists[f"precision_fusion_holdout50"].append(precision)
        va_lists[f"f1_fusion_holdout50"].append(f1)
        va_lists[f"prauc_fusion_holdout50"].append(prauc)
        if subject_label[i] == 0: 
            va_lists[f"rocauc_lr_holdout50"].append(rocauc)
            va_lists[f"acc_lr_holdout50"].append(acc)
            va_lists[f"recall_lr_holdout50"].append(recall)
            va_lists[f"precision_lr_holdout50"].append(precision)
            va_lists[f"f1_lr_holdout50"].append(f1)
            va_lists[f"prauc_lr_holdout50"].append(prauc)
        else:
            va_lists[f"rocauc_nn_holdout50"].append(rocauc)
            va_lists[f"acc_nn_holdout50"].append(acc)
            va_lists[f"recall_nn_holdout50"].append(recall)
            va_lists[f"precision_nn_holdout50"].append(precision)
            va_lists[f"f1_nn_holdout50"].append(f1)
            va_lists[f"prauc_nn_holdout50"].append(prauc)
        t_holdout50 += time.time() - start_holdout50
        # holdout 1.0
        start_holdout100 = time.time()
        x_va100 = x_va
        y_va100 = y_va
        output = subject(x_va100)
        rocauc = roc_auc_score(y_va100, output.detach().numpy())
        acc = accuracy_score(y_va100, output.detach().numpy())
        recall = recall_score(y_va100, output.detach().numpy())
        precision = precision_score(y_va100, output.detach().numpy())
        f1 = f1_score(y_va100, output.detach().numpy())
        prauc = average_precision_score(y_va100, output.detach().numpy())
        va_lists[f"rocauc_fusion_holdout100"].append(rocauc)
        va_lists[f"acc_fusion_holdout100"].append(acc)
        va_lists[f"recall_fusion_holdout100"].append(recall)
        va_lists[f"precision_fusion_holdout100"].append(precision)
        va_lists[f"f1_fusion_holdout100"].append(f1)
        va_lists[f"prauc_fusion_holdout100"].append(prauc)
        if subject_label[i] == 0:
            va_lists[f"rocauc_lr_holdout100"].append(rocauc)
            va_lists[f"acc_lr_holdout100"].append(acc)
            va_lists[f"recall_lr_holdout100"].append(recall)
            va_lists[f"precision_lr_holdout100"].append(precision)
            va_lists[f"f1_lr_holdout100"].append(f1)
            va_lists[f"prauc_lr_holdout100"].append(prauc)
        else:
            va_lists[f"rocauc_nn_holdout100"].append(rocauc)
            va_lists[f"acc_nn_holdout100"].append(acc)
            va_lists[f"recall_nn_holdout100"].append(recall)
            va_lists[f"precision_nn_holdout100"].append(precision)
            va_lists[f"f1_nn_holdout100"].append(f1)
            va_lists[f"prauc_nn_holdout100"].append(prauc)
        t_holdout100 += time.time() - start_holdout100
        # cv5
        start_cv5 = time.time()
        kf = KFold(n_splits=5)
        tmp_va_lists = {}
        for va_metric_name in va_metric_names:
            for subject_name in subject_names:
                tmp_va_lists[f"{va_metric_name}_{subject_name}_cv5"] = []
        for train_index, test_index in kf.split(x_va):
            x_train, x_test = x_va[train_index], x_va[test_index]
            y_train, y_test = y_va[train_index], y_va[test_index]
            output = subject(x_test)
            rocauc = roc_auc_score(y_test, output.detach().numpy())
            acc = accuracy_score(y_test, output.detach().numpy())
            recall = recall_score(y_test, output.detach().numpy())
            precision = precision_score(y_test, output.detach().numpy())
            f1 = f1_score(y_test, output.detach().numpy())
            prauc = average_precision_score(y_test, output.detach().numpy())
            tmp_va_lists[f"rocauc_fusion_cv5"].append(rocauc)
            tmp_va_lists[f"acc_fusion_cv5"].append(acc)
            tmp_va_lists[f"recall_fusion_cv5"].append(recall)
            tmp_va_lists[f"precision_fusion_cv5"].append(precision)
            tmp_va_lists[f"f1_fusion_cv5"].append(f1)
            tmp_va_lists[f"prauc_fusion_cv5"].append(prauc)
            if subject_label[i] == 0:
                tmp_va_lists[f"rocauc_lr_cv5"].append(rocauc)
                tmp_va_lists[f"acc_lr_cv5"].append(acc)
                tmp_va_lists[f"recall_lr_cv5"].append(recall)
                tmp_va_lists[f"precision_lr_cv5"].append(precision)
                tmp_va_lists[f"f1_lr_cv5"].append(f1)
                tmp_va_lists[f"prauc_lr_cv5"].append(prauc)
            else:
                tmp_va_lists[f"rocauc_nn_cv5"].append(rocauc)
                tmp_va_lists[f"acc_nn_cv5"].append(acc)
                tmp_va_lists[f"recall_nn_cv5"].append(recall)
                tmp_va_lists[f"precision_nn_cv5"].append(precision)
                tmp_va_lists[f"f1_nn_cv5"].append(f1)
                tmp_va_lists[f"prauc_nn_cv5"].append(prauc)
        # print("tmp va lists:", tmp_va_lists)
        # mean of cv5, check whether this subject is lr or nn, 
        for va_metric_name in va_metric_names:
            mean_cv5_fusion = np.mean(tmp_va_lists[f"{va_metric_name}_fusion_cv5"])
            mean_cv5_lr = np.mean(tmp_va_lists[f"{va_metric_name}_lr_cv5"])
            mean_cv5_nn = np.mean(tmp_va_lists[f"{va_metric_name}_nn_cv5"])
            va_lists[f"{va_metric_name}_fusion_cv5"].append(mean_cv5_fusion)
            if subject_label[i] == 0:
                va_lists[f"{va_metric_name}_lr_cv5"].append(mean_cv5_lr)
            else:
                va_lists[f"{va_metric_name}_nn_cv5"].append(mean_cv5_nn)

        t_cv5 += time.time() - start_cv5
        # cv10
        start_cv10 = time.time()
        kf = KFold(n_splits=10)
        tmp_va_lists = {}
        for va_metric_name in va_metric_names:
            for subject_name in subject_names:
                tmp_va_lists[f"{va_metric_name}_{subject_name}_cv10"] = []
        for train_index, test_index in kf.split(x_va):
            x_train, x_test = x_va[train_index], x_va[test_index]
            y_train, y_test = y_va[train_index], y_va[test_index]
            output = subject(x_test)
            rocauc = roc_auc_score(y_test, output.detach().numpy())
            acc = accuracy_score(y_test, output.detach().numpy())
            recall = recall_score(y_test, output.detach().numpy())
            precision = precision_score(y_test, output.detach().numpy())
            f1 = f1_score(y_test, output.detach().numpy())
            prauc = average_precision_score(y_test, output.detach().numpy())
            tmp_va_lists[f"rocauc_fusion_cv10"].append(rocauc)
            tmp_va_lists[f"acc_fusion_cv10"].append(acc)
            tmp_va_lists[f"recall_fusion_cv10"].append(recall)
            tmp_va_lists[f"precision_fusion_cv10"].append(precision)
            tmp_va_lists[f"f1_fusion_cv10"].append(f1)
            tmp_va_lists[f"prauc_fusion_cv10"].append(prauc)
            if subject_label[i] == 0:
                tmp_va_lists[f"rocauc_lr_cv10"].append(rocauc)
                tmp_va_lists[f"acc_lr_cv10"].append(acc)
                tmp_va_lists[f"recall_lr_cv10"].append(recall)
                tmp_va_lists[f"precision_lr_cv10"].append(precision)
                tmp_va_lists[f"f1_lr_cv10"].append(f1)
                tmp_va_lists[f"prauc_lr_cv10"].append(prauc)
            else:
                tmp_va_lists[f"rocauc_nn_cv10"].append(rocauc)
                tmp_va_lists[f"acc_nn_cv10"].append(acc)
                tmp_va_lists[f"recall_nn_cv10"].append(recall)
                tmp_va_lists[f"precision_nn_cv10"].append(precision)
                tmp_va_lists[f"f1_nn_cv10"].append(f1)
                tmp_va_lists[f"prauc_nn_cv10"].append(prauc)
        # mean of cv10 
        for va_metric_name in va_metric_names:
            mean_cv10_fusion = np.mean(tmp_va_lists[f"{va_metric_name}_fusion_cv10"])
            mean_cv10_lr = np.mean(tmp_va_lists[f"{va_metric_name}_lr_cv10"])
            mean_cv10_nn = np.mean(tmp_va_lists[f"{va_metric_name}_nn_cv10"])
            va_lists[f"{va_metric_name}_fusion_cv10"].append(mean_cv10_fusion)
            if subject_label[i] == 0:
                va_lists[f"{va_metric_name}_lr_cv10"].append(mean_cv10_lr)
            else:
                va_lists[f"{va_metric_name}_nn_cv10"].append(mean_cv10_nn)

        t_cv10 += time.time() - start_cv10
        # bootstrap 
        start_bootstrap = time.time()
        # randomly choose 1000 samples with replacement from a tensor, 
        bootstrap_index = torch.randint(0, len(x_va), (len(x_va),))
        x_va_bootstrap = x_va[bootstrap_index]
        y_va_bootstrap = y_va[bootstrap_index]
        output = subject(x_va_bootstrap)
        rocauc = roc_auc_score(y_va_bootstrap, output.detach().numpy())
        acc = accuracy_score(y_va_bootstrap, output.detach().numpy())
        recall = recall_score(y_va_bootstrap, output.detach().numpy())
        precision = precision_score(y_va_bootstrap, output.detach().numpy())
        f1 = f1_score(y_va_bootstrap, output.detach().numpy())
        prauc = average_precision_score(y_va_bootstrap, output.detach().numpy())
        va_lists[f"rocauc_fusion_bootstrap"].append(rocauc)
        va_lists[f"acc_fusion_bootstrap"].append(acc)
        va_lists[f"recall_fusion_bootstrap"].append(recall)
        va_lists[f"precision_fusion_bootstrap"].append(precision)
        va_lists[f"f1_fusion_bootstrap"].append(f1)
        va_lists[f"prauc_fusion_bootstrap"].append(prauc)
        if subject_label[i] == 0:
            va_lists[f"rocauc_lr_bootstrap"].append(rocauc)
            va_lists[f"acc_lr_bootstrap"].append(acc)
            va_lists[f"recall_lr_bootstrap"].append(recall)
            va_lists[f"precision_lr_bootstrap"].append(precision)
            va_lists[f"f1_lr_bootstrap"].append(f1)
            va_lists[f"prauc_lr_bootstrap"].append(prauc)
        else:
            va_lists[f"rocauc_nn_bootstrap"].append(rocauc)
            va_lists[f"acc_nn_bootstrap"].append(acc)
            va_lists[f"recall_nn_bootstrap"].append(recall)
            va_lists[f"precision_nn_bootstrap"].append(precision)
            va_lists[f"f1_nn_bootstrap"].append(f1)
            va_lists[f"prauc_nn_bootstrap"].append(prauc) 
        t_bootstrap += time.time() - start_bootstrap
    print("time holdout10:", t_holdout10, "per subject:", t_holdout10 / n_subjects)
    print("time holdout20:", t_holdout20, "per subject:", t_holdout20 / n_subjects)
    print("time holdout50:", t_holdout50, "per subject:", t_holdout50 / n_subjects)
    print("time holdout100:", t_holdout100, "per subject:", t_holdout100 / n_subjects)
    print("time cv5:", t_cv5, "per subject:", t_cv5 / n_subjects)
    print("time cv10:", t_cv10, "per subject:", t_cv10 / n_subjects)
    print("time bootstrap:", t_bootstrap, "per subject:", t_bootstrap / n_subjects)
    print("all time:", time.time() - start, "per subject:", (time.time() - start) / n_subjects)
    input()
    # merge all the va metrics, 
    lr_proxy_lists = []
    nn_proxy_lists = []
    fusion_proxy_lists = []
    for va_metric_name in va_metric_names:
        for va_valid_name in va_valid_names:
            # print("len of va lists:", len(va_lists[f"{va_metric_name}_lr_{va_valid_name}"]))
            # print("len of va lists:", len(va_lists[f"{va_metric_name}_nn_{va_valid_name}"]))
            # print("len of va lists:", len(va_lists[f"{va_metric_name}_fusion_{va_valid_name}"]))
            lr_proxy_lists.append(va_lists[f"{va_metric_name}_lr_{va_valid_name}"])
            nn_proxy_lists.append(va_lists[f"{va_metric_name}_nn_{va_valid_name}"])
            fusion_proxy_lists.append(va_lists[f"{va_metric_name}_fusion_{va_valid_name}"])
    # print("lr proxy lists:", lr_proxy_lists)
    # print("nn proxy lists:", nn_proxy_lists)
    # print("fusion proxy lists:", fusion_proxy_lists)
    # transpose the lists,
    lr_proxy_lists = np.array(lr_proxy_lists).T
    nn_proxy_lists = np.array(nn_proxy_lists).T
    fusion_proxy_lists = np.array(fusion_proxy_lists).T

    # # to numpy, 
    # lr_proxy_lists = np.array(lr_proxy_lists)
    # nn_proxy_lists = np.array(nn_proxy_lists)
    # fusion_proxy_lists = np.array(fusion_proxy_lists)
    subject_lr_vector_list = np.array(subject_lr_vector_list)
    # reshape, transpose, a
    # show shape, 
    print("lr proxy lists shape:", lr_proxy_lists.shape)
    print("nn proxy lists shape:", nn_proxy_lists.shape)
    print("fusion proxy lists shape:", fusion_proxy_lists.shape)
    print("subject lr vector list shape:", subject_lr_vector_list.shape)
    # ---------------------------------------------------------------------

    # ---------------------------evaluation task-----------------------------
    # randomly select 0.8 subjects as train, 0.2 as test, from subject list, 
    # subject: train, test; metrics: train, test; proxy metrics on validation: train, test, 
    s_lr_tr, s_lr_te, \
        rocauc_lr_tr_te, rocauc_lr_te_te, proxy_lr_tr_va, proxy_lr_te_va, \
        acc_lr_tr_te, acc_lr_te_te, \
        recall_lr_tr_te, recall_lr_te_te,  \
        precision_lr_tr_te, precision_lr_te_te, \
        f1_lr_tr_te, f1_lr_te_te,  \
        prauc_lr_tr_te, prauc_lr_te_te \
        = train_test_split(subject_lr_vector_list, \
                            te_lists[f"rocauc_lr"], lr_proxy_lists, \
                            te_lists[f"acc_lr"], \
                            te_lists[f"recall_lr"], \
                            te_lists[f"precision_lr"], \
                            te_lists[f"f1_lr"], \
                            te_lists[f"prauc_lr"], \
                            test_size=0.2)
    
    s_nn_tr, s_nn_te, \
        rocauc_nn_tr_te, rocauc_nn_te_te, proxy_nn_tr_va, proxy_nn_te_va, \
        acc_nn_tr_te, acc_nn_te_te, \
        recall_nn_tr_te, recall_nn_te_te, \
        precision_nn_tr_te, precision_nn_te_te, \
        f1_nn_tr_te, f1_nn_te_te,  \
        prauc_nn_tr_te, prauc_nn_te_te \
        = train_test_split(subject_nn_vector_list, \
                            te_lists[f"rocauc_nn"], nn_proxy_lists, \
                            te_lists[f"acc_nn"], \
                            te_lists[f"recall_nn"], \
                            te_lists[f"precision_nn"], \
                            te_lists[f"f1_nn"], \
                            te_lists[f"prauc_nn"], \
                            test_size=0.2)
    # ---------------------------------------------------------------------
    # concat s_lr_tr, roc_lr_tr, acc_lr_tr, recall_lr_tr, precision_lr_tr, f1_lr_tr, prauc_lr_tr to a tensor,
    # va set is used as a covariate, 
    # to tensor, 
    s_lr_tr = torch.tensor(s_lr_tr, dtype=torch.float32)
    proxy_lr_tr_va = torch.tensor(proxy_lr_tr_va, dtype=torch.float32)
    # concat s_lr_te, roc_lr_te, acc_lr_te, recall_lr_te, precision_lr_te, f1_lr_te, prauc_lr_te to a tensor,
    sm_lr_tr = torch.cat([s_lr_tr, proxy_lr_tr_va], axis=1).detach().numpy()
    # to tensor, 
    s_lr_te = torch.tensor(s_lr_te, dtype=torch.float32)
    proxy_lr_te_va = torch.tensor(proxy_lr_te_va, dtype=torch.float32)
    # concat s_lr_te, roc_lr_te, acc_lr_te, recall_lr_te, precision_lr_te, f1_lr_te, prauc_lr_te to a tensor,
    sm_lr_te = torch.cat([s_lr_te, proxy_lr_te_va], axis=1).detach().numpy()
    # to tensor, 
    s_nn_tr = torch.tensor(s_nn_tr, dtype=torch.float32)
    proxy_nn_tr_va = torch.tensor(proxy_nn_tr_va, dtype=torch.float32)
    # concat s_nn_te, roc_nn_te, acc_nn_te, recall_nn_te, precision_nn_te, f1_nn_te, prauc_nn_te to a tensor,
    sm_nn_tr = torch.cat([s_nn_tr, proxy_nn_tr_va], axis=1).detach().numpy()
    # to tensor, 
    s_nn_te = torch.tensor(s_nn_te, dtype=torch.float32)
    proxy_nn_te_va = torch.tensor(proxy_nn_te_va, dtype=torch.float32)
    # concat s_nn_te, roc_nn_te, acc_nn_te, recall_nn_te, precision_nn_te, f1_nn_te, prauc_nn_te to a tensor,
    sm_nn_te = torch.cat([s_nn_te, proxy_nn_te_va], axis=1).detach().numpy()
    # they are all true metric (result on test set), only consider the roc_auc_score and accuracy, 
    m_lr_tr_rocauc = torch.tensor(rocauc_lr_tr_te, dtype=torch.float32).detach().numpy()
    m_lr_te_rocauc = torch.tensor(rocauc_lr_te_te, dtype=torch.float32).detach().numpy()
    m_nn_tr_rocauc = torch.tensor(rocauc_nn_tr_te, dtype=torch.float32).detach().numpy()
    m_nn_te_rocauc = torch.tensor(rocauc_nn_te_te, dtype=torch.float32).detach().numpy()
    m_lr_tr_acc = torch.tensor(acc_lr_tr_te, dtype=torch.float32).detach().numpy()
    m_lr_te_acc = torch.tensor(acc_lr_te_te, dtype=torch.float32).detach().numpy()
    m_nn_tr_acc = torch.tensor(acc_nn_tr_te, dtype=torch.float32).detach().numpy()
    m_nn_te_acc = torch.tensor(acc_nn_te_te, dtype=torch.float32).detach().numpy()
    s_lr_te = s_lr_te.detach().numpy()
    s_nn_te = s_nn_te.detach().numpy()
    s_lr_tr = s_lr_tr.detach().numpy()
    s_nn_tr = s_nn_tr.detach().numpy()
    proxy_lr_te_va = proxy_lr_te_va.detach().numpy()
    proxy_nn_te_va = proxy_nn_te_va.detach().numpy()
    proxy_lr_tr_va = proxy_lr_tr_va.detach().numpy()
    proxy_nn_tr_va = proxy_nn_tr_va.detach().numpy()

    # error is va_list - te_list, 
    errors['rocauc_holdout100_fusion'].append([a - b for a, b in zip(te_lists['rocauc_fusion'],va_lists['rocauc_fusion_holdout100'])])
    errors['rocauc_holdout50_fusion'].append([a - b for a, b in zip(te_lists['rocauc_fusion'],va_lists['rocauc_fusion_holdout50'])])
    errors['rocauc_holdout20_fusion'].append([a - b for a, b in zip(te_lists['rocauc_fusion'],va_lists['rocauc_fusion_holdout20'])])
    errors['rocauc_holdout10_fusion'].append([a - b for a, b in zip(te_lists['rocauc_fusion'],va_lists['rocauc_fusion_holdout10'])])
    errors['rocauc_holdout100_lr'].append([a - b for a, b in zip(te_lists['rocauc_lr'],va_lists['rocauc_lr_holdout100'])])
    errors['rocauc_holdout50_lr'].append([a - b for a, b in zip(te_lists['rocauc_lr'],va_lists['rocauc_lr_holdout50'])])
    errors['rocauc_holdout20_lr'].append([a - b for a, b in zip(te_lists['rocauc_lr'],va_lists['rocauc_lr_holdout20'])])
    errors['rocauc_holdout10_lr'].append([a - b for a, b in zip(te_lists['rocauc_lr'],va_lists['rocauc_lr_holdout10'])])
    errors['rocauc_holdout100_nn'].append([a - b for a, b in zip(te_lists['rocauc_nn'],va_lists['rocauc_nn_holdout100'])])
    errors['rocauc_holdout50_nn'].append([a - b for a, b in zip(te_lists['rocauc_nn'],va_lists['rocauc_nn_holdout50'])])
    errors['rocauc_holdout20_nn'].append([a - b for a, b in zip(te_lists['rocauc_nn'],va_lists['rocauc_nn_holdout20'])])
    errors['rocauc_holdout10_nn'].append([a - b for a, b in zip(te_lists['rocauc_nn'],va_lists['rocauc_nn_holdout10'])])
    errors['rocauc_cv5_fusion'].append([a - b for a, b in zip(te_lists['rocauc_fusion'],va_lists['rocauc_fusion_cv5'])])
    errors['rocauc_cv5_lr'].append([a - b for a, b in zip(te_lists['rocauc_lr'],va_lists['rocauc_lr_cv5'])])
    errors['rocauc_cv5_nn'].append([a - b for a, b in zip(te_lists['rocauc_nn'],va_lists['rocauc_nn_cv5'])])
    errors['rocauc_cv10_fusion'].append([a - b for a, b in zip(te_lists['rocauc_fusion'],va_lists['rocauc_fusion_cv10'])])
    errors['rocauc_cv10_lr'].append([a - b for a, b in zip(te_lists['rocauc_lr'],va_lists['rocauc_lr_cv10'])])
    errors['rocauc_cv10_nn'].append([a - b for a, b in zip(te_lists['rocauc_nn'],va_lists['rocauc_nn_cv10'])])
    errors['rocauc_bootstrap_fusion'].append([a - b for a, b in zip(te_lists['rocauc_fusion'],va_lists['rocauc_fusion_bootstrap'])])
    errors['rocauc_bootstrap_lr'].append([a - b for a, b in zip(te_lists['rocauc_lr'],va_lists['rocauc_lr_bootstrap'])])
    errors['rocauc_bootstrap_nn'].append([a - b for a, b in zip(te_lists['rocauc_nn'],va_lists['rocauc_nn_bootstrap'])])
    errors['acc_holdout100_fusion'].append([a - b for a, b in zip(te_lists['acc_fusion'],va_lists['acc_fusion_holdout100'])])
    errors['acc_holdout50_fusion'].append([a - b for a, b in zip(te_lists['acc_fusion'],va_lists['acc_fusion_holdout50'])])
    errors['acc_holdout20_fusion'].append([a - b for a, b in zip(te_lists['acc_fusion'],va_lists['acc_fusion_holdout20'])])
    errors['acc_holdout10_fusion'].append([a - b for a, b in zip(te_lists['acc_fusion'],va_lists['acc_fusion_holdout10'])])
    errors['acc_holdout100_lr'].append([a - b for a, b in zip(te_lists['acc_lr'],va_lists['acc_lr_holdout100'])])
    errors['acc_holdout50_lr'].append([a - b for a, b in zip(te_lists['acc_lr'],va_lists['acc_lr_holdout50'])])
    errors['acc_holdout20_lr'].append([a - b for a, b in zip(te_lists['acc_lr'],va_lists['acc_lr_holdout20'])])
    errors['acc_holdout10_lr'].append([a - b for a, b in zip(te_lists['acc_lr'],va_lists['acc_lr_holdout10'])])
    errors['acc_holdout100_nn'].append([a - b for a, b in zip(te_lists['acc_nn'],va_lists['acc_nn_holdout100'])])
    errors['acc_holdout50_nn'].append([a - b for a, b in zip(te_lists['acc_nn'],va_lists['acc_nn_holdout50'])])
    errors['acc_holdout20_nn'].append([a - b for a, b in zip(te_lists['acc_nn'],va_lists['acc_nn_holdout20'])])
    errors['acc_holdout10_nn'].append([a - b for a, b in zip(te_lists['acc_nn'],va_lists['acc_nn_holdout10'])])
    errors['acc_cv5_fusion'].append([a - b for a, b in zip(te_lists['acc_fusion'],va_lists['acc_fusion_cv5'])])
    errors['acc_cv5_lr'].append([a - b for a, b in zip(te_lists['acc_lr'],va_lists['acc_lr_cv5'])])
    errors['acc_cv5_nn'].append([a - b for a, b in zip(te_lists['acc_nn'],va_lists['acc_nn_cv5'])])
    errors['acc_cv10_fusion'].append([a - b for a, b in zip(te_lists['acc_fusion'],va_lists['acc_fusion_cv10'])])
    errors['acc_cv10_lr'].append([a - b for a, b in zip(te_lists['acc_lr'],va_lists['acc_lr_cv10'])])
    errors['acc_cv10_nn'].append([a - b for a, b in zip(te_lists['acc_nn'],va_lists['acc_nn_cv10'])])
    errors['acc_bootstrap_fusion'].append([a - b for a, b in zip(te_lists['acc_fusion'],va_lists['acc_fusion_bootstrap'])])
    errors['acc_bootstrap_lr'].append([a - b for a, b in zip(te_lists['acc_lr'],va_lists['acc_lr_bootstrap'])])
    errors['acc_bootstrap_nn'].append([a - b for a, b in zip(te_lists['acc_nn'],va_lists['acc_nn_bootstrap'])])
    # ----------------------- em -----------------------
    for model_name in model_names:
        print(f"model_name: {model_name}")
        if model_name == 'linear':
            from sklearn.linear_model import LinearRegression
            lrmodel_rocauc = LinearRegression()
            nnmodel_rocauc = LinearRegression()
            lrmodel_acc = LinearRegression()
            nnmodel_acc = LinearRegression()
            lrmodel_rocauc.fit(sm_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc.fit(sm_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc.fit(sm_lr_tr, m_lr_tr_acc)
            nnmodel_acc.fit(sm_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_subject = LinearRegression()
            nnmodel_rocauc_subject = LinearRegression()
            lrmodel_acc_subject = LinearRegression()
            nnmodel_acc_subject = LinearRegression()
            lrmodel_rocauc_subject.fit(s_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc_subject.fit(s_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc_subject.fit(s_lr_tr, m_lr_tr_acc)
            nnmodel_acc_subject.fit(s_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_proxy = LinearRegression()
            nnmodel_rocauc_proxy = LinearRegression()
            lrmodel_acc_proxy = LinearRegression()
            nnmodel_acc_proxy = LinearRegression()
            lrmodel_rocauc_proxy.fit(proxy_lr_tr_va, m_lr_tr_rocauc)
            nnmodel_rocauc_proxy.fit(proxy_nn_tr_va, m_nn_tr_rocauc)
            lrmodel_acc_proxy.fit(proxy_lr_tr_va, m_lr_tr_acc)
            nnmodel_acc_proxy.fit(proxy_nn_tr_va, m_nn_tr_acc)
            lr_error_rocauc = m_lr_te_rocauc - lrmodel_rocauc.predict(sm_lr_te)
            nn_error_rocauc = m_nn_te_rocauc - nnmodel_rocauc.predict(sm_nn_te)
            lr_error_acc = m_lr_te_acc - lrmodel_acc.predict(sm_lr_te)
            nn_error_acc = m_nn_te_acc - nnmodel_acc.predict(sm_nn_te)
            lr_error_rocauc_subject = m_lr_te_rocauc - lrmodel_rocauc_subject.predict(s_lr_te)
            nn_error_rocauc_subject = m_nn_te_rocauc - nnmodel_rocauc_subject.predict(s_nn_te)
            lr_error_acc_subject = m_lr_te_acc - lrmodel_acc_subject.predict(s_lr_te)
            nn_error_acc_subject = m_nn_te_acc - nnmodel_acc_subject.predict(s_nn_te)
            lr_error_rocauc_proxy = m_lr_te_rocauc - lrmodel_rocauc_proxy.predict(proxy_lr_te_va)
            nn_error_rocauc_proxy = m_nn_te_rocauc - nnmodel_rocauc_proxy.predict(proxy_nn_te_va)
            lr_error_acc_proxy = m_lr_te_acc - lrmodel_acc_proxy.predict(proxy_lr_te_va)
            nn_error_acc_proxy = m_nn_te_acc - nnmodel_acc_proxy.predict(proxy_nn_te_va)
        elif model_name == 'mlp':
            from sklearn.neural_network import MLPRegressor
            lrmodel_rocauc = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_rocauc = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_acc = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_acc = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_rocauc.fit(sm_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc.fit(sm_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc.fit(sm_lr_tr, m_lr_tr_acc)
            nnmodel_acc.fit(sm_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_subject = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_rocauc_subject = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_acc_subject = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_acc_subject = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_rocauc_subject.fit(s_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc_subject.fit(s_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc_subject.fit(s_lr_tr, m_lr_tr_acc)
            nnmodel_acc_subject.fit(s_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_proxy = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_rocauc_proxy = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_acc_proxy = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            nnmodel_acc_proxy = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000)
            lrmodel_rocauc_proxy.fit(proxy_lr_tr_va, m_lr_tr_rocauc)
            nnmodel_rocauc_proxy.fit(proxy_nn_tr_va, m_nn_tr_rocauc)
            lrmodel_acc_proxy.fit(proxy_lr_tr_va, m_lr_tr_acc)
            nnmodel_acc_proxy.fit(proxy_nn_tr_va, m_nn_tr_acc)

            lr_error_rocauc = m_lr_te_rocauc - lrmodel_rocauc.predict(sm_lr_te)
            nn_error_rocauc = m_nn_te_rocauc - nnmodel_rocauc.predict(sm_nn_te)
            lr_error_acc = m_lr_te_acc - lrmodel_acc.predict(sm_lr_te)
            nn_error_acc = m_nn_te_acc - nnmodel_acc.predict(sm_nn_te)
            lr_error_rocauc_subject = m_lr_te_rocauc - lrmodel_rocauc_subject.predict(s_lr_te)
            nn_error_rocauc_subject = m_nn_te_rocauc - nnmodel_rocauc_subject.predict(s_nn_te)
            lr_error_acc_subject = m_lr_te_acc - lrmodel_acc_subject.predict(s_lr_te)
            nn_error_acc_subject = m_nn_te_acc - nnmodel_acc_subject.predict(s_nn_te)
            lr_error_rocauc_proxy = m_lr_te_rocauc - lrmodel_rocauc_proxy.predict(proxy_lr_te_va)
            nn_error_rocauc_proxy = m_nn_te_rocauc - nnmodel_rocauc_proxy.predict(proxy_nn_te_va)
            lr_error_acc_proxy = m_lr_te_acc - lrmodel_acc_proxy.predict(proxy_lr_te_va)
            nn_error_acc_proxy = m_nn_te_acc - nnmodel_acc_proxy.predict(proxy_nn_te_va)

        elif model_name == 'svm':
            from sklearn.svm import SVR
            lrmodel_rocauc = SVR()
            nnmodel_rocauc = SVR()
            lrmodel_acc = SVR()
            nnmodel_acc = SVR()
            lrmodel_rocauc.fit(sm_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc.fit(sm_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc.fit(sm_lr_tr, m_lr_tr_acc)
            nnmodel_acc.fit(sm_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_subject = SVR()
            nnmodel_rocauc_subject = SVR()
            lrmodel_acc_subject = SVR()
            nnmodel_acc_subject = SVR()
            lrmodel_rocauc_subject.fit(s_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc_subject.fit(s_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc_subject.fit(s_lr_tr, m_lr_tr_acc)
            nnmodel_acc_subject.fit(s_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_proxy = SVR()
            nnmodel_rocauc_proxy = SVR()
            lrmodel_acc_proxy = SVR()
            nnmodel_acc_proxy = SVR()
            lrmodel_rocauc_proxy.fit(proxy_lr_tr_va, m_lr_tr_rocauc)
            nnmodel_rocauc_proxy.fit(proxy_nn_tr_va, m_nn_tr_rocauc)
            lrmodel_acc_proxy.fit(proxy_lr_tr_va, m_lr_tr_acc)
            nnmodel_acc_proxy.fit(proxy_nn_tr_va, m_nn_tr_acc)
            lr_error_rocauc = m_lr_te_rocauc - lrmodel_rocauc.predict(sm_lr_te)
            nn_error_rocauc = m_nn_te_rocauc - nnmodel_rocauc.predict(sm_nn_te)
            lr_error_acc = m_lr_te_acc - lrmodel_acc.predict(sm_lr_te)
            nn_error_acc = m_nn_te_acc - nnmodel_acc.predict(sm_nn_te)
            lr_error_rocauc_subject = m_lr_te_rocauc - lrmodel_rocauc_subject.predict(s_lr_te)
            nn_error_rocauc_subject = m_nn_te_rocauc - nnmodel_rocauc_subject.predict(s_nn_te)
            lr_error_acc_subject = m_lr_te_acc - lrmodel_acc_subject.predict(s_lr_te)
            nn_error_acc_subject = m_nn_te_acc - nnmodel_acc_subject.predict(s_nn_te)
            lr_error_rocauc_proxy = m_lr_te_rocauc - lrmodel_rocauc_proxy.predict(proxy_lr_te_va)
            nn_error_rocauc_proxy = m_nn_te_rocauc - nnmodel_rocauc_proxy.predict(proxy_nn_te_va)
            lr_error_acc_proxy = m_lr_te_acc - lrmodel_acc_proxy.predict(proxy_lr_te_va)
            nn_error_acc_proxy = m_nn_te_acc - nnmodel_acc_proxy.predict(proxy_nn_te_va)
        elif model_name == 'rf': 
            from sklearn.ensemble import RandomForestRegressor
            lrmodel_rocauc = RandomForestRegressor()
            nnmodel_rocauc = RandomForestRegressor()
            lrmodel_acc = RandomForestRegressor()
            nnmodel_acc = RandomForestRegressor()
            lrmodel_rocauc.fit(sm_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc.fit(sm_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc.fit(sm_lr_tr, m_lr_tr_acc)
            nnmodel_acc.fit(sm_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_subject = RandomForestRegressor()
            nnmodel_rocauc_subject = RandomForestRegressor()
            lrmodel_acc_subject = RandomForestRegressor()
            nnmodel_acc_subject = RandomForestRegressor()
            lrmodel_rocauc_subject.fit(s_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc_subject.fit(s_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc_subject.fit(s_lr_tr, m_lr_tr_acc)
            nnmodel_acc_subject.fit(s_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_proxy = RandomForestRegressor()
            nnmodel_rocauc_proxy = RandomForestRegressor()
            lrmodel_acc_proxy = RandomForestRegressor()
            nnmodel_acc_proxy = RandomForestRegressor()
            lrmodel_rocauc_proxy.fit(proxy_lr_tr_va, m_lr_tr_rocauc)
            nnmodel_rocauc_proxy.fit(proxy_nn_tr_va, m_nn_tr_rocauc)
            lrmodel_acc_proxy.fit(proxy_lr_tr_va, m_lr_tr_acc)
            nnmodel_acc_proxy.fit(proxy_nn_tr_va, m_nn_tr_acc)

            lr_error_rocauc = m_lr_te_rocauc - lrmodel_rocauc.predict(sm_lr_te)
            nn_error_rocauc = m_nn_te_rocauc - nnmodel_rocauc.predict(sm_nn_te)
            lr_error_acc = m_lr_te_acc - lrmodel_acc.predict(sm_lr_te)
            nn_error_acc = m_nn_te_acc - nnmodel_acc.predict(sm_nn_te)
            lr_error_rocauc_subject = m_lr_te_rocauc - lrmodel_rocauc_subject.predict(s_lr_te)
            nn_error_rocauc_subject = m_nn_te_rocauc - nnmodel_rocauc_subject.predict(s_nn_te)
            lr_error_acc_subject = m_lr_te_acc - lrmodel_acc_subject.predict(s_lr_te)
            nn_error_acc_subject = m_nn_te_acc - nnmodel_acc_subject.predict(s_nn_te)
            lr_error_rocauc_proxy = m_lr_te_rocauc - lrmodel_rocauc_proxy.predict(proxy_lr_te_va)
            nn_error_rocauc_proxy = m_nn_te_rocauc - nnmodel_rocauc_proxy.predict(proxy_nn_te_va)
            lr_error_acc_proxy = m_lr_te_acc - lrmodel_acc_proxy.predict(proxy_lr_te_va)
            nn_error_acc_proxy = m_nn_te_acc - nnmodel_acc_proxy.predict(proxy_nn_te_va)

        elif model_name == 'lgbm':
            import lightgbm as lgb
            lrmodel_rocauc = lgb.LGBMRegressor()
            nnmodel_rocauc = lgb.LGBMRegressor()
            lrmodel_acc = lgb.LGBMRegressor()
            nnmodel_acc = lgb.LGBMRegressor()
            lrmodel_rocauc.fit(sm_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc.fit(sm_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc.fit(sm_lr_tr, m_lr_tr_acc)
            nnmodel_acc.fit(sm_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_subject = lgb.LGBMRegressor()
            nnmodel_rocauc_subject = lgb.LGBMRegressor()
            lrmodel_acc_subject = lgb.LGBMRegressor()
            nnmodel_acc_subject = lgb.LGBMRegressor()
            lrmodel_rocauc_subject.fit(s_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc_subject.fit(s_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc_subject.fit(s_lr_tr, m_lr_tr_acc)
            nnmodel_acc_subject.fit(s_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_proxy = lgb.LGBMRegressor()
            nnmodel_rocauc_proxy = lgb.LGBMRegressor()
            lrmodel_acc_proxy = lgb.LGBMRegressor()
            nnmodel_acc_proxy = lgb.LGBMRegressor()
            lrmodel_rocauc_proxy.fit(proxy_lr_tr_va, m_lr_tr_rocauc)
            nnmodel_rocauc_proxy.fit(proxy_nn_tr_va, m_nn_tr_rocauc)
            lrmodel_acc_proxy.fit(proxy_lr_tr_va, m_lr_tr_acc)
            nnmodel_acc_proxy.fit(proxy_nn_tr_va, m_nn_tr_acc)

            lr_error_rocauc = m_lr_te_rocauc - lrmodel_rocauc.predict(sm_lr_te)
            nn_error_rocauc = m_nn_te_rocauc - nnmodel_rocauc.predict(sm_nn_te)
            lr_error_acc = m_lr_te_acc - lrmodel_acc.predict(sm_lr_te)
            nn_error_acc = m_nn_te_acc - nnmodel_acc.predict(sm_nn_te)
            lr_error_rocauc_subject = m_lr_te_rocauc - lrmodel_rocauc_subject.predict(s_lr_te)
            nn_error_rocauc_subject = m_nn_te_rocauc - nnmodel_rocauc_subject.predict(s_nn_te)
            lr_error_acc_subject = m_lr_te_acc - lrmodel_acc_subject.predict(s_lr_te)
            nn_error_acc_subject = m_nn_te_acc - nnmodel_acc_subject.predict(s_nn_te)
            lr_error_rocauc_proxy = m_lr_te_rocauc - lrmodel_rocauc_proxy.predict(proxy_lr_te_va)
            nn_error_rocauc_proxy = m_nn_te_rocauc - nnmodel_rocauc_proxy.predict(proxy_nn_te_va)
            lr_error_acc_proxy = m_lr_te_acc - lrmodel_acc_proxy.predict(proxy_lr_te_va)
            nn_error_acc_proxy = m_nn_te_acc - nnmodel_acc_proxy.predict(proxy_nn_te_va)

        elif model_name == 'xgb':
            import xgboost as xgb
            lrmodel_rocauc = xgb.XGBRegressor()
            nnmodel_rocauc = xgb.XGBRegressor()
            lrmodel_acc = xgb.XGBRegressor()
            nnmodel_acc = xgb.XGBRegressor()
            lrmodel_rocauc.fit(sm_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc.fit(sm_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc.fit(sm_lr_tr, m_lr_tr_acc)
            nnmodel_acc.fit(sm_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_subject = xgb.XGBRegressor()
            nnmodel_rocauc_subject = xgb.XGBRegressor()
            lrmodel_acc_subject = xgb.XGBRegressor()
            nnmodel_acc_subject = xgb.XGBRegressor()
            lrmodel_rocauc_subject.fit(s_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc_subject.fit(s_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc_subject.fit(s_lr_tr, m_lr_tr_acc)
            nnmodel_acc_subject.fit(s_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_proxy = xgb.XGBRegressor()
            nnmodel_rocauc_proxy = xgb.XGBRegressor()
            lrmodel_acc_proxy = xgb.XGBRegressor()
            nnmodel_acc_proxy = xgb.XGBRegressor()
            lrmodel_rocauc_proxy.fit(proxy_lr_tr_va, m_lr_tr_rocauc)
            nnmodel_rocauc_proxy.fit(proxy_nn_tr_va, m_nn_tr_rocauc)
            lrmodel_acc_proxy.fit(proxy_lr_tr_va, m_lr_tr_acc)
            nnmodel_acc_proxy.fit(proxy_nn_tr_va, m_nn_tr_acc)

            lr_error_rocauc = m_lr_te_rocauc - lrmodel_rocauc.predict(sm_lr_te)
            nn_error_rocauc = m_nn_te_rocauc - nnmodel_rocauc.predict(sm_nn_te)
            lr_error_acc = m_lr_te_acc - lrmodel_acc.predict(sm_lr_te)
            nn_error_acc = m_nn_te_acc - nnmodel_acc.predict(sm_nn_te)
            lr_error_rocauc_subject = m_lr_te_rocauc - lrmodel_rocauc_subject.predict(s_lr_te)
            nn_error_rocauc_subject = m_nn_te_rocauc - nnmodel_rocauc_subject.predict(s_nn_te)
            lr_error_acc_subject = m_lr_te_acc - lrmodel_acc_subject.predict(s_lr_te)
            nn_error_acc_subject = m_nn_te_acc - nnmodel_acc_subject.predict(s_nn_te)
            lr_error_rocauc_proxy = m_lr_te_rocauc - lrmodel_rocauc_proxy.predict(proxy_lr_te_va)
            nn_error_rocauc_proxy = m_nn_te_rocauc - nnmodel_rocauc_proxy.predict(proxy_nn_te_va)
            lr_error_acc_proxy = m_lr_te_acc - lrmodel_acc_proxy.predict(proxy_lr_te_va)
            nn_error_acc_proxy = m_nn_te_acc - nnmodel_acc_proxy.predict(proxy_nn_te_va)

        elif model_name == 'catboost':
            from catboost import CatBoostRegressor
            lrmodel_rocauc = CatBoostRegressor()
            nnmodel_rocauc = CatBoostRegressor()
            lrmodel_acc = CatBoostRegressor()
            nnmodel_acc = CatBoostRegressor()
            lrmodel_rocauc.fit(sm_lr_tr, m_lr_tr_rocauc, verbose=0)
            nnmodel_rocauc.fit(sm_nn_tr, m_nn_tr_rocauc, verbose=0)
            lrmodel_acc.fit(sm_lr_tr, m_lr_tr_acc, verbose=0)
            nnmodel_acc.fit(sm_nn_tr, m_nn_tr_acc, verbose=0)
            lrmodel_rocauc_subject = CatBoostRegressor()
            nnmodel_rocauc_subject = CatBoostRegressor()
            lrmodel_acc_subject = CatBoostRegressor()
            nnmodel_acc_subject = CatBoostRegressor()
            lrmodel_rocauc_subject.fit(s_lr_tr, m_lr_tr_rocauc, verbose=0)
            nnmodel_rocauc_subject.fit(s_nn_tr, m_nn_tr_rocauc, verbose=0)
            lrmodel_acc_subject.fit(s_lr_tr, m_lr_tr_acc, verbose=0)
            nnmodel_acc_subject.fit(s_nn_tr, m_nn_tr_acc, verbose=0)
            lrmodel_rocauc_proxy = CatBoostRegressor()
            nnmodel_rocauc_proxy = CatBoostRegressor()
            lrmodel_acc_proxy = CatBoostRegressor()
            nnmodel_acc_proxy = CatBoostRegressor()
            lrmodel_rocauc_proxy.fit(proxy_lr_tr_va, m_lr_tr_rocauc, verbose=0)
            nnmodel_rocauc_proxy.fit(proxy_nn_tr_va, m_nn_tr_rocauc, verbose=0)
            lrmodel_acc_proxy.fit(proxy_lr_tr_va, m_lr_tr_acc, verbose=0)
            nnmodel_acc_proxy.fit(proxy_nn_tr_va, m_nn_tr_acc, verbose=0)

            lr_error_rocauc = m_lr_te_rocauc - lrmodel_rocauc.predict(sm_lr_te)
            nn_error_rocauc = m_nn_te_rocauc - nnmodel_rocauc.predict(sm_nn_te)
            lr_error_acc = m_lr_te_acc - lrmodel_acc.predict(sm_lr_te)
            nn_error_acc = m_nn_te_acc - nnmodel_acc.predict(sm_nn_te)
            lr_error_rocauc_subject = m_lr_te_rocauc - lrmodel_rocauc_subject.predict(s_lr_te)
            nn_error_rocauc_subject = m_nn_te_rocauc - nnmodel_rocauc_subject.predict(s_nn_te)
            lr_error_acc_subject = m_lr_te_acc - lrmodel_acc_subject.predict(s_lr_te)
            nn_error_acc_subject = m_nn_te_acc - nnmodel_acc_subject.predict(s_nn_te)
            lr_error_rocauc_proxy = m_lr_te_rocauc - lrmodel_rocauc_proxy.predict(proxy_lr_te_va)
            nn_error_rocauc_proxy = m_nn_te_rocauc - nnmodel_rocauc_proxy.predict(proxy_nn_te_va)
            lr_error_acc_proxy = m_lr_te_acc - lrmodel_acc_proxy.predict(proxy_lr_te_va)
            nn_error_acc_proxy = m_nn_te_acc - nnmodel_acc_proxy.predict(proxy_nn_te_va)

        elif model_name == 'tabpfn':
            from tabpfn import TabPFNClassifier, TabPFNRegressor
            lrmodel_rocauc = TabPFNRegressor()
            nnmodel_rocauc = TabPFNRegressor()
            lrmodel_acc = TabPFNRegressor()
            nnmodel_acc = TabPFNRegressor()
            lrmodel_rocauc.fit(sm_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc.fit(sm_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc.fit(sm_lr_tr, m_lr_tr_acc)
            nnmodel_acc.fit(sm_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_subject = TabPFNRegressor()
            nnmodel_rocauc_subject = TabPFNRegressor()
            lrmodel_acc_subject = TabPFNRegressor()
            nnmodel_acc_subject = TabPFNRegressor()
            lrmodel_rocauc_subject.fit(s_lr_tr, m_lr_tr_rocauc)
            nnmodel_rocauc_subject.fit(s_nn_tr, m_nn_tr_rocauc)
            lrmodel_acc_subject.fit(s_lr_tr, m_lr_tr_acc)
            nnmodel_acc_subject.fit(s_nn_tr, m_nn_tr_acc)
            lrmodel_rocauc_proxy = TabPFNRegressor()
            nnmodel_rocauc_proxy = TabPFNRegressor()
            lrmodel_acc_proxy = TabPFNRegressor()
            nnmodel_acc_proxy = TabPFNRegressor()
            lrmodel_rocauc_proxy.fit(proxy_lr_tr_va, m_lr_tr_rocauc)
            nnmodel_rocauc_proxy.fit(proxy_nn_tr_va, m_nn_tr_rocauc)
            lrmodel_acc_proxy.fit(proxy_lr_tr_va, m_lr_tr_acc)
            nnmodel_acc_proxy.fit(proxy_nn_tr_va, m_nn_tr_acc)

            lr_error_rocauc = m_lr_te_rocauc - lrmodel_rocauc.predict(sm_lr_te)
            nn_error_rocauc = m_nn_te_rocauc - nnmodel_rocauc.predict(sm_nn_te)
            lr_error_acc = m_lr_te_acc - lrmodel_acc.predict(sm_lr_te)
            nn_error_acc = m_nn_te_acc - nnmodel_acc.predict(sm_nn_te)
            lr_error_rocauc_subject = m_lr_te_rocauc - lrmodel_rocauc_subject.predict(s_lr_te)
            nn_error_rocauc_subject = m_nn_te_rocauc - nnmodel_rocauc_subject.predict(s_nn_te)
            lr_error_acc_subject = m_lr_te_acc - lrmodel_acc_subject.predict(s_lr_te)
            nn_error_acc_subject = m_nn_te_acc - nnmodel_acc_subject.predict(s_nn_te)
            lr_error_rocauc_proxy = m_lr_te_rocauc - lrmodel_rocauc_proxy.predict(proxy_lr_te_va)
            nn_error_rocauc_proxy = m_nn_te_rocauc - nnmodel_rocauc_proxy.predict(proxy_nn_te_va)
            lr_error_acc_proxy = m_lr_te_acc - lrmodel_acc_proxy.predict(proxy_lr_te_va)
            nn_error_acc_proxy = m_nn_te_acc - nnmodel_acc_proxy.predict(proxy_nn_te_va)

        else: 
            raise ValueError("model_name is not supported!")
        
        # merge the lr error and nn error, 
        error_list_rocauc = np.concatenate([lr_error_rocauc, nn_error_rocauc])
        error_list_acc = np.concatenate([lr_error_acc, nn_error_acc])
        error_list_rocauc_subject = np.concatenate([lr_error_rocauc_subject, nn_error_rocauc_subject])
        error_list_acc_subject = np.concatenate([lr_error_acc_subject, nn_error_acc_subject])
        error_list_rocauc_proxy = np.concatenate([lr_error_rocauc_proxy, nn_error_rocauc_proxy])
        error_list_acc_proxy = np.concatenate([lr_error_acc_proxy, nn_error_acc_proxy])

        # error, 
        errors[f"{model_name}_rocauc_fusion"].append(error_list_rocauc)
        errors[f"{model_name}_acc_fusion"].append(error_list_acc)
        errors[f"{model_name}_rocauc_lr"].append(lr_error_rocauc)
        errors[f"{model_name}_rocauc_nn"].append(nn_error_rocauc)
        errors[f"{model_name}_acc_lr"].append(lr_error_acc)
        errors[f"{model_name}_acc_nn"].append(nn_error_acc)
        # subject error, 
        errors_subject[f"{model_name}_rocauc_fusion"].append(error_list_rocauc_subject)
        errors_subject[f"{model_name}_acc_fusion"].append(error_list_acc_subject)
        errors_subject[f"{model_name}_rocauc_lr"].append(lr_error_rocauc_subject)
        errors_subject[f"{model_name}_rocauc_nn"].append(nn_error_rocauc_subject)
        errors_subject[f"{model_name}_acc_lr"].append(lr_error_acc_subject)
        errors_subject[f"{model_name}_acc_nn"].append(nn_error_acc_subject)
        # proxy error, 
        errors_proxy[f"{model_name}_rocauc_fusion"].append(error_list_rocauc_proxy)
        errors_proxy[f"{model_name}_acc_fusion"].append(error_list_acc_proxy)
        errors_proxy[f"{model_name}_rocauc_lr"].append(lr_error_rocauc_proxy)
        errors_proxy[f"{model_name}_rocauc_nn"].append(nn_error_rocauc_proxy)
        errors_proxy[f"{model_name}_acc_lr"].append(lr_error_acc_proxy)
        errors_proxy[f"{model_name}_acc_nn"].append(nn_error_acc_proxy)

        # reshape the error_list to 1-d array, 
        error_list_rocauc = error_list_rocauc.reshape(-1)
        error_list_acc = error_list_acc.reshape(-1)
        error_list_rocauc = winsorize(error_list_rocauc)
        error_list_acc = winsorize(error_list_acc)

        # ------iid check by central limit theorem-------
        # (1) randomly select half errors, total 30 groups, 
        n_group = 30
        group_means_rocauc = []
        group_means_acc = []
        group_means_rocauc_lr = []
        group_means_rocauc_nn = []
        group_means_acc_lr = []
        group_means_acc_nn = []
        for i in range(n_group):
            # randomly select half errors from error_list, 
            selected_errors_rocauc = np.random.choice(error_list_rocauc, int(len(error_list_rocauc)/2))
            selected_errors_acc = np.random.choice(error_list_acc, int(len(error_list_acc)/2))
            selected_errors_rocauc_lr = np.random.choice(lr_error_rocauc, int(len(lr_error_rocauc)/2))
            selected_errors_rocauc_nn = np.random.choice(nn_error_rocauc, int(len(nn_error_rocauc)/2))
            selected_errors_acc_lr = np.random.choice(lr_error_acc, int(len(lr_error_acc)/2))
            selected_errors_acc_nn = np.random.choice(nn_error_acc, int(len(nn_error_acc)/2))
            # calculate the mean of the error, 
            group_means_rocauc.append(np.mean(selected_errors_rocauc))
            group_means_acc.append(np.mean(selected_errors_acc))
            group_means_rocauc_lr.append(np.mean(selected_errors_rocauc_lr))
            group_means_rocauc_nn.append(np.mean(selected_errors_rocauc_nn))
            group_means_acc_lr.append(np.mean(selected_errors_acc_lr))
            group_means_acc_nn.append(np.mean(selected_errors_acc_nn))
        
        # (2) mean of each group, should be a normal distribution, 
        # use kurtosistest test to check the normal distribution, 
        from scipy.stats import normaltest
        stat_rocauc, p_rocauc = normaltest(group_means_rocauc)
        stat_acc, p_acc = normaltest(group_means_acc)
        stat_rocauc_lr, p_rocauc_lr = normaltest(group_means_rocauc_lr)
        stat_rocauc_nn, p_rocauc_nn = normaltest(group_means_rocauc_nn)
        stat_acc_lr, p_acc_lr = normaltest(group_means_acc_lr)
        stat_acc_nn, p_acc_nn = normaltest(group_means_acc_nn)
        # model_name = "linear", test_name = "iid", target_name = "rocauc", subject_name = "fusion"
        p_values[f"{model_name}_iid_rocauc_fusion"].append(p_rocauc)
        p_values[f"{model_name}_iid_acc_fusion"].append(p_acc)
        p_values[f"{model_name}_iid_rocauc_lr"].append(p_rocauc_lr)
        p_values[f"{model_name}_iid_rocauc_nn"].append(p_rocauc_nn)
        p_values[f"{model_name}_iid_acc_lr"].append(p_acc_lr)
        p_values[f"{model_name}_iid_acc_nn"].append(p_acc_nn)

        # ------id check by ks test-------
        n_sample_per_group = 30
        n_group = len(error_list_rocauc) // n_sample_per_group
        n_group_lr = len(lr_error_rocauc) // n_sample_per_group
        n_group_nn = len(nn_error_rocauc) // n_sample_per_group
        # randomly split the error_list into n_group groups, 
        group_errors_rocauc = np.array_split(error_list_rocauc, n_group)
        group_errors_acc = np.array_split(error_list_acc, n_group)
        group_errors_rocauc_lr = np.array_split(lr_error_rocauc, n_group_lr)
        group_errors_rocauc_nn = np.array_split(nn_error_rocauc, n_group_nn)
        group_errors_acc_lr = np.array_split(lr_error_acc, n_group_lr)
        group_errors_acc_nn = np.array_split(nn_error_acc, n_group_nn)

        # transformation
        # use ks test pair-wise to check the whether n groups are from the same distribution, 
        from scipy.stats import kruskal
        stat_rocauc, p_rocauc = kruskal(*group_errors_rocauc)
        stat_acc, p_acc = kruskal(*group_errors_acc)
        stat_rocauc_lr, p_rocauc_lr = kruskal(*group_errors_rocauc_lr)
        stat_rocauc_nn, p_rocauc_nn = kruskal(*group_errors_rocauc_nn)
        stat_acc_lr, p_acc_lr = kruskal(*group_errors_acc_lr)
        stat_acc_nn, p_acc_nn = kruskal(*group_errors_acc_nn)
        # model_name = "linear", test_name = "id", target_name = "rocauc", subject_name = "fusion"
        p_values[f"{model_name}_id_rocauc_fusion"].append(p_rocauc)
        p_values[f"{model_name}_id_acc_fusion"].append(p_acc)
        p_values[f"{model_name}_id_rocauc_lr"].append(p_rocauc_lr)
        p_values[f"{model_name}_id_rocauc_nn"].append(p_rocauc_nn)
        p_values[f"{model_name}_id_acc_lr"].append(p_acc_lr)
        p_values[f"{model_name}_id_acc_nn"].append(p_acc_nn)
        
        # ------bias check by something-------
        # null hypothesis: check the bias of the error is 0,
        from scipy.stats import ttest_1samp
        stat_rocauc, p_rocauc = ttest_1samp(error_list_rocauc, 0)
        stat_acc, p_acc = ttest_1samp(error_list_acc, 0)
        stat_rocauc_lr, p_rocauc_lr = ttest_1samp(lr_error_rocauc, 0)
        stat_rocauc_nn, p_rocauc_nn = ttest_1samp(nn_error_rocauc, 0)
        stat_acc_lr, p_acc_lr = ttest_1samp(lr_error_acc, 0)
        stat_acc_nn, p_acc_nn = ttest_1samp(nn_error_acc, 0)
        # model_name = "linear", test_name = "unbias", target_name = "rocauc", subject_name = "fusion"
        p_values[f"{model_name}_bias_rocauc_fusion"].append(p_rocauc)
        p_values[f"{model_name}_bias_acc_fusion"].append(p_acc)
        p_values[f"{model_name}_bias_rocauc_lr"].append(p_rocauc_lr)
        p_values[f"{model_name}_bias_rocauc_nn"].append(p_rocauc_nn)
        p_values[f"{model_name}_bias_acc_lr"].append(p_acc_lr)
        p_values[f"{model_name}_bias_acc_nn"].append(p_acc_nn)
        # ---------------------------------------------------------------------------------------------------
print("complete evaluation!")
# ---------------------------evaluation task-----------------------------
# save the p_values, errors, errors_subject, errors_proxy, (dictories file)
import pickle
with open('./expdata/' + dataset_name + '/p_values.pkl', 'wb') as f:
    pickle.dump(p_values, f)
with open('./expdata/' + dataset_name + '/errors.pkl', 'wb') as f:
    pickle.dump(errors, f)
with open('./expdata/' + dataset_name + '/errors_subject.pkl', 'wb') as f:
    pickle.dump(errors_subject, f)
with open('./expdata/' + dataset_name + '/errors_proxy.pkl', 'wb') as f:
    pickle.dump(errors_proxy, f)
print("complete write!")
# ---------------------------evaluation task-----------------------------
# load the p_values, errors, errors_subject, errors_proxy, (dictories file)
import pickle
with open('./expdata/' + dataset_name + '/p_values.pkl', 'rb') as f:
    p_values = pickle.load(f)
with open('./expdata/' + dataset_name + '/errors.pkl', 'rb') as f:
    errors = pickle.load(f)
with open('./expdata/' + dataset_name + '/errors_subject.pkl', 'rb') as f:
    errors_subject = pickle.load(f)
with open('./expdata/' + dataset_name + '/errors_proxy.pkl', 'rb') as f:
    errors_proxy = pickle.load(f)
# ---------------------------evaluation task-----------------------------
print("complete load!")

# time holdout10: 15.326334953308105 per subject: 0.007663167476654052
# time holdout20: 17.94678568840027 per subject: 0.008973392844200134
# time holdout50: 26.657612800598145 per subject: 0.013328806400299072
# time holdout100: 41.87955331802368 per subject: 0.020939776659011842
# time cv5: 92.27963638305664 per subject: 0.04613981819152832
# time cv10: 159.310613155365 per subject: 0.0796553065776825
# time bootstrap: 43.22130465507507 per subject: 0.021610652327537536
# all time: 396.6331970691681 per subject: 0.19831659853458405