import numpy as np
def split(data, prop = None, seed = 1000):
    # train-test split
    [X_raw, y_raw] = data
    np.random.seed(seed)
    m = len(X_raw)
    if prop is None:
        prop = 0.2 * np.ones(m)
    X_train, X_test = [], []
    y_train, y_test = [], []
    for j in range(m):
        n_j = len(y_raw[j])
        idx_test = np.random.choice(n_j, int(prop[j] * n_j), replace = False)
        X_test.append(X_raw[j][idx_test])
        y_test.append(y_raw[j][idx_test])
        idx_train = np.delete(np.array(range(n_j)), idx_test)
        X_train.append(X_raw[j][idx_train])
        y_train.append(y_raw[j][idx_train])
    data_train = [X_train, y_train]
    data_test = [X_test, y_test]
    return [data_train, data_test]


def split_cv(n_list, n_fold = 5, seed = 1000):
    # train-test split for cross-validation
    np.random.seed(seed)
    m = len(n_list)
    ans = []
    for j in range(m):
        n_j = n_list[j]
        perm = np.random.permutation(n_j)
        q, r = int(n_j / n_fold), n_j % n_fold
        ans_j = []
        for k in range(n_fold):           
            if k < r:
                tmp = [perm[i * n_fold + k] for i in range(q + 1)]
            else:
                tmp = [perm[i * n_fold + k] for i in range(q)]
            ans_j.append(tmp)
        ans.append(ans_j)
    return ans


def MTL_preprocessing(data, link = 'linear', intercept = True, n_class = 1, standardization = False):
    # standardization of data
    m = len(data[0])
    d = data[0][0].shape[1]
    n_list = np.zeros(m).astype(int)

    if not standardization:
        X_means, X_stds = np.zeros((d, 1)), np.zeros((d, 1))
        if intercept:
            X_means = np.vstack((np.zeros((1, 1)), X_means))
            X_stds = np.vstack((np.ones((1, 1)), X_stds))
        y_mean, y_std = 0, 1
        X, Y = [], []        
        for j in range(m):
            # load X
            tmp = data[0][j]
            n_list[j] = tmp.shape[0]
            if intercept:
                tmp = np.hstack((np.ones((n_list[j], 1)), tmp))
            X.append(tmp)

            # load y
            d_out = 1
            if link == 'linear':
                y = data[1][j]
                Y.append(y.reshape(-1, 1))
            if link == 'logistic':
                if n_class == 2:
                    y = data[1][j]
                    Y.append(y.reshape(-1, 1))
                else: # n_class > 2, use one-hoc encoding
                    d_out = n_class
                    y = data[1][j]
                    rows = np.arange(y.shape[0])
                    tmp = np.zeros((y.shape[0], n_class))
                    tmp[rows, y.reshape(-1,)] = 1
                    Y.append(tmp)
        return [X, Y, X_means, X_stds, y_mean, y_std, n_list, d_out]
    

