import warnings
from copy import deepcopy
from itertools import chain, combinations, permutations

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from numpy import ndarray
from sklearn.linear_model import LinearRegression

#######################################################################################################################
"""
def fisherZ(correlation_matrix, X, Y, condition_set, sample_size):
    "Perform an independence test using Fisher-Z's test and output the p-value of the test"
    var = list((X, Y) + condition_set)
    sub_corr_matrix = correlation_matrix[np.ix_(var, var)]
    inv = np.linalg.inv(sub_corr_matrix)
    r = -inv[0, 1] / sqrt(inv[0, 0] * inv[1, 1])
    Z = 0.5 * log((1 + r) / (1 - r))
    X = sqrt(sample_size - len(condition_set) - 3) * abs(Z)
    # p = 2 * (1 - norm.cdf(abs(X)))
    p = 1 - norm.cdf(abs(X))
    return p
"""

#######################################################################################################################
"""
def chisq(data, X, Y, conditioning_set, G_sq=False):
    "Perform an independence test using chi-square test and output the p-value of the test"
    # Step 1: Subset the data
    categories_list = [np.unique(data[:, i]) for i in
                       list(conditioning_set)]  # Obtain the categories of each variable in conditioning_set
    value_config_list = cartesian_product(
        categories_list)  # Obtain all the possible value configurations of the conditioning_set (e.g., [[]] if categories_list == [])

    max_categories = int(
        np.max(data)) + 1  # Used to fix the size of the contingency table (before applying Fienberg's method)

    sum_of_chi_square = 0  # initialize a zero chi_square statistic
    sum_of_df = 0  # initialize a zero degree of freedom

    def recursive_and(L):
        "A helper function for subsetting the data using the conditions in L of the form [(variable, value),...]"
        if len(L) == 0:
            return data
        else:
            condition = data[:, L[0][0]] == L[0][1]
            i = 1
            while i < len(L):
                new_conjunct = data[:, L[i][0]] == L[i][1]
                condition = new_conjunct & condition
                i += 1
            return data[condition]

    for value_config in range(len(value_config_list)):
        L = list(zip(conditioning_set, value_config_list[value_config]))
        sub_data = recursive_and(L)[:, [X,
                                        Y]]  # obtain the subset dataset (containing only the X, Y columns) with only rows specified in value_config

        # Step 2: Generate contingency table (applying Fienberg's method)
        def make_ctable(D, cat_size):
            x = np.array(D[:, 0], dtype=np.dtype(int))
            y = np.array(D[:, 1], dtype=np.dtype(int))
            bin_count = np.bincount(cat_size * x + y)  # Perform linear transformation to obtain frequencies
            diff = (cat_size ** 2) - len(bin_count)
            if diff > 0:  # The number of cells generated by bin_count can possibly be less than cat_size**2
                bin_count = np.concatenate(
                    (bin_count, np.zeros(diff)))  # In that case, we concatenate some zeros to fit cat_size**2
            ctable = bin_count.reshape(cat_size, cat_size)
            ctable = ctable[~np.all(ctable == 0, axis=1)]  # Remove rows consisted entirely of zeros
            ctable = ctable[:, ~np.all(ctable == 0, axis=0)]  # Remove columns consisted entirely of zeros
            return ctable

        ctable = make_ctable(sub_data, max_categories)

        # Step 3: Calculate chi-square statistic and degree of freedom from the contingency table
        row_sum = np.sum(ctable, axis=1)
        col_sum = np.sum(ctable, axis=0)
        expected = np.outer(row_sum, col_sum) / sub_data.shape[0]
        if G_sq == False:
            chi_sq_stat = np.sum(((ctable - expected) ** 2) / expected)
        else:
            div = np.divide(ctable, expected)
            div[div == 0] = 1  # It guarantees that taking natural log in the next step won't cause any error
            chi_sq_stat = 2 * np.sum(ctable * np.log(div))
        df = (ctable.shape[0] - 1) * (ctable.shape[1] - 1)

        sum_of_chi_square += chi_sq_stat
        sum_of_df += df

    # Step 4: Compute p-value from chi-square CDF
    if sum_of_df == 0:
        return 1
    else:
        return chi2.sf(sum_of_chi_square, sum_of_df)
"""


#######################################################################################################################

def append_value(array, i, j, value):
    """
    Append value to the list at array[i, j]
    """
    if array[i, j] is None:
        array[i, j] = [value]
    else:
        array[i, j].append(value)


#######################################################################################################################

def powerset(L):
    """
    Return the powerset of L (list)
    """
    s = list(L)
    return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))


#######################################################################################################################

def cartesian_product(lists):
    "Return the Cartesian product of lists (List of lists)"
    result = [[]]
    for pool in lists:
        result = [x + [y] for x in result for y in pool]
    return result


#######################################################################################################################

def list_union(L1, L2):
    "Return the union of L1 and L2 (lists)"
    return list(set(L1 + L2))


#######################################################################################################################

def list_intersection(L1, L2):
    "Return the intersection of L1 and L2 (lists)"
    return list(set(L1) & set(L2))


#######################################################################################################################

def list_minus(L1, L2):
    "Return a list of members in L1 (list) that are in L2 (list)"
    return list(set(L1) - set(L2))


#######################################################################################################################

def sort_dict_ascending(dict, descending=False):
    "Sort dict (dictionary) by its value in ascending order"
    dict_list = sorted(dict.items(), key=lambda x: x[1], reverse=descending)
    return {dict_list[i][0]: dict_list[i][1] for i in range(len(dict_list))}


#######################################################################################################################

def np_ignore_nan(ndarray):
    "Replace all nan entries as blank entries"
    Output = ndarray.astype(str)
    Output[Output == 'nan'] = ''
    return Output


#######################################################################################################################

def neighbors(adjmat, i):
    "Find the neighbors of node i in the adjacency matrix adjmat (np.ndarray)"
    l0 = np.where(adjmat[i, :] == 0)[0]
    l1 = np.where(adjmat[i, :] == 1)[0]
    return np.concatenate((l0, l1))


#######################################################################################################################

def degree_graph(adjmat):
    "Return the maximum number of edges connected to a node in the adjacency matrix adjmat (np.ndarray)"
    nodes = range(len(adjmat))
    max_degree = 0
    for i in nodes:
        len_neigh_i = len(neighbors(adjmat, i))
        if len_neigh_i > max_degree:
            max_degree = len_neigh_i
    return max_degree


#######################################################################################################################
# Graph-related functions
#######################################################################################################################
def find_circ_arrow(adjmat):
    "Return the list of i o-> j as (i, j) in the adjacency matrix adjmat (np.ndarray)"
    L = np.where(adjmat == 1)
    return list(zip(L[0], L[1]))


#######################################################################################################################

def find_tail(adjmat):
    "Return the list of i --o j as (i, j) in the adjacency matrix adjmat (np.ndarray)"
    L = np.where(adjmat == 0)
    return list(zip(L[0], L[1]))


#######################################################################################################################

def find_undirected(adjmat):
    "Return the list of undirected edge i --- j as (i, j) in the adjacency matrix adjmat (np.ndarray)"
    return [(edge[0], edge[1]) for edge in find_tail(adjmat) if adjmat[edge[1], edge[0]] == 0]


#######################################################################################################################

def find_fully_directed(adjmat):
    "Return the list of directed edges i --> j as (i, j) in the adjacency matrix adjmat (np.ndarray)"
    return [(edge[0], edge[1]) for edge in find_circ_arrow(adjmat) if adjmat[edge[1], edge[0]] == 0]


#######################################################################################################################

def find_bi_directed(adjmat):
    "Return the list of directed edges i <-> j as (i, j) in the adjacency matrix adjmat (np.ndarray)"
    return [(edge[0], edge[1]) for edge in find_circ_arrow(adjmat) if adjmat[edge[1], edge[0]] == 1]


#######################################################################################################################

def find_adj(adjmat):
    "Return the list of adjacencies i --- j as (i, j) in the adjacency matrix adjmat (np.ndarray)"
    return list(find_tail(adjmat) + find_circ_arrow(adjmat))


#######################################################################################################################

def is_fully_directed(adjmat, i, j):
    "Return True if i --> j holds in the adjacency matrix adjmat (np.ndarray) and False otherwise"
    return adjmat[i, j] == 1 and adjmat[j, i] == 0


#######################################################################################################################

def is_undirected(adjmat, i, j):
    "Return True if i --- j holds in the adjacency matrix adjmat (np.ndarray) and False otherwise"
    return adjmat[i, j] == 0 and adjmat[j, i] == 0


#######################################################################################################################

def is_bi_directed(adjmat, i, j):
    "Return True if i <-> j holds in the adjacency matrix adjmat (np.ndarray) and False otherwise"
    return adjmat[i, j] == 1 and adjmat[j, i] == 1


#######################################################################################################################

def is_adj(adjmat, i, j):
    "Return True if i o-o j holds in the adjacency matrix adjmat (np.ndarray) and False otherwise"
    return is_fully_directed(adjmat, i, j) or is_fully_directed(adjmat, j, i) \
           or is_undirected(adjmat, i, j) or is_bi_directed(adjmat, i, j)


#######################################################################################################################

def find_unshielded_triples(adjmat):
    "Return the list of unshielded triples i o-o j o-o k as (i, j, k) from the adjacency matrix adjmat (np.ndarray)"
    return [(pair[0][0], pair[0][1], pair[1][1]) for pair in permutations(find_adj(adjmat), 2)
            if pair[0][1] == pair[1][0] and pair[0][0] != pair[1][1] and adjmat[pair[0][0], pair[1][1]] == -1]


#######################################################################################################################

def find_triangles(adjmat):
    "Return the list of non-ambiguous triangles i o-o j o-o k o-o i as (i, j, k) from the adjacency matrix adjmat (np.ndarray)"
    Adj = find_adj(adjmat)
    return [(pair[0][0], pair[0][1], pair[1][1]) for pair in permutations(Adj, 2)
            if pair[0][1] == pair[1][0] and pair[0][0] != pair[1][1] and (pair[0][0], pair[1][1]) in Adj]


#######################################################################################################################

def find_kites(graph):
    "Return the list of non-ambiguous kites i o-o j o-o l o-o k o-o i o-o l (where j and k are non-adjacent)\
    as (i, j, k, l) from the adjacency matrix adjmat (np.ndarray)"
    return [(pair[0][0], pair[0][1], pair[1][1], pair[0][2]) for pair in permutations(find_triangles(graph), 2)
            if pair[0][0] == pair[1][0] and pair[0][2] == pair[1][2]
            and pair[0][1] < pair[1][1] and graph[pair[0][1], pair[1][1]] == -1]


#######################################################################################################################

def find_all_conditioning_sets(adjmat, x, y):
    "return the list of conditioning sets of the neighbors of x or y in the adjacency matrix adjmat (np.ndarray)"
    neigh_x = neighbors(adjmat, x)
    neigh_y = neighbors(adjmat, y)
    pow_neigh_x = powerset(neigh_x)
    pow_neigh_y = powerset(neigh_y)
    return list_union(pow_neigh_x, pow_neigh_y)


#######################################################################################################################

def find_conditioning_sets_with_middle(adjmat, x, y, z):
    "return the list of conditioning sets of the neighbors of x or y which contains z in the adjacency matrix adjmat (np.ndarray)"
    return [S for S in find_all_conditioning_sets(adjmat, x, y) if z in S]


#######################################################################################################################

def find_conditioning_sets_without_middle(adjmat, x, y, z):
    "return the list of conditioning sets of the neighbors of x or y which does not contain z in the adjacency matrix adjmat (np.ndarray)"
    return [S for S in find_all_conditioning_sets(adjmat, x, y) if z not in S]


#######################################################################################################################

def find_uc(adjmat):
    "Return the list of unshielded colliders x --> y <-- z as (x, y, z) in the adjacency matrix adjmat (np.ndarray)\
    with asymmetry x < z"
    directed = find_fully_directed(adjmat)
    return [(pair[0][0], pair[0][1], pair[1][0]) for pair in permutations(directed, 2)
            if pair[0][1] == pair[1][1] and pair[0][0] < pair[1][0] and adjmat[pair[0][0], pair[1][0]] == -1]


#######################################################################################################################

def rearrange_columns(adjmat, PATH):
    "Rearrange the adjacency matrix adjmat (np.ndarray) according to the data imported at PATH"
    raw_col_names = list(pd.read_csv(PATH, sep='\t').columns)
    var_indices = []
    for name in raw_col_names:
        var_indices.append(int(name.split('X')[1]) - 1)
    new_indices = np.zeros_like(var_indices)
    for i in range(1, len(new_indices)):
        new_indices[var_indices[i]] = range(len(new_indices))[i]
    output = adjmat[:, new_indices]
    output = output[new_indices, :]
    return output


#######################################################################################################################

def dag2pattern(adjmat):
    "Generate the pattern of the adjacency matrix adjmat (np.ndarray)"
    pattern = deepcopy(adjmat)
    pattern[pattern == 1] = 0  # Remove all the arrowheads from the DAG to obtain the skeleton
    UC = find_uc(adjmat)
    for (i, j, k) in UC:
        pattern[i, j] = 1
        pattern[k, j] = 1

    UT = find_unshielded_triples(pattern)
    Tri = find_triangles(pattern)
    Kites = find_kites(pattern)

    Loop = True
    while Loop:
        Loop = False
        for (i, j, k) in UT:
            if is_fully_directed(pattern, i, j) and is_undirected(pattern, j, k):
                pattern[j, k] = 1
                Loop = True

        for (i, j, k) in Tri:
            if is_fully_directed(pattern, i, j) and is_fully_directed(pattern, j, k) and is_undirected(pattern, i, k):
                pattern[i, k] = 1
                Loop = True

        for (i, j, k, l) in Kites:
            if is_undirected(pattern, i, j) and is_undirected(pattern, i, k) and is_fully_directed(pattern, j,
                                                                                                   l) and is_fully_directed(
                pattern, k, l) \
                    and is_undirected(pattern, i, l):
                pattern[i, l] = 1
                Loop = True

    return pattern


#######################################################################################################################

def adjmat2digraph(adjmat):
    "Recover the directed graph from the adjacency matrix adjmat (np.ndarray) and return a nx_graph object"
    g = nx.DiGraph()
    nodes = range(len(adjmat))
    g.add_nodes_from(nodes)
    undirected = find_undirected(adjmat)
    directed = find_fully_directed(adjmat)
    bidirected = find_bi_directed(adjmat)
    for (i, j) in undirected:
        g.add_edge(i, j, color='g')  # Green edge: undirected edge
    for (i, j) in directed:
        g.add_edge(i, j, color='b')  # Blue edge: directed edge
    for (i, j) in bidirected:
        g.add_edge(i, j, color='r')  # Red edge: bidirected edge
    return g


#######################################################################################################################

def draw_graph(nx_graph):
    "Draw the nx_graph (networkx graph object)"
    print("Green: undirected; Blue: directed; Red: bi-directed")
    warnings.filterwarnings("ignore", category=UserWarning)
    edges = nx_graph.edges()
    colors = [nx_graph[u][v]['color'] for u, v in edges]
    pos = nx.circular_layout(nx_graph)
    nx.draw(nx_graph, pos=pos, with_labels=True, edge_color=colors)
    # nx.draw(graph, pos=pos, with_labels=True)
    plt.draw()
    plt.show()


#######################################################################################################################

def is_dsep(nx_graph, x, y, Z):
    "Return True if x and y are d-separated by the set Z in nx_graph (networkx graph object) and False otherwise"
    S = set([str(i) for i in Z])
    return nx.d_separated(nx_graph, {str(x)}, {str(y)}, S)


#######################################################################################################################

def tetrad2adjmat(path):
    "Convert the graph (.txt output by TETRAD) at path into an adjacency matrix (np.ndarray)"
    tetrad_file = pd.read_csv(path, sep='\t')

    if ',' in str(tetrad_file.loc[0][0]):
        var_names = str(tetrad_file.loc[0][0]).split(',')
    elif ';' in str(tetrad_file.loc[0][0]):
        var_names = str(tetrad_file.loc[0][0]).split(';')
    else:
        var_names = ''

    adjmat = np.eye(len(var_names), len(var_names))
    adjmat[adjmat == 1] = None
    adjmat[adjmat == 0] = -1

    bidirected = 0

    for i in range(2, tetrad_file.shape[0]):
        STR = str(tetrad_file.loc[i][0])
        if '-->' in STR:
            STR_truncated = STR.split('. ')[1].split(' --> ')
            LEFT = int(STR_truncated[0].split('X')[1]) - 1
            RIGHT = int(STR_truncated[1].split('X')[1]) - 1
            if adjmat[LEFT, RIGHT] != -1 and adjmat[RIGHT, LEFT] != -1:
                if adjmat[LEFT, RIGHT] != 1 or adjmat[RIGHT, LEFT] != 0:
                    raise ValueError("Inconsistency detected. Check the source file on", STR_truncated[0], "and",
                                     STR_truncated[1], ".")
            else:
                adjmat[LEFT, RIGHT] = 1
                adjmat[RIGHT, LEFT] = 0

        elif '---' in STR:
            STR_truncated = STR.split('. ')[1].split(' --- ')
            LEFT = int(STR_truncated[0].split('X')[1]) - 1
            RIGHT = int(STR_truncated[1].split('X')[1]) - 1
            if adjmat[LEFT, RIGHT] != -1 and adjmat[RIGHT, LEFT] != -1:
                if adjmat[LEFT, RIGHT] != 0 or adjmat[RIGHT, LEFT] != 0:
                    raise ValueError("Inconsistency detected. Check the source file on", STR_truncated[0], "and",
                                     STR_truncated[1], ".")
            else:
                adjmat[LEFT, RIGHT] = 0
                adjmat[RIGHT, LEFT] = 0

        elif '<->' in STR:
            bidirected += 1
            STR_truncated = STR.split('. ')[1].split(' <-> ')
            LEFT = int(STR_truncated[0].split('X')[1]) - 1
            RIGHT = int(STR_truncated[1].split('X')[1]) - 1
            if adjmat[LEFT, RIGHT] != -1 and adjmat[RIGHT, LEFT] != -1:
                if adjmat[LEFT, RIGHT] != 1 or adjmat[RIGHT, LEFT] != 1:
                    raise ValueError("Inconsistency detected. Check the source file on", STR_truncated[0], "and",
                                     STR_truncated[1], ".")
            else:
                adjmat[(LEFT, RIGHT)] = 1
                adjmat[(RIGHT, LEFT)] = 1

    if bidirected > 0:
        print("The source file contains", bidirected, "bi-directed edges.")

    return adjmat


#######################################################################################################################

def adjmat2tetrad(PATH, adjmat):
    "Convert the adjacency matrix adjmat (np.ndarray) into a text file at PATH which is readable by TETRAD"
    directed = find_fully_directed(adjmat)
    undirected = [(i, j) for (i, j) in find_undirected(adjmat) if i < j]
    bidirected = [(i, j) for (i, j) in find_circ_arrow(adjmat) if adjmat[j, i] == 1 and i < j]
    file = open(str(PATH), 'w')

    file.write('Graph Nodes: \n')
    node_size = adjmat.shape[0]
    for node in range(node_size - 1):
        file.write('X' + str(node + 1) + ';')
    file.write('X' + str(node_size) + '\n')
    file.write('\n')

    file.write('Graph Edges: \n')

    a = iter(range(1, len(directed) + len(undirected) + len(bidirected) + 1))
    for (i, j) in directed:
        file.write(str(next(a)) + '. ' + 'X' + str(i + 1) + ' --> X' + str(j + 1) + '\n')
    for (i, j) in undirected:
        file.write(str(next(a)) + '. ' + 'X' + str(i + 1) + ' --- X' + str(j + 1) + '\n')
    for (i, j) in bidirected:
        file.write(str(next(a)) + '. ' + 'X' + str(i + 1) + ' <-> X' + str(j + 1) + '\n')

    file.close()


#######################################################################################################################
# mvpc utils

def gen_vir_data(regMs, rss, Ws, num_test_var, effective_sz):
    """Generate the virtual data follows the full data distribution P(X, Y, S)"""
    data_vir = np.ndarray(shape=(effective_sz, num_test_var), dtype=float, order='F')
    for i in range(num_test_var):
        data_vir[:, i] = regMs[i].predict(Ws) + rss[i]
    return data_vir


def get_predictor_ws(mdata, num_test_var, effective_sz):
    """Get the data of the predictors, Ws
        1. no missing samples
        2. shuffuled
        3. sample size is equal to the effective sample size
    :params:
        mdata: the data of involved variables in the correction
        num_test_var: number of the variables in the test
        effective_sz: effective sample size
    :return:
        W_del_shuffle_eff: reshaped ndarray, data of the predictors
    """
    ## 1. contain no missing value
    Ws_del = test_wise_deletion(mdata[:, num_test_var:])

    ## 2. shuffled
    Ws_nrow, Ws_ncol = np.shape(Ws_del)
    arr = np.arange(Ws_nrow)
    np.random.shuffle(arr)

    ## 3. effective sample size
    indx_W_shuffle = arr[:effective_sz]

    W_del_shuffle_eff = Ws_del[indx_W_shuffle, :]  # the sample size of W should be equal to effective sample size
    return W_del_shuffle_eff.reshape(-1, Ws_ncol)


def cond_perm_c(X, Y, condition_set, prt_m, skel):
    """Check whether it requires a correction or not"""
    var = list((X, Y) + condition_set)

    if contain_crrn_m(var, prt_m):
        if contain_common_neighbors_prt_mvar(X, Y, condition_set, skel, prt_m):
            return True
        else:
            return False
    else:
        return False


def contain_crrn_m(var, prt_m):
    """Check if the missingness indicators of var
    in the list of the ones requiring correction"""
    intersection_var_m = list(set(var) & set(prt_m['m']))
    if len(intersection_var_m) > 0:
        return True
    else:
        return False


def contain_common_neighbors_prt_mvar(X, Y, condition_set, skel, prt_m):
    """Check whether X and Y have at least one common neighbor.
    If they have no common neighbors,
        return False
    If they have common neighbors,
        but the common neighbors are not the parents of the missingness indicators of the variables in the test,
        return False
    If they have common neighbors,
        at least one common neighbor is the parent of the missingness indicator of the variables in the test,
        return True.
    """
    adj_skel_Mx = nx.to_numpy_array(skel).astype(int)

    # get children of X
    X_child = adj_skel_Mx[X, :]

    # get children of Y
    Y_child = adj_skel_Mx[Y, :]

    common_neighbor = (X_child == 1) & (Y_child == 1)
    if sum(common_neighbor) > 0:  # have at least one common neighbor
        indx = np.array([i for i in range(len(Y_child))])
        common_neighbor_indx = indx[common_neighbor]
        var = [X] + [Y] + list(condition_set)
        prt_ls = get_prt_mvars(var, prt_m)
        if len(list(set(common_neighbor_indx) & set(prt_ls))):
            # at least one common neighbor is the parent of M
            return True
        else:  # the common neighbors are not the parent of M
            return False
    else:  # have no common neighbor
        return False


def get_prt_mvars(var, prt_m):
    """ Get the parents of missingness indicators of XYS
    :params:
        - var: a list or a tuple
    :return:
        - W_indx_: a list with unique elements
    """
    W_indx_ = []
    for vi in var:
        if vi in prt_m['m']:  # vi has a missingness indicator requiring correction
            W_indx_ += get_prt_of_mi(vi, prt_m)
    W_indx_ = list(np.unique(W_indx_))
    return W_indx_


def get_prt_of_mi(vi, prt_m):
    """Get the parents of the missingness indicator, vi"""
    for i, mi in enumerate(prt_m['m']):
        if mi == vi:
            prti = prt_m['prt'][i]
            return list(prti)


def get_prt_mw(W_indx_, prt_m):
    """Iteratively get the parents of missingness indicators of W
    :params:
        W_indx_: a list with unique elements
    :return:
        W_indx: a list with unique elements
    """
    W_indx = W_indx_
    prt_W = get_prt_mvars(W_indx, prt_m)
    stop_cond = list(set(prt_W) - set(W_indx))
    while len(stop_cond) > 0:  # There are parents of W_indx
        W_indx += prt_W
        W_indx = list(np.unique(W_indx))
        prt_W = get_prt_mvars(W_indx, prt_m)
        stop_cond = list(set(prt_W) - set(W_indx))

    # No more parents of W_indx outside of the list W_indx
    return list(np.unique(W_indx))


def test_wise_deletion(data):
    """dataset after test-wise deletion"""
    indxCompleteRows = get_indx_complete_rows(data)
    return data[indxCompleteRows, :]


def learn_regression_model(tdel_data, num_model):
    """Learn regression model for each variable in the independent test
    :params:
        tdel_data: test-wise deleted dataset
        num_model: number of regression models
        num_predictor: number of predictors
    :return:
        regressMs: list, models
        residuals: list, residuals
    """
    regressMs = []
    residuals = []
    for i in range(num_model):
        y = tdel_data[:, i]
        X = tdel_data[:, num_model:]
        regressMs.append(LinearRegression().fit(X, y))
        residuals.append(get_residual(regressMs[i], X, y))
    return regressMs, residuals


def get_residual(regM, X, y):
    """get the residuals of a regression model"""
    prediction = regM.predict(X)
    residual = (y - prediction)
    return residual


def get_sub_correlation_matrix(mvdata):
    """"
    Get the correlation matrix of the input data
    -------
    INPUT:
    -------
    mvdata: data, columns represent variables, rows represent records/samples
    -------
    OUTPUT:
    -------
    matrix: the correlation matrix of all the variables
    sample_size: the sample size of the dataset after test-wise deletion
    """
    indxRows = get_indx_complete_rows(mvdata)
    matrix = np.corrcoef(mvdata[indxRows, :], rowvar=False)
    sample_size = len(indxRows)
    return matrix, sample_size


def get_indx_complete_rows(mvdata):
    """
    Get the index of the rows with complete records
    -------
    INPUT:
    -------
    mvdata: data, columns represent variables, rows represent records/samples
    -------
    OUTPUT:
    -------
    the index of the rows with complete records
    """
    nrow, ncol = np.shape(mvdata)
    bindxRows = np.ones((nrow,), dtype=bool)
    indxRows = np.array(list(range(nrow)))
    for i in range(ncol):
        bindxRows = np.logical_and(bindxRows, ~np.isnan(mvdata[:, i]))
    indxRows = indxRows[bindxRows]
    return indxRows
