from typing import List
from ..Graph.Graph import GeneralGraph
from ..Graph.Node import Node
from ..Graph.Edge import Edge
from ..Graph.NodeType import NodeType
from ..Graph.Endpoint import convert_endpoint_to_sign
from collections import deque
import numpy as np
import itertools


def calculate_E_separation_triple(graph: GeneralGraph, verbose = False) -> List:
    nodes = graph.nodes
    independence_triples = []
    for node1 in nodes:
        nodes_without_node1 = [node for node in nodes if node != node1]
        for node2 in nodes:
            for i in range(len(nodes_without_node1)+1):
                cond_sets = itertools.combinations(nodes_without_node1, i)
                for cond_set in cond_sets:
                    bool_help, path = is_E_connected_to(node1, node2, list(cond_set), graph, verbose = verbose)
                    if bool_help == False:
                        independence_triples.append((node1.get_name(), node2.get_name(), [node.get_name() for node in cond_set]))
    return independence_triples


def is_E_connected_to(node1: Node, node2: Node, z: List[Node], graph: GeneralGraph, verbose: bool = True) -> (bool, List[Node]):
    LDG = generate_lifted_dependence_graph(graph)
    if verbose == True:
        print("LDG: ")
        LDG.print_edges()
    LDG.calculate_strongly_connected_components()
    node1_past = LDG.get_node(node1.get_name() + "_0")
    node2_future = LDG.get_node(node2.get_name() + "_1")
    cond_set = []
    for node in z:
        node_past = LDG.get_node(node.get_name() + "_0")
        cond_set.append(node_past)
        if node == node2:
            continue
        else:
            node_future = LDG.get_node(node.get_name() + "_1")
            cond_set.append(node_future)
    bool_help, path = is_connected_to(node1_past, node2_future, cond_set, LDG,connectivity_type="sigma", verbose = verbose)
    return bool_help, path


def generate_lifted_dependence_graph(graph: GeneralGraph):
    all_nodes = []
    directed_edges = graph.directed_edges
    bidirected_edges = graph.bidirected_edges
    for node1 in graph.nodes:
        str_old = node1.get_name()
        node1_past = Node(str_old + "_0", NodeType.OBSERVED)
        node1_future = Node(str_old + "_1", NodeType.OBSERVED)
        all_nodes.append(node1_past)
        all_nodes.append(node1_future)
    lifted_dependence_graph = GeneralGraph(all_nodes)

    for node1 in graph.nodes:
        for node2 in graph.nodes:
            if directed_edges[graph.node_map[node2], graph.node_map[node1]] == 1:
                name1 = node1.get_name()
                name2 = node2.get_name()
                if name1 == name2:
                    lifted_dependence_graph.add_directed_edge(lifted_dependence_graph.get_node(name1 + "_0"), lifted_dependence_graph.get_node(name1 + "_1"))
                else:
                    lifted_dependence_graph.add_directed_edge(lifted_dependence_graph.get_node(name1 + "_0"), lifted_dependence_graph.get_node(name2 + "_0"))
                    lifted_dependence_graph.add_directed_edge(lifted_dependence_graph.get_node(name1 + "_0"), lifted_dependence_graph.get_node(name2 + "_1"))
                    lifted_dependence_graph.add_directed_edge(lifted_dependence_graph.get_node(name1 + "_1"), lifted_dependence_graph.get_node(name2 + "_1"))

    node_list = list(graph.nodes)
    for i in range(len(node_list)):
        for j in range(i+1, len(node_list)):
            node1 = node_list[i]
            node2 = node_list[j]
            if bidirected_edges[graph.node_map[node1], graph.node_map[node2]] == 1:
                name1 = node1.get_name()
                name2 = node2.get_name()
                lifted_dependence_graph.add_bidirected_edge(lifted_dependence_graph.get_node(name1 + "_0"), lifted_dependence_graph.get_node(name2 + "_0"))
                lifted_dependence_graph.add_bidirected_edge(lifted_dependence_graph.get_node(name1 + "_0"), lifted_dependence_graph.get_node(name2 + "_1"))
                lifted_dependence_graph.add_bidirected_edge(lifted_dependence_graph.get_node(name1 + "_1"), lifted_dependence_graph.get_node(name2 + "_0"))
                lifted_dependence_graph.add_bidirected_edge(lifted_dependence_graph.get_node(name1 + "_1"), lifted_dependence_graph.get_node(name2 + "_1"))
    return lifted_dependence_graph


# def is_connected_to(node1: Node, node2: Node, z: List[Node], graph: GeneralGraph, connectivity_type: str ="d", verbose: bool = True) -> (bool, List[Node]): # connectivity_type from ["d", "sigma", "mu"]
#     if node1 in z or node2 in z:
#         return False, []
#     elif node1 == node2: # if the nodes are the same, they are always d-connected (and not in z!)
#         return True, [node1]
#     else:
#         edgenode_deque = deque([]) # list containing tuples of edges and nodes
#         nodes_visited = set()
#         initial_path = [node1]
#         if verbose == True:
#             print("Initial node: ", node1.get_name())
#         for edge in graph.get_edges_node(node1): # loop over all edges connected to node1
#             if edge.get_distal_node(node1) == node2: # if the edge is directly connected to node2, they are d-connected
#                 if verbose == True:
#                     print("edge: ", node1.get_name(), "----", edge.get_distal_node(node1).get_name(), "and", edge.get_distal_node(node1).get_name(), " the wanted endnode")
#                 return True, initial_path + [(edge, node2)]
#             edgenode_deque.append((edge, node1, initial_path)) # otherwise, add the edge and node from where we came added to deque
#             if verbose == True:
#                 print("edge: ", node1.get_name(), "----", edge.get_distal_node(node1).get_name(), "and", node1.get_name(), "added")
#         while len(edgenode_deque) > 0:
#             edge, node_a, path = edgenode_deque.pop() # pop the last element from the deque (edge and node from where we came)
#             node_b = edge.get_distal_node(node_a) # get the node on the other side of the edge "edge" (not node_a)
#             if node_b in nodes_visited:
#                 continue
#             nodes_visited.add(node_b)
#             for edge2 in graph.get_edges_node(node_b): # loop over all edges connected to node_b
#                 node_c = edge2.get_distal_node(node_b)
#                 if node_c == node_a:
#                     if connectivity_type=="d":
#                         continue
#                     elif connectivity_type=="sigma":
#                         continue
#                     elif connectivity_type=="mu":
#                         raise ValueError("mu connectivity is not implemented yet")
#                 if check_triple_configuration(edge, edge2, node_a, z, graph,connectivity_type=connectivity_type, verbose = verbose): # in configuration node_a --edge-- node_b --edge2-- node_c
#                     new_path = path + [(edge,node_b)]
#                     if node_c == node2:
#                         return True, new_path + [(edge2, node_c)]
#                     else:
#                         edgenode_deque.append((edge2, node_b, new_path))
#         return False, []

def get_nodes_from_path(path: List) -> List[Node]: # path of structure [node, (edge, node), (edge, node), ...] and want to extract the nodes
    # print("Haben path:", path)
    if len(path) == 1:
        return [path[0]]
    else:
        nodes = []
        for i in range(len(path)):
            if i == 0:
                node = path[i]
                nodes.append(node)
            else:
                node = path[i][1]
                nodes.append(node)
        return nodes

def is_connected_to(node1: Node, node2: Node, z: List[Node], graph: GeneralGraph, current_path: List[Node] = [], connectivity_type: str ="d", verbose: bool = True) -> (bool, List[Node]): # connectivity_type from ["d", "sigma", "mu"]
    if node1 in z or node2 in z:
        return False, []
    elif node1 == node2: # if the nodes are the same, they are always d-connected (and not in z!)
        return True, [node1]
    else:
        if current_path == []:
            current_path = [node1] # if the current path is empty, just started and the first node is node1
            if verbose == True:
                print("Algorithm just started and Initial node: ", node1.get_name())
            for edge in graph.get_edges_node(node1): # loop over all edges connected to node1
                next_node = edge.get_distal_node(node1)
                if next_node == node2: # if the edge is directly connected to node2, they are d-connected and sigma-connected
                    if verbose == True:
                        print("edge: ", node1.get_name(), "----", edge.get_distal_node(node1).get_name(), "and", edge.get_distal_node(node1).get_name(), " the wanted endnode")
                    return True, current_path + [(edge, node2)]
                else:
                    if verbose == True:
                        print("edge: ", node1.get_name(), "----", edge.get_distal_node(node1).get_name(), "and", node1.get_name(), "added")
                    bool_help, path = is_connected_to(node1, node2, z, graph, current_path + [(edge, next_node)], connectivity_type=connectivity_type, verbose = verbose)
                    if bool_help == True:
                        return bool_help, path
        else:
            if verbose == True: # current path = [node, (edge, node), (edge, node), ...] ends with node_b o-edge-o node_a
                print("Current path: ")
                print_path(current_path)
            edge, node_a = current_path[-1]
            if node_a == node2: # if the last node in the path is node2, we are done
                return True, current_path
            node_b = edge.get_distal_node(node_a)
            for edge_new in graph.get_edges_node(node_a): # loop over all edges connected to node_a
                next_node = edge_new.get_distal_node(node_a)
                if next_node in get_nodes_from_path(current_path): # if the next node is already in the path, skip this edge therefore also not node_b 
                    continue
                else:
                    if check_triple_configuration(edge, edge_new, node_b, z, graph,connectivity_type=connectivity_type, verbose = verbose): # in configuration node_b --edge-- node_a --edge_new-- net_node
                        bool_help, path = is_connected_to(node1, node2, z, graph, current_path + [(edge_new, next_node)], connectivity_type=connectivity_type, verbose = verbose)
                        if bool_help == True:
                            return bool_help, path
    # If no path was found, return False and an empty list.
    return False, []         

# returned_value = is_connected_to(next_node, node2, z, graph, current_path + [(edge, next_node)], connectivity_type=connectivity_type, verbose = verbose)
#                     if returned_value is not None:
#                         bool_help, path = returned_value
#                         if bool_help == True:
#                             return bool_help, path

def generate_DMG(directed_edges_matrix: np.ndarray,symmetric_edges_matrix: np.ndarray,graph_type: str = "DG") -> GeneralGraph: # graph_type from ["DG", "DMG"]
    nodes = [Node("X"+str(i+1), NodeType.OBSERVED) for i in range(directed_edges_matrix.shape[0])]
    graph = GeneralGraph(nodes)
    if graph_type in  ["DG", "DMG"]:
        for i in range(directed_edges_matrix.shape[0]):
            for j in range(directed_edges_matrix.shape[1]):
                if directed_edges_matrix[j,i] == 1:
                    graph.add_directed_edge(nodes[i], nodes[j])
        if graph_type == "DMG":
            for i in range(symmetric_edges_matrix.shape[0]):
                for j in range(i+1, symmetric_edges_matrix.shape[1]):
                    if symmetric_edges_matrix[j,i] == 1:
                        graph.add_bidirected_edge(nodes[i], nodes[j])
    return graph 

# have configuration node_a ---- node_b ---- node_c 
def check_triple_configuration(edge1: Edge, edge2: Edge, node_a: Node, z: List[Node], graph: GeneralGraph,connectivity_type: str ="d", verbose:bool = False) -> bool:
    node_b = edge1.get_distal_node(node_a)
    node_c = edge2.get_distal_node(node_b)
    if verbose == True:
        print("Triple config: ", node_a.get_name(),convert_endpoint_to_sign(str(edge1.get_proximal_endpoint(node_a)), "left")+"--"+convert_endpoint_to_sign(str(edge1.get_proximal_endpoint(node_b)), "right")
                ,node_b.get_name(), convert_endpoint_to_sign(str(edge2.get_proximal_endpoint(node_b)), "left")+"--"+convert_endpoint_to_sign(str(edge2.get_proximal_endpoint(node_c)),"right"),node_c.get_name())
    collider = (str(edge1.get_proximal_endpoint(node_b))=="ARROW" or str(edge1.get_proximal_endpoint(node_b))=="BLUNT") and (str(edge2.get_proximal_endpoint(node_b))=="ARROW" or str(edge2.get_proximal_endpoint(node_b))=="BLUNT")
    if verbose == True:
        print("Is node_b middle a collider: ", collider)
    if connectivity_type == "d":
        if (not collider) and not (node_b in z):
            return True
    elif connectivity_type == "sigma":
        if (not collider):
            left_out_of = (graph.scc_dict[graph.node_map[node_a]] != graph.scc_dict[graph.node_map[node_b]]) and str(edge1.get_proximal_endpoint(node_b))=="TAIL" 
            right_out_of = (graph.scc_dict[graph.node_map[node_b]] != graph.scc_dict[graph.node_map[node_c]]) and str(edge2.get_proximal_endpoint(node_b))=="TAIL"
            blockable_noncollider = left_out_of or right_out_of
            if verbose == True:
                print("left_out_of: ", left_out_of, "right_out_of: ", right_out_of, "blockable_noncollider: ", blockable_noncollider)
            if blockable_noncollider==True and (node_b in z):
                return False
            else: 
                return True
    ancestor = is_ancestor(node_b, z, graph)
    if verbose == True:
        print("Is node_b middle an ancestor: ", ancestor)
    return collider and ancestor

# reimplementation of is_ancestor for cyclic graphs
def is_ancestor(node: Node, z: List[Node], graph: GeneralGraph) -> bool:
    if node in z:
        return True
    already_visited = []
    nodedeque = deque([])
    for node_z in z:
        nodedeque.append(node_z)
    while len(nodedeque) > 0:
        node_t = nodedeque.pop()
        if node_t in already_visited:
            # print("visited: ", node_t.get_name())
            pass
        else:
            already_visited.append(node_t)
            if node_t == node:
                return True
            for node_c in graph.get_parents(node_t):
                if node_c not in nodedeque:
                    nodedeque.append(node_c)
    return False

    
# Check 
def print_path(path: List):
    prev_node = path[0]
    if len(path) == 1:
        path_string = "Trivial path: " + prev_node.get_name()
        print(path_string)
    else:
        path_string = prev_node.get_name()
        for i in range(1, len(path)):
            edge, new_node = path[i]
            path_string += convert_endpoint_to_sign(str(edge.get_proximal_endpoint(prev_node)), "left")+"--"+convert_endpoint_to_sign(str(edge.get_proximal_endpoint(new_node)), "right") + " " + new_node.get_name() + " "
            prev_node = new_node
        print(path_string)