import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
import sys
matplotlib.use('TKAgg')  # this command will enable ctrl+C quit in command line

from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestRegressor, BaggingRegressor, GradientBoostingRegressor
from sklearn.neighbors import KNeighborsRegressor
# from sklearn.ensemble import GradientBoostingClassifier
# from sklearn.ensemble import RandomForestClassifier
# from sklearn.svm import SVR
from scipy.special import expit, xlogy
from scipy.optimize import fmin_bfgs, minimize_scalar, minimize
from sklearn.metrics import accuracy_score, mean_squared_error, mean_absolute_error
from genData import genMoon, gen_Logis, gen0, gen1, gen2, gen3tasks
import os
from sklearn.preprocessing import StandardScaler
import pickle
import random

my_path = os.path.dirname(os.path.realpath(__file__))
fontsize = 20
font = {'size'   : fontsize}
matplotlib.rc('font', **font)
PATH = [your own folder path]

'''
    This script is main file to test the incentived PAL algorithm for three entities, 
'''


def fit_reg(X, y, method='linear'):
    '''
        fit a statistical learner, either classifier or regressor
        But now it is a classifier that we focus on

        Input:
            X: nxp design matrix
            y: n-vector of response, here n-vector means a vector of size n

        Output:
            f: regressor
            yhat: n-vector, fitted values
    '''
    if method=='linear':
        f = LinearRegression()
    elif method=='rf':
        f = RandomForestRegressor(n_estimators=50, max_depth=5)
    elif method=='rf_large':
        f = RandomForestRegressor(n_estimators=100, max_depth=10)
    elif method=='bs':
        f = GradientBoostingRegressor(n_estimators=50, max_depth=5)
    elif method=='bag':
        f = BaggingRegressor(n_estimators=20)
    elif method=='knn':
        f = KNeighborsRegressor(n_neighbors=20)
    else:
        sys.exit('wrong reg method!')
    f.fit(X, y)
    yhat = f.predict(X)

    return f, yhat
 

def evaluate(y, yhat):
    '''
        evaluation of the predicted results

        Input:
            y: n-vector label (need to be 0,1 for classification)
            yhat: n-vector numeric (in-sample or out-sample)
        Output:
            acc: accuracy score, the smaller the better
            yhat: decoded label (for classification)
    '''
    acc = np.sqrt(mean_squared_error(y, yhat)) # unit scale

    return acc


def visual(X, y, yhat):
    '''
        Scatterplot of X0, X1
        for class labels 0 and 1
    '''
    m, f = 6, 20

    plt.plot(y, yhat, 'b.', markersize=m)
    plt.xlabel(r'$E(Y \mid X)$', fontsize=f)
    plt.ylabel(r'$\hat{Y}(X)$', fontsize=f)
    T = 20
    plt.xlim([-T, T])
    plt.ylim([-T, T])
    plt.plot([-T, T], [-T, T], 'r:')

 
def useAPIs(APIs, X_list, y=None):
    '''
        invoke the APIs for a particular entity for prediction
        Note: APIs is a list. APIs[0] is a local model. APIs[i] for i>0 is a 6-tuple of 4 functions and 2 coefficients
        Note: thisEntity is 1,2,or 3
        X_list = [X1, X2, X3]
    '''
    X1, X2, X3 = X_list[0], X_list[1], X_list[2]

    thisEntity, f = APIs[0]
    Xu = eval(f"X{thisEntity}")
    yhat = f.predict(Xu)
    for i in range(1,len(APIs)):

        tup = APIs[i]

        # if it is a local model
        if len(tup) == 2:
            _, f = tup[0], tup[1]
            yhat += f.predict(Xu)
        # if it is a PAL model
        else:
            
            u, v = tup[0], tup[1]

            if thisEntity != u:
                sys.exit('mismatched thisEntity!')

            Xv = eval(f"X{v}")
            yhat += (tup[2].predict(Xu) + tup[3].predict(Xv)) * tup[4] + (tup[5].predict(Xu) + tup[6].predict(Xv)) * tup[7]

    if y is not None:
        acc = evaluate(y, yhat)
    else:
        acc = None

    return yhat, acc


def runPAL(a, b, APIs, X_tr, r_tr, X_val, y_val, X_test, y_test, method, PRINT=False):
    '''
        This is function implements the adaptive/suggested version of AL in the paper
        for comparison purposes only

        Input:
            a: the real name (e.g., 1,2,3) of entity a
            b: the real name of entity b
            X: nxp design matrix
            ya: n-vector of A's response 
            yb: n-vector of B's response 
            X_test: nnxp design matrix for test
            ya_test: nn-vector of response for A's test
            yb_test: nn-vector of response for B's test
            eta_a: learning rate pre-fixed for A
            eta_b: learning rate pre-fixed for B
            PRINT: whether to log

        Output:
            the first element of which is the single-entity benchmark
    '''
    # load data
    Xa_tr, ra_tr, Xa_val, ya_val, Xa_test, ya_test = X_tr[a], r_tr[a], X_val[a], y_val[a], X_test[a], y_test[a]
    Xb_tr, rb_tr, Xb_val, yb_val, Xb_test, yb_test = X_tr[b], r_tr[b], X_val[b], y_val[b], X_test[b], y_test[b]

    # KEY control param
    tauA, tauB = 1.0, -1.0 # 1.0, -1.0 0.1, -0.1

    # === A's initialized section ===
    # B updates for A's learning
    if PRINT: print('Update B for A')
    yb = ra_tr + tauB * rb_tr
    fba, yba_fit = fit_reg(Xb_tr, yb, method)
    # print('training MSE_aa: ', mean_squared_error(res_a, yaa_fit))
    ya = yb - yba_fit

    # A updates for A's learning
    if PRINT: print('Update A for A')
    faa, _ = fit_reg(Xa_tr, ya, method)
 
    # === B's initialized section ===
    # A updates for B's learning
    if PRINT: print('Update A for B')
    ya = rb_tr + tauA * ra_tr
    fab, yab_fit = fit_reg(Xa_tr, ya, method)
    yb = ya - yab_fit

    # A updates for A's learning
    if PRINT: print('Update A for A')
    fbb, _ = fit_reg(Xb_tr, yb, method)

    # === A and B pack this round of API
    API_a = (a, b, faa, fba, 1/(1-tauA*tauB), fab, fbb, -tauB/(1-tauA*tauB))
    API_b = (b, a, fbb, fab, 1/(1-tauB*tauA), fba, faa, -tauA/(1-tauB*tauA))

    APIs[a].append(API_a)
    APIs[b].append(API_b)

    # Evaluate gains: mua, mub, muab, muba, mua2b, mub2a
    # A part
    _, a_val = useAPIs(APIs[a], X_val[1:], y_val[a])
    _, a_val_last = useAPIs(APIs[a][:-1], X_val[1:], y_val[a])
    muab = a_val_last - a_val

    APIs_a_local = [w for w in APIs[a][:-1]] # copy
    fa_local, _ = fit_reg(Xa_tr, ra_tr, method)
    APIs_a_local.append((a, fa_local))
    _, a_val_local = useAPIs(APIs_a_local, X_val[1:], y_val[a])

    muab = a_val_last - a_val
    mua = a_val_local - a_val

    # B part
    _, b_val = useAPIs(APIs[b], X_val[1:], y_val[b])
    _, b_val_last = useAPIs(APIs[b][:-1], X_val[1:], y_val[b])
    muba = b_val_last - b_val

    APIs_b_local = [w for w in APIs[b][:-1]] # copy
    fb_local, _ = fit_reg(Xb_tr, rb_tr, method)
    APIs_b_local.append((b, fb_local))
    _, b_val_local = useAPIs(APIs_b_local, X_val[1:], y_val[b])

    muba = b_val_last - b_val
    mub = b_val_local - b_val

    mub2a = muab - mua
    mua2b = muba - mub

    return APIs, muab, mua2b, muba, mub2a


def arg_max(idx_exclude, vec):
    '''
        return argmin index that must not be idx_exclude
    '''
    top2_idx = np.argsort(-vec)[:2] # sort from large to small
    if top2_idx[0] != idx_exclude:
        return top2_idx[0]
    else:
        return top2_idx[1]


def visu_result(acc_tr_pal, acc_te_pal, acc_te_orac, msg):
    '''
        Visualize the AL results and comparison with oracle and single A learner case

    '''
    I = len(acc_tr_pal)
    m, f = 7, 14
    # plt.figure()
    plt.title(msg, fontsize=f)
    # plt.plot(np.arange(I), acc_tr_pal, 'ko:', markersize=m, label='PAL (in-sample)')
    plt.plot(np.arange(I), acc_te_pal, 'ko-', markersize=m, label='PAL')
    plt.plot(np.arange(I), acc_te_orac * np.ones(I), 'rs-', markersize=m, label='Oracle')
    plt.plot(np.arange(I), acc_te_pal[0] * np.ones(I), 'c*-', markersize=m, label='Single')
    plt.legend(fontsize=f)
    plt.xlabel('Iteration', fontsize=f)
    plt.xticks([i for i in range(I)], [i for i in range(I)])
    plt.ylabel('Error', fontsize=f)
    # plt.ylim([0.1, 1])
    # plt.title('Iteration', fontsize=13)
    plt.xticks(fontsize=f)
    plt.yticks(fontsize=f)
    # plt.show()


def test(randomSchemeOverwrite=False, PRINT=False):

    # Control param
    method = 'linear' # 'linear', 'bs', 'rf', 'bag'
    I = 10 # 10
    sca = [1, 1, 1] # default: no scaling
    useSCALE = True # True False
    # autoStop = True # False, True whether to let A or B auto stop using their validation error


    n, s = 1000, 1
    S1, S2, S3 = np.array([0,1]), np.array([2,3]), np.array([4,5])
    X, y1, y2, y3, _ = gen3tasks(n, s, False)
    X_te, y1_te, y2_te, y3_te, _ = gen3tasks(5000, s, PRINT=False)

    if useSCALE:
        labels = np.concatenate((y1.reshape(-1,1), y2.reshape(-1,1), y3.reshape(-1,1)), axis=1)
        scaler = StandardScaler(with_mean=True, with_std=True).fit(labels)
        sca = scaler.scale_
        print('scaling factor is ', sca)
        labels = scaler.transform(labels)
        y1, y2, y3 = labels[:,0], labels[:,1], labels[:,2]

        labels_test = scaler.transform(np.concatenate((y1_te.reshape(-1,1), y2_te.reshape(-1,1), y3_te.reshape(-1,1)), axis=1))
        y_test = [labels_test[:,0], labels_test[:,1], labels_test[:,2]]

    # data
    n = X.shape[0]
    nval = int(n/3) # 1/3 for validation, 2/3 for training
    X_tr = [None, X[nval:, S1], X[nval:, S2], X[nval:, S3]]
    y_tr = [None, y1[nval:], y2[nval:], y3[nval:]]
    X_val = [None, X[:nval, S1], X[:nval, S2], X[:nval, S3]]
    y_val = [None, y1[:nval], y2[:nval], y3[:nval]]
    X_test = [None, X_te[:, S1], X_te[:, S2], X_te[:, S3]]
    y_test = [None, y_test[0], y_test[1], y_test[2]]

    '''
        Incent PAL
    '''
    # Init
    U = 1
    # C = np.array([np.Inf, 60, 1, 1]) # np.Inf is dummy
    C = np.array([np.Inf, 0.1, 0.1, 0.1])
    I = 10
    r_tr = [None for i in range(4)]
    APIs = [None for i in range(4)]
    f = [None for i in range(4)]
    mu = np.zeros((4, 4)) # to avoid inf-inf error
    mu2 = np.ones((4, 4)) * 999999 # to avoid 0 * inf error
    acc_te = np.zeros((4, I))
    acc_va = np.zeros((4, I))
    partiSet = [1, 2, 3] # initially no one quits PAL
    final_round = [I-1 for i in range(4)] # final round by default, indexed from zero

    for ite in range(I):

        print(f"Iteration {ite} mu = {mu}")
        print(f"Iteration {ite} mu2 = {mu2}")

        if ite == 0:
            # single-entity benchmark
            for i in range(1,4):
                f[i], _ = fit_reg(X_tr[i], y_tr[i], method)
                r_tr[i] = y_tr[i] - f[i].predict(X_tr[i])
                APIs[i] = [(i, f[i])]
                _, acc_va[i, 0] = useAPIs(APIs[i], X_val[1:], y_val[i])

        else:
            # detect if there is a consensus
            favor = -1 * np.ones(4)
            for i in partiSet:
                q = (U - C[i]) * mu[i, 1:] + C[1:] * mu2[i, 1:]
                # print(f"q vector: {q}")
                # print(f"U - C[i]: {U - C[i]}, mu[i, 1:]: {mu[i, 1:]}, C[1:]: {C[1:]}, mu2[i, 1:]: {mu2[i, 1:]}")
                favor[i] = 1 + arg_max(idx_exclude=i-1, vec=q)
                # print(f"favor[{i}]={favor[i]}")

            individualSet = []
            collab = False
            if favor[1] == 2 and favor[2] == 1 and (1 in partiSet) and (2 in partiSet):
                a, b = 1, 2
                collab = True
                individualSet = [3] if (3 in partiSet) else []
            elif favor[1] == 3 and favor[3] == 1 and (1 in partiSet) and (3 in partiSet):
                a, b = 1, 3
                collab = True
                individualSet = [2] if (2 in partiSet) else []
            elif favor[2] == 3 and favor[3] == 2 and (2 in partiSet) and (3 in partiSet):
                a, b = 2, 3
                collab = True
                individualSet = [1] if (1 in partiSet) else []
            else:
                collab = False
                individualSet = [i for i in partiSet]

            # TEST purpose -- success
            # collab = True
            # a, b = 1, 3
            # individualSet = [2]
            if randomSchemeOverwrite: # for baseline purpose
                idx = [1, 2, 3]
                random.shuffle(idx)
                collab = True 
                a, b, individualSet = idx[0], idx[1], [idx[2]]

            print(f"Iteration {ite}, individualSet = {individualSet}")

            # PAL one round
            if collab:
                APIs, mu[a,b], mu2[a,b], mu[b,a], mu2[b,a] = runPAL(
                    a, b, APIs, X_tr, r_tr, X_val, y_val, X_test, y_test, method, PRINT=PRINT)

            # for the remaining entity
            for i in individualSet:
                f[i], _ = fit_reg(X_tr[i], y_tr[i], method)
                APIs[i].append((i, f[i]))
            
        # update residual and validation/test loss for ALL entities and decide the stopping
        for c in range(1,4):

            _, acc_te[c, ite] = useAPIs(APIs[c], X_test[1:], y_test[c])
            _, acc_va[c, ite] = useAPIs(APIs[c], X_val[1:], y_val[c])

            if ite > 0 and acc_va[c, ite] > acc_va[c, ite-1] and c in partiSet:
                # partiSet.remove(c) # NOTE: to avoid stopping early, a client need not exit. just use the previous API and skip this round.
                # final_round[c] = ite - 1
                # print(f"Entity {c} quits the game at round {ite}, so the remaining entities are {partiSet}")
                print(f"Entity {c} skips the game at round {ite}")
                del APIs[c][-1]
                acc_te[c, ite] = acc_te[c, ite-1]
                acc_va[c, ite] = acc_va[c, ite-1]
            else:
                # if not skipping, update the residual
                yfit, _ = useAPIs(APIs[c], X_tr[1:])
                r_tr[c] = y_tr[c] - yfit

    # oracle benchmark
    f_ora = [None for i in range(4)]
    acc_ora = np.zeros(4)
    for i in range(1,4):
        f_ora[i], _ = fit_reg(X[nval:], y_tr[i], method)
        # print(f'f_ora[i]: {f_ora[i].coef_}, {f_ora[i].intercept_}')
        acc_ora[i] = evaluate(y_test[i], f_ora[i].predict(X_te))

    # print(f"y_tr = {y_tr}")
    if useSCALE:
        # scale back the original reg response
        acc_te[1,:] *= sca[0]
        acc_te[2,:] *= sca[1]
        acc_te[3,:] *= sca[2]
        acc_va[1,:] *= sca[0]
        acc_va[2,:] *= sca[1]
        acc_va[3,:] *= sca[2]
        acc_ora[1:] *= sca
    
    plt.figure(0)
    m, f = 7, 14
    plt.plot(np.arange(1, I+1), acc_te[1,:], 'b-.', label=f'Entity 1')
    plt.plot(np.arange(1, I+1), acc_te[2,:], 'r-o', label=f'Entity 2')
    plt.plot(np.arange(1, I+1), acc_te[3,:], 'g-x', label=f'Entity 3')
    plt.legend(fontsize=f)
    plt.xlabel('Iteration', fontsize=f)
    plt.xticks([i for i in range(I)], [i for i in range(I)])
    plt.ylabel('MSE', fontsize=f)
    # plt.ylim([0.1, 1])
    plt.xticks(fontsize=f)
    plt.yticks(fontsize=f)
    plt.show()

    acc_local = acc_te[1:, 0]
    acc_PAL = acc_te[1:, -1]
    # acc_PAL = [acc_te[i, final_round[i]] for i in [1,2,3]] # acc_te[1:, -2]
    print(f"acc_ora = {acc_ora[1:]}")
    print(f"acc_local = {acc_local}")
    print(f"acc_PAL = {acc_PAL}")
    # print(f"validation history:\n {acc_va[1:,:]}")

    return acc_local, acc_PAL



def test_MIMIC(I=10, randomSchemeOverwrite=False, PRINT=False):

    # Control param
    method = 'rf' # 'linear', 'bs', 'rf', 'bag'
     # 10
    sca = [1, 1, 1] # default: no scaling
    useSCALE = True # True False
    # autoStop = True # False, True whether to let A or B auto stop using their validation error

    df = np.genfromtxt(PATH+'/data/MIMIC_cleaned_data.csv', delimiter=',', skip_header=1)
    labels = df[:,np.array([3, 8, 17])]
    S1, S2, S3 = np.array([0,1,2,4,5,6,7]), np.array([0,1,2]), np.array([9,10,11,12,13,14,15,16])

    if useSCALE:
        scaler = StandardScaler(with_mean=True, with_std=True).fit(labels)
        sca = scaler.scale_
        print('scaling factor is ', sca)
        labels = scaler.transform(labels)
        
    X, X_te, y, y_test = train_test_split(df, labels, test_size=0.5)
    y1, y2, y3 = y[:,0], y[:,1], y[:,2]
    y_test = [y_test[:,0], y_test[:,1], y_test[:,2]]

    # data
    n = X.shape[0]
    nval = int(n/3) # 1/3 for validation, 2/3 for training
    X_tr = [None, X[nval:, S1], X[nval:, S2], X[nval:, S3]]
    y_tr = [None, y1[nval:], y2[nval:], y3[nval:]]
    X_val = [None, X[:nval, S1], X[:nval, S2], X[:nval, S3]]
    y_val = [None, y1[:nval], y2[:nval], y3[:nval]]
    X_test = [None, X_te[:, S1], X_te[:, S2], X_te[:, S3]]
    y_test = [None, y_test[0], y_test[1], y_test[2]]

    '''
        Incent PAL
    '''
    # Init
    U = 100
    # C = np.array([np.Inf, U, 0, U]) # not significant
    C = np.array([np.Inf, U, U, 0]) # significant
    # C = np.array([np.Inf, 0, U, U])
    I = 10
    r_tr = [None for i in range(4)]
    APIs = [None for i in range(4)]
    f = [None for i in range(4)]
    mu = np.zeros((4, 4)) # to avoid inf-inf error
    mu2 = np.ones((4, 4)) * 999999 # to avoid 0 * inf error
    acc_te = np.zeros((4, I))
    acc_va = np.zeros((4, I))
    partiSet = [1, 2, 3] # initially no one quits PAL
    final_round = [I-1 for i in range(4)] # final round by default, indexed from zero

    for ite in range(I):

        print(f"Iteration {ite} mu = {mu}")
        print(f"Iteration {ite} mu2 = {mu2}")

        if ite == 0:
            # single-entity benchmark
            for i in range(1,4):
                f[i], _ = fit_reg(X_tr[i], y_tr[i], method)
                r_tr[i] = y_tr[i] - f[i].predict(X_tr[i])
                APIs[i] = [(i, f[i])]
                _, acc_va[i, 0] = useAPIs(APIs[i], X_val[1:], y_val[i])

        else:
            # detect if there is a consensus
            favor = -1 * np.ones(4)
            for i in partiSet:
                q = (U - C[i]) * mu[i, 1:] + C[1:] * mu2[i, 1:]
                favor[i] = 1 + arg_max(idx_exclude=i-1, vec=q)
                # print(f"favor[{i}]={favor[i]}")

            individualSet = []
            collab = False
            if favor[1] == 2 and favor[2] == 1 and (1 in partiSet) and (2 in partiSet):
                a, b = 1, 2
                collab = True
                individualSet = [3] if (3 in partiSet) else []
            elif favor[1] == 3 and favor[3] == 1 and (1 in partiSet) and (3 in partiSet):
                a, b = 1, 3
                collab = True
                individualSet = [2] if (2 in partiSet) else []
            elif favor[2] == 3 and favor[3] == 2 and (2 in partiSet) and (3 in partiSet):
                a, b = 2, 3
                collab = True
                individualSet = [1] if (1 in partiSet) else []
            else:
                collab = False
                individualSet = [i for i in partiSet]

            # TEST purpose -- success
            # collab = True
            # a, b = 1, 3
            # individualSet = [2]
            if randomSchemeOverwrite: # for baseline purpose
                idx = [1, 2, 3]
                random.shuffle(idx)
                collab = True 
                a, b, individualSet = idx[0], idx[1], [idx[2]]

            print(f"Iteration {ite}, individualSet = {individualSet}")

            # PAL one round
            if collab:
                APIs, mu[a,b], mu2[a,b], mu[b,a], mu2[b,a] = runPAL(
                    a, b, APIs, X_tr, r_tr, X_val, y_val, X_test, y_test, method, PRINT=PRINT)

            # for the remaining entity
            # NOTE: suspend local training to avoid gain brought by itself, nothing to do with incentive. may resume it
            # for i in individualSet:
                # f[i], _ = fit_reg(X_tr[i], y_tr[i], method)
                # APIs[i].append((i, f[i]))

            
        # update residual and validation/test loss for ALL entities and decide the stopping
        for c in range(1,4):

            _, acc_te[c, ite] = useAPIs(APIs[c], X_test[1:], y_test[c])
            _, acc_va[c, ite] = useAPIs(APIs[c], X_val[1:], y_val[c])

            if ite > 0 and acc_va[c, ite] > acc_va[c, ite-1] and c in partiSet:
                # partiSet.remove(c) # NOTE: to avoid stopping early, a client need not exit. just use the previous API and skip this round.
                # final_round[c] = ite - 1
                # print(f"Entity {c} quits the game at round {ite}, so the remaining entities are {partiSet}")
                print(f"Entity {c} skips the game at round {ite}")
                del APIs[c][-1]
                acc_te[c, ite] = acc_te[c, ite-1]
                acc_va[c, ite] = acc_va[c, ite-1]
            else:
                # if not skipping, update the residual
                yfit, _ = useAPIs(APIs[c], X_tr[1:])
                r_tr[c] = y_tr[c] - yfit

    # oracle benchmark
    f_ora = [None for i in range(4)]
    acc_ora = np.zeros(4)
    for i in range(1,4):
        f_ora[i], _ = fit_reg(X[nval:], y_tr[i], method)
        # print(f'f_ora[i]: {f_ora[i].coef_}, {f_ora[i].intercept_}')
        acc_ora[i] = evaluate(y_test[i], f_ora[i].predict(X_te))

    # print(f"y_tr = {y_tr}")
    if useSCALE:
        # scale back the original reg response
        acc_te[1,:] *= sca[0]
        acc_te[2,:] *= sca[1]
        acc_te[3,:] *= sca[2]
        acc_va[1,:] *= sca[0]
        acc_va[2,:] *= sca[1]
        acc_va[3,:] *= sca[2]
        acc_ora[1:] *= sca
        # pass
    
    # plt.figure(0)
    # m, f = 7, 14
    # plt.plot(np.arange(1, I+1), (acc_te[1,:]), 'b-.', label=f'Entity 1')
    # plt.plot(np.arange(1, I+1), (acc_te[2,:]), 'r-o', label=f'Entity 2')
    # plt.plot(np.arange(1, I+1), (acc_te[3,:]), 'g-x', label=f'Entity 3')
    # plt.legend(fontsize=f)
    # plt.xlabel('Iteration', fontsize=f)
    # plt.xticks([i for i in range(I)], [i for i in range(I)])
    # plt.ylabel('MSE', fontsize=f)
    # # plt.ylim([0.1, 1])
    # plt.xticks(fontsize=f)
    # plt.yticks(fontsize=f)
    # plt.show()

    acc_local = acc_te[1:, 0]
    acc_PAL = acc_te[1:, -1]
    # acc_PAL = [acc_te[i, final_round[i]] for i in [1,2,3]] # acc_te[1:, -2]
    print(f"acc_ora = {acc_ora[1:]}")
    print(f"acc_local = {acc_local}")
    print(f"acc_PAL = {acc_PAL}")
    # print(f"validation history:\n {acc_va[1:,:]}")

    return np.concatenate((acc_local, acc_PAL), axis=0), acc_te[1:,:]


def print_mean_ste(res, prec):
    '''
        print out the mean and std of a data result 
        res: numpy array of size nrep x p, where nrep is the num of replications
        prec: int of the number of digits after 0
    '''
    nrep, p = res.shape
    m, s = res.mean(axis=0), res.std(axis=0)/np.sqrt(nrep)
    out = ''
    for j in range(p):
        out = out + str(np.round(m[j],prec)) + ' (' + str(np.round(s[j],prec)) + ')  '
    print(out)
    
    return out

 
if __name__ == '__main__':
    # run single experiment

    # test_MIMIC(randomSchemeOverwrite = False)
    # test_ADULT()
    # test(randomSchemeOverwrite = False)

    # run multiple experiments

    # I = 10
    # nrep = 10
    # res = np.zeros((nrep, 6))
    # traject = np.zeros((3, I))

    # for i in range(nrep):
    #     res[i,:], temp = test_MIMIC(I = I)
    #     traject += temp/nrep

    # print('\nSummary: ')
    # print_mean_ste(res, 4)

    # print(f"traject\n {traject}")
