import os
import pandas as pd
import numpy as np
import networkx as nx
from itertools import chain, combinations
from dodiscover.ci import KernelCITest

def powerset(num_nodes, x, y):
    variables = set(range(num_nodes))
    variables.remove(x)
    variables.remove(y)

    s = list(variables)
    powerset_list =  list(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))
    powerset_list = list(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))[1:] # remove empty set
    return powerset_list


def test_unfaithfulness():
    tot_cancellations = []
    for seed in range(20):
        # TODO: fix dataset!
        data_dir="/efs/data/ER/gauss/unfaithful/unfaithful_1.0/1000_large20_dense" # large20
        data_file = os.path.join(data_dir, f"data_{seed}.csv")
        groundtruth_file = os.path.join(data_dir, f"groundtruth_{seed}.csv")
        unfaithful_adj_file = os.path.join(data_dir, f"unfaithful_{seed}.csv")
        assert groundtruth_file != unfaithful_adj_file

        A = np.genfromtxt(groundtruth_file, delimiter=",")
        A_p = np.genfromtxt(unfaithful_adj_file, delimiter=",") # graph faithful to P
        num_nodes = A.shape[0]
        X = pd.read_csv(data_file, header=None, names=list(range(num_nodes))).iloc[:500, :]

        G = nx.from_numpy_array(A, create_using=nx.DiGraph)
        G_p = nx.from_numpy_array(A_p, create_using=nx.DiGraph)

        num_nodes = A.shape[0]
        n_path_cancel = 0
        kci = KernelCITest()
        alpha = 0.05
        for child in range(num_nodes-1):
            parents = np.flatnonzero(A[:, child])
            faithful_parents = np.flatnonzero(A_p[:, child]) # parents according to unfaithful dependencies
            canceled_causes = np.setdiff1d(parents, faithful_parents) # parents in A_gt but not in A_faithful
            for parent in canceled_causes:
            # for y in range(x+1, num_nodes):
                for s in powerset(num_nodes, child, parent):
                    if len(s) ==1 :
                        s = set(s)
                    elif len(s) == 0:
                        s = set()
                    else:
                        s = set(s)
                    # is_sep = nx.d_separated(G_p, set([child]), set([parent]), s)
                    _, pvalue = kci.test(X, set([parent]), set([child]), s)
                    is_sep = pvalue > alpha

                    if is_sep:
                        n_path_cancel += 1
                        assert parent not in s and child not in s
                        # assert list(s) not in parents
                        # print(s)
                        if any([(element in faithful_parents) for element in s]):
                            print("parent in separators")
                        break
                        

        print(f"Dataset {seed}: {n_path_cancel} path cancellations")
        tot_cancellations.append(n_path_cancel)

    print(f"Avg number of cancellations: {np.mean(tot_cancellations)}")


def test_unfaithfulness_example():
    A_gt = np.array([[0, 1, 1, 1], [0, 0, 1, 0], [0, 0, 0, 0], [0, 1, 0, 0]])
    G_gt = nx.from_numpy_array(A_gt, create_using=nx.DiGraph)
    assert nx.is_directed_acyclic_graph(G_gt)
    print(G_gt.edges())

    A_faithful = np.array([[0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 0], [0, 1, 0, 0]])

    is_dsep = nx.d_separated(
        nx.from_numpy_array(A_faithful, create_using=nx.DiGraph), set([2]), set([0]), set([1, 3])
    )
    print(is_dsep)

    num_nodes = A_gt.shape[0]
    paths_cancelling = 0
    edges_cancelling = 0
    for node in range(num_nodes):
        parents = np.flatnonzero(A_gt[:, node])
        faithful_parents = np.flatnonzero(A_faithful[:, node]) # parents according to unfaithful dependencies
        canceled_causes = np.setdiff1d(parents, faithful_parents) # parents in A_gt but not in A_faithful
        for unfaithful_parent in canceled_causes:
            print(node)
            is_dsep = nx.d_separated(
                nx.from_numpy_array(A_faithful, create_using=nx.DiGraph), set([node]), set([unfaithful_parent.item()]), set(faithful_parents)
            )
            if is_dsep:
                paths_cancelling +=1
            edges_cancelling +=1
    print(f"{paths_cancelling}/{edges_cancelling} path cancellations")

if __name__ == "__main__":
    test_unfaithfulness()
