
from collections import deque
from itertools import combinations
from causallearn.graph.Node import Node
from typing import List, Set, Tuple, Dict
from causallearn.graph.Graph import Graph
from causallearn.graph.Edge import Edge


def markEssentialEdges_by_conditioning_set(essential_edges, D, D_copy, cond_sets, nodes, essential_edges_name, essential_edges_by_order):
    # create two ancestral graphs
    edges = D.get_graph_edges()
    cond_size = 0
    #print("All edges are : {}".format(edges))
    for edge in edges:
        node1 = edge.get_node1()
        node2 = edge.get_node2()
        if edge is None or (node1, node2) in essential_edges.required_rules_specs:
            continue
        node1name = D.node_map[node1]
        node2name = D.node_map[node2]
        D_copy.remove_edge(edge)

        # get the list of nodes
        found = False
        # search through all zero orders
        cond1= is_dconnected_to(node1, node2, [], D)
        cond2 = is_dconnected_to(node1, node2,[], D_copy)
        if cond1 and not cond2:
            essential_edges.add_required_by_node(node1, node2)
            essential_edges_name.append((node1name, node2name))
            essential_edges_by_order[cond_size].append((node1name, node2name))
            D_copy.add_edge(edge)
            found = True
        
        if found:
            continue 
        # if zero order has not found any essential edges, move on to find essential edges based on the found conditioning set
        idx = 0
        while not found and idx < len(cond_sets):
            first_node_idx, second_node_idx, new_cond_set_indices = cond_sets[idx]
            first_node = nodes[first_node_idx]
            second_node = nodes[second_node_idx]
            w = [nodes[ele] for ele in new_cond_set_indices]
            #print("sep_set:{}".format(z))
            #cond1 = D.is_dconnected_to(node1, node2, list(z))
            cond1= is_dconnected_to(first_node, second_node, list(w), D)
            #cond2 =  D_copy.is_dconnected_to(node1, node2, list(z))
            cond2 = is_dconnected_to(first_node, second_node, list(w), D_copy)
            if cond1 and not cond2:
                essential_edges.add_required_by_node(node1, node2)
                essential_edges_name.append((node1name, node2name))
                essential_edges_by_order[len(w)].append((node1name, node2name))
                D_copy.add_edge(edge)
                found = True
                break
            idx = idx + 1 
        # add back the edge
        if not found:
            D_copy.add_edge(edge)


def markEssentialEdges(essential_edges, D, D_copy, cond_size, essential_edges_name, essential_edges_by_order):
    # create two ancestral graphs
    edges = D.get_graph_edges()
    #print("All edges are : {}".format(edges))
    for edge in edges:
        node1 = edge.get_node1()
        node2 = edge.get_node2()
        if edge is None or (node1, node2) in essential_edges.required_rules_specs:
            continue
        node1name = D.node_map[node1]
        node2name = D.node_map[node2]
        D_copy.remove_edge(edge)
        #print("D:{}".format(D.graph))
        #print("D_copy:{}".format(D_copy.graph))
        # get the list of nodes
        node1_adj = D_copy.get_adjacent_nodes(node1)
        node2_adj = D_copy.get_adjacent_nodes(node2)
        nodes_to_be_chosen_for_conditioning = node1_adj + node2_adj
        # nodes_to_be_chosen_for_conditioning = [z for z in nodes_to_be_chosen_for_conditioning  if D.node_map[z] != node1name and D.node_map[z] != node2name]
        w = 0
        found = False
        while not found and w <= cond_size:
            for z in combinations(nodes_to_be_chosen_for_conditioning, w):
                #print("sep_set:{}".format(z))
                cond1= is_dconnected_to(node1, node2, list(z), D)
                cond2 = is_dconnected_to(node1, node2, list(z), D_copy)
                if cond1 and not cond2:
                    essential_edges.add_required_by_node(node1, node2)
                    essential_edges_name.append((node1name, node2name))
                    essential_edges_by_order[cond_size].append((node1name, node2name))
                    found = True
                    break
            w = w + 1
        D_copy.add_edge(edge)


def is_dconnected_to(node1: Node, node2: Node, z: List[Node], graph: Graph):
    if node1 == node2:
        return True
    node1name = graph.node_map[node1]
    node2name = graph.node_map[node2]
    edgenode_deque = deque([])
    prevent_recursive_ls = []
    for edge in graph.get_node_edges(node1):
        #print("edge of node 1:{}".format(edge))
        if edge.get_distal_node(node1) == node2:
            # if node1 - node2 is in the graph, we know they are d-connected
            return True 
        # if not, we add this pair (edge, node1) to a list
        edgenode_deque.append((edge, node1))
    # so long there is an edge that leads to different node until that node becomes the "node 2"
    while len(edgenode_deque) > 0:
        # get the first pair
        edge, node_a = edgenode_deque.pop()
        # get another end node of node_a 
        node_b = edge.get_distal_node(node_a)
        # node_a is not node_2
        # node_a - node_b
        for edge2 in graph.get_node_edges(node_b):

            node_c = edge2.get_distal_node(node_b)
            
            # node_a - node_b - node_c
            if node_c == node_a:
                # this step prevents all permutations
                continue

            if reachable(edge, edge2, node_a, z, graph):
                if node_c == node2:
                    return True
                else:
                    if (edge2, node_b) not in prevent_recursive_ls:
                        edgenode_deque.append((edge2, node_b))
                        prevent_recursive_ls.append((edge2, node_b))
        


    return False

def reachable(edge1: Edge, edge2: Edge, node_a: Node, z: List[Node], graph: Graph) -> bool:
        node_b = edge1.get_distal_node(node_a)

        collider = str(edge1.get_proximal_endpoint(node_b)) == "ARROW" and str(
            edge2.get_proximal_endpoint(node_b)) == "ARROW"

        if (not collider) and not (node_b in z):
            return True
        ancestor = is_ancestor(node_b, z, graph)
        return collider and ancestor

def is_ancestor(node: Node, z: List[Node], graph: Graph) -> bool:
        if node in z:
            return True
        visited = []
        nodedeque = deque([])

        for node_z in z:
            nodedeque.append(node_z)

        while len(nodedeque) > 0:
            node_t = nodedeque.pop()
            if node_t == node:
                return True

            for node_c in graph.get_parents(node_t):
                if node_c not in nodedeque and node_c not in visited:
                    nodedeque.append(node_c)
                    visited.append(node_c)
        return False
