from itertools import combinations

import numpy as np
from causallearn.graph.GraphClass import CausalGraph
from causallearn.utils.cit import chisq, gsq
from causallearn.utils.PCUtils.Helper import append_value
from tqdm.auto import tqdm


def skeleton_discovery(
    data,
    alpha,
    indep_test,
    stable=True,
    background_knowledge=None,
    labels={},
    verbose=False,
    show_progress=True,
):
    """
    Perform skeleton discovery

    Parameters
    ----------
    data : input data set (numpy ndarray), shape (n_samples, n_features).
    alpha: float, desired significance level of independence tests (p_value) in (0,1)
    indep_test : the function of the independence test being used
            [fisherz, chisq, gsq, mv_fisherz, kci]
           - fisherz: Fisher's Z conditional independence test
           - chisq: Chi-squared conditional independence test
           - gsq: G-squared conditional independence test
           - mv_fisherz: Missing-value Fishers'Z conditional independence test
           - kci: Kernel-based conditional independence test
    stable : run stabilized skeleton discovery if True (default = True)
    background_knowledge : background knowledge
    verbose : True iff verbose output should be printed.
    show_progress : True iff the algorithm progress should be show in console.

    Returns
    -------
    cg : a CausalGraph object. Where cg.G.graph[j,i]=0 and cg.G.graph[i,j]=1 indicates  i -> j ,
                    cg.G.graph[i,j] = cg.G.graph[j,i] = -1 indicates i -- j,
                    cg.G.graph[i,j] = cg.G.graph[j,i] = 1 indicates i <-> j.

    """

    assert type(data) == np.ndarray
    assert 0 < alpha < 1

    n_features = data.shape[1]
    cg = CausalGraph(n_features, labels=labels)
    cg.set_ind_test(indep_test)
    cg.data_hash_key = hash(str(data))
    if indep_test == chisq or indep_test == gsq:
        # if dealing with discrete data, data is numpy.ndarray with n rows m columns,
        # for each column, translate the discrete values to int indexs starting from 0,
        #   e.g. [45, 45, 6, 7, 6, 7] -> [2, 2, 0, 1, 0, 1]
        #        ['apple', 'apple', 'pear', 'peach', 'pear'] -> [0, 0, 2, 1, 2]
        # in old code, its presumed that discrete `data` is already indexed,
        # but here we make sure it's in indexed form, so allow more user input e.g. 'apple' ..
        def _unique(column):
            return np.unique(column, return_inverse=True)[1]

        cg.is_discrete = True
        cg.data = np.apply_along_axis(_unique, 0, data).astype(np.int64)
        cg.cardinalities = np.max(cg.data, axis=0) + 1
    else:
        cg.data = data

    depth = -1
    pbar = tqdm(total=n_features) if show_progress else None
    while cg.max_degree() - 1 > depth:
        depth += 1
        edge_removal = []
        if show_progress:
            pbar.reset()
        for x in range(n_features):
            if show_progress:
                pbar.update()
            if show_progress:
                pbar.set_description(f"Depth={depth}, working on node {x}")
            Neigh_x = cg.neighbors(x)
            if len(Neigh_x) < depth - 1:
                continue
            for y in Neigh_x:
                knowledge_ban_edge = False
                sepsets = set()
                if background_knowledge is not None and (
                    background_knowledge.is_forbidden(cg.G.nodes[x], cg.G.nodes[y])
                    and background_knowledge.is_forbidden(cg.G.nodes[y], cg.G.nodes[x])
                ):
                    knowledge_ban_edge = True
                if knowledge_ban_edge:
                    if not stable:
                        edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
                        if edge1 is not None:
                            cg.G.remove_edge(edge1)
                        edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
                        if edge2 is not None:
                            cg.G.remove_edge(edge2)
                        append_value(cg.sepset, x, y, ())
                        append_value(cg.sepset, y, x, ())
                        break
                    else:
                        edge_removal.append((x, y))  # after all conditioning sets at
                        edge_removal.append((y, x))  # depth l have been considered

                Neigh_x_noy = np.delete(Neigh_x, np.where(Neigh_x == y))
                for S in combinations(Neigh_x_noy, depth):
                    p = cg.ci_test(x, y, S)
                    if p > alpha:
                        if verbose:
                            print("%d ind %d | %s with p-value %f\n" % (x, y, S, p))
                        if not stable:
                            edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
                            if edge1 is not None:
                                cg.G.remove_edge(edge1)
                            edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
                            if edge2 is not None:
                                cg.G.remove_edge(edge2)
                            append_value(cg.sepset, x, y, S)
                            append_value(cg.sepset, y, x, S)
                            break
                        else:
                            edge_removal.append(
                                (x, y)
                            )  # after all conditioning sets at
                            edge_removal.append((y, x))  # depth l have been considered
                            for s in S:
                                sepsets.add(s)
                    else:
                        append_value(cg.p_values, x, y, p)
                        if verbose:
                            print("%d dep %d | %s with p-value %f\n" % (x, y, S, p))
                append_value(cg.sepset, x, y, tuple(sepsets))
                append_value(cg.sepset, y, x, tuple(sepsets))

        if show_progress:
            pbar.refresh()

        for x, y in list(set(edge_removal)):
            edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
            if edge1 is not None:
                cg.G.remove_edge(edge1)

    if show_progress:
        pbar.close()

    return cg


def local_skeleton_discovery(
    data, local_node, alpha, indep_test, mi=[], labels={}, verbose=False
):
    assert type(data) == np.ndarray
    assert local_node <= data.shape[1]
    assert 0 < alpha < 1

    no_of_var = data.shape[1]
    cg = CausalGraph(no_of_var, labels=labels)
    cg.set_ind_test(indep_test)
    cg.data_hash_key = hash(str(data))
    if indep_test == chisq or indep_test == gsq:

        def _unique(column):
            return np.unique(column, return_inverse=True)[1]

        cg.is_discrete = True
        cg.data = np.apply_along_axis(_unique, 0, data).astype(np.int64)
        cg.cardinalities = np.max(cg.data, axis=0) + 1
    else:
        cg.data = data

    depth = -1
    x = local_node
    # Remove edges between nodes in MI and F-node
    for i in mi:
        cg.remove_edge(x, i)

    while cg.max_degree() - 1 > depth:
        depth += 1

        local_neigh = np.random.permutation(cg.neighbors(x))
        # local_neigh = cg.neighbors(x)
        for y in local_neigh:
            Neigh_y = cg.neighbors(y)
            Neigh_y = np.delete(Neigh_y, np.where(Neigh_y == x))
            Neigh_y_f = []
            if depth > 0:
                Neigh_y_f = [s for s in Neigh_y if x in cg.neighbors(s)]
                # Neigh_y_f += mi

            for S in combinations(Neigh_y_f, depth):
                p = cg.ci_test(x, y, S)
                if p > alpha:
                    if verbose:
                        print("%d ind %d | %s with p-value %f\n" % (x, y, S, p))
                    cg.remove_edge(x, y)
                    append_value(cg.sepset, x, y, S)
                    append_value(cg.sepset, y, x, S)

                    if depth == 0:
                        cg.append_to_mi(y)
                    break
                else:
                    append_value(cg.p_values, x, y, p)
                    if verbose:
                        print("%d dep %d | %s with p-value %f\n" % (x, y, S, p))

    return cg
