#!/usr/bin/env python3
from __future__ import annotations

from collections import deque
from itertools import permutations
from typing import List, Tuple, Deque

import pydot

from compare_algs.causallearn_package.graph.AdjacencyConfusion import AdjacencyConfusion
from compare_algs.causallearn_package.graph.ArrowConfusion import ArrowConfusion
from compare_algs.causallearn_package.graph.Edge import Edge
from compare_algs.causallearn_package.graph.Edges import Edges
from compare_algs.causallearn_package.graph.Endpoint import Endpoint
from compare_algs.causallearn_package.graph.Graph import Graph
from compare_algs.causallearn_package.graph.Node import Node
from compare_algs.causallearn_package.graph.NodeType import NodeType


class GraphUtils:

    def __init__(self):
        pass

    # Returns true if node1 is d-connected to node2 on the set of nodes z.
    ### CURRENTLY THIS DOES NOT IMPLEMENT UNDERLINE TRIPLE EXCEPTIONS ###
    def is_dconnected_to(self, node1: Node, node2: Node, z: List[Node], graph: Graph):
        if node1 == node2:
            return True

        edgenode_deque = deque([])

        for edge in graph.get_node_edges(node1):
            if edge.get_distal_node(node1) == node2:
                return True
            edgenode_deque.append((edge, node1))

        while len(edgenode_deque) > 0:
            edge, node_a = edgenode_deque.pop()
            node_b = edge.get_distal_node(node_a)

            for edge2 in graph.get_node_edges(node_b):
                node_c = edge2.get_distal_node(node_b)

                if node_c == node_a:
                    continue

                if self.reachable(edge, edge2, node_a, z, graph):
                    if node_c == node2:
                        return True
                    else:
                        edgenode_deque.append((edge2, node_b))

        return False

    def edge_string(self, edge: Edge) -> str:
        node1 = edge.get_node1()
        node2 = edge.get_node2()

        endpoint1 = edge.get_endpoint1()
        endpoint2 = edge.get_endpoint2()

        edge_string = node1.get_name() + " "

        if endpoint1 == Endpoint.TAIL:
            edge_string = edge_string + "-"
        else:
            if endpoint1 == Endpoint.ARROW:
                edge_string = edge_string + "<"
            else:
                edge_string = edge_string + "o"

        edge_string = edge_string + "-"

        if endpoint2 == Endpoint.TAIL:
            edge_string = edge_string + "-"
        else:
            if endpoint2 == Endpoint.ARROW:
                edge_string = edge_string + ">"
            else:
                edge_string = edge_string + "o"

        edge_string = edge_string + " " + node2.get_name()
        return edge_string

    def graph_string(self, graph: Graph) -> str:
        nodes = graph.get_nodes()
        edges = graph.get_graph_edges()

        # nodes.sort()
        # edges.sort()

        graph_string = "Graph Nodes:\n"

        for i in range(len(nodes) - 1):
            node = nodes[i]
            graph_string = graph_string + node.get_name() + ";"

        if len(nodes) > 0:
            graph_string = graph_string + nodes[-1].get_name()

        graph_string = graph_string + "\n\nGraph Edges:\n"

        count = 0
        for edge in edges:
            count = count + 1
            graph_string = graph_string + str(count) + ". " + str(edge) + "\n"

        return graph_string

    # Helper method. Determines if two edges do or do not form a block for d-separation, conditional on a set of nodes z
    # starting from a node a
    def reachable(self, edge1: Edge, edge2: Edge, node_a: Node, z: List[Node], graph: Graph) -> bool:
        node_b = edge1.get_distal_node(node_a)

        collider = str(edge1.get_proximal_endpoint(node_b)) == "ARROW" and str(
            edge2.get_proximal_endpoint(node_b)) == "ARROW"

        if (not collider) and not (node_b in z):
            return True

        ancestor = self.is_ancestor(node_b, z, graph)

        return collider and ancestor

    # Helper method. Determines if a given node is an ancestor of any node in a set of nodes z.
    def is_ancestor(self, node: Node, z: List[Node], graph: Graph) -> bool:
        if node in z:
            return True

        nodedeque = deque([])

        for node_z in z:
            nodedeque.append(node_z)

        while len(nodedeque) > 0:
            node_t = nodedeque.pop()
            if node_t == node:
                return True

            for node_c in graph.get_parents(node_t):
                if node_c not in nodedeque:
                    nodedeque.append(node_c)
        return False

    def get_sepset(self, x: Node, y: Node, graph: Graph) -> List[Node] | None:
        sepset = self.get_sepset_visit(x, y, graph)
        if sepset is None:
            sepset = self.get_sepset_visit(y, x, graph)

        return sepset

    def get_sepset_visit(self, x: Node, y: Node, graph: Graph) -> List[Node] | None:
        if x == y:
            return None

        z: List[Node] = []

        while True:
            _z = z.copy()
            path: List[Node] = [x]
            colliders = []

            for b in graph.get_adjacent_nodes(x):
                if self.sepset_path_found(x, b, y, path, z, graph, colliders):
                    return None

            z.sort()
            _z.sort()
            if z == _z:
                break

        return z

    def sepset_path_found(self, a: Node, b: Node, y: Node, path: List[Node], z: List[Node], graph: Graph,
                          colliders: List[Tuple[Node, Node, Node]]) -> bool:
        if b == y:
            return True

        if b in path:
            return False

        path.append(b)

        if b.get_node_type == NodeType.LATENT or b in z:
            pass_nodes = self.get_pass_nodes(a, b, z, graph, None)

            for c in pass_nodes:
                if self.sepset_path_found(b, c, y, path, z, graph, colliders):
                    path.remove(b)
                    return True

            path.remove(b)
            return False
        else:
            found1 = False
            colliders1 = []
            pass_nodes1 = self.get_pass_nodes(a, b, z, graph, colliders1)

            for c in pass_nodes1:
                if self.sepset_path_found(b, c, y, path, z, graph, colliders1):
                    found1 = True
                    break

            if not found1:
                path.remove(b)
                colliders.extend(colliders1)
                return False

            z.append(b)
            found2 = False
            colliders2: List[Tuple[Node, Node, Node]] = []
            pass_nodes2 = self.get_pass_nodes(a, b, z, graph, None)

            for c in pass_nodes2:
                if self.sepset_path_found(b, c, y, path, z, graph, colliders2):
                    found2 = True
                    break

            if not found2:
                path.remove(b)
                colliders.extend(colliders2)
                return False

            z.remove(b)
            path.remove(b)
            return True

    def get_pass_nodes(self, a: Node, b: Node, z: List[Node], graph: Graph,
                       colliders: List[Tuple[Node, Node, Node]] | None) -> List[Node]:
        pass_nodes: List[Node] = []

        for c in graph.get_adjacent_nodes(b):
            if c == a:
                continue

            if self.node_reachable(a, b, c, z, graph, colliders):
                pass_nodes.append(c)

        return pass_nodes

    def node_reachable(self, a: Node, b: Node, c: Node, z: List[Node], graph: Graph,
                       colliders: List[Tuple[Node, Node, Node]] | None) -> bool:
        collider = graph.is_def_collider(a, b, c)

        if not collider and not (b in z):
            return True

        ancestor = self.is_ancestor(b, z, graph)

        collider_reachable = collider and ancestor

        if colliders is not None and collider and not ancestor:
            colliders.append((a, b, c))

        return collider_reachable

    # Returns a tiered ordering of variables in an acyclic graph. THIS ALGORITHM IS NOT ALWAYS CORRECT.
    def get_causal_order(self, graph: Graph) -> List[Node]:
        if graph.exists_directed_cycle():
            raise ValueError("Graph must be acyclic.")

        found: List[Node] = []
        not_found: List[Node] = graph.get_nodes()
        sub_not_found: List[Node] = []

        for node in not_found:
            if node.get_node_type() == NodeType.ERROR:
                sub_not_found.append(node)

        not_found = [e for e in not_found if e not in sub_not_found]

        all_nodes = not_found.copy()

        while len(not_found) > 0:
            sub_not_found: List[Node] = []
            for node in not_found:
                # print(node)
                parents = graph.get_parents(node)
                sub_parents: List[Node] = []
                for node1 in parents:
                    if not (node1 in all_nodes):
                        sub_parents.append(node1)

                parents = [e for e in parents if e not in sub_parents]

                if all(node1 in found for node1 in parents):
                    found.append(node)
                    sub_not_found.append(node)

            not_found = [e for e in not_found if e not in sub_not_found]

        return found

    def find_unshielded_triples(self, graph: Graph):
        """Return the list of unshielded triples i o-o j o-o k in adjmat as (i, j, k)"""
        from compare_algs.causallearn_package.graph.Dag import Dag
        if not isinstance(graph, Dag):
            raise ValueError("graph must be a DAG")
        triples = []

        for pair in permutations(graph.get_graph_edges(), 2):
            node1 = pair[0].get_node1()
            node2 = pair[0].get_node2()
            node3 = pair[1].get_node1()
            node4 = pair[1].get_node1()

            node_map = graph.get_node_map()

            if node1 == node3:
                if node2 != node4 and graph.get_adjacency_matrix()[node_map[node2], node_map[node4]] == 0:
                    triples.append((node2, node1, node4))
                    continue
            if node1 == node4:
                if node2 != node3 and graph.get_adjacency_matrix()[node_map[node2], node_map[node3]] == 0:
                    triples.append((node2, node1, node3))
                    continue
            if node2 == node3:
                if node1 != node4 and graph.get_adjacency_matrix()[node_map[node1], node_map[node4]] == 0:
                    triples.append((node1, node2, node4))
                    continue
            if node2 == node4:
                if node2 != node3 and graph.get_adjacency_matrix()[node_map[node2], node_map[node3]] == 0:
                    triples.append((node1, node2, node3))

        return triples

    #    return [(pair[0].get_node1(), pair[0].get_node2(), pair[1].get_node2) for pair in permutations(graph.get_graph_edges(), 2)
    #            if pair[0].get_node2() == pair[1].get_node1() and pair[0].get_node1() != pair[1].get_node2() and graph.get_adjacency_matrix()[graph.get_node_map()[pair[0].get_node1()], graph.get_node_map()[pair[1].get_node2()]] == -1]

    def find_triangles(self, graph: Graph) -> List[Tuple[Node, Node, Node]]:
        """Return the list of triangles i o-o j o-o k o-o i in adjmat as (i, j, k) [with symmetry]"""
        Adj = graph.get_graph_edges()
        triangles: List[Tuple[Node, Node, Node]] = []

        for pair in permutations(Adj, 2):
            node1 = pair[0].get_node1()
            node2 = pair[0].get_node2()
            node3 = pair[1].get_node1()
            node4 = pair[1].get_node2()

            if node1 == node3:
                if graph.is_adjacent_to(node2, node4):
                    triangles.append((node2, node1, node4))
                    continue
            if node1 == node4:
                if graph.is_adjacent_to(node2, node3):
                    triangles.append((node2, node1, node3))
                    continue
            if node2 == node3:
                if graph.is_adjacent_to(node1, node4):
                    triangles.append((node1, node2, node4))
                    continue
            if node2 == node4:
                if graph.is_adjacent_to(node1, node3):
                    triangles.append((node1, node2, node3))

        return triangles

    #    return [(pair[0].get_node1(), pair[0].get_node2(), pair[1].get_node2) for pair in permutations(Adj, 3)
    #            if pair[0].get_node2 == pair[1].get_node1() and pair[0].get_node1() != pair[1].get_node2() and (pair[0][0], pair[1][1]) in Adj]

    def find_kites(self, graph) -> List[Tuple[Node, Node, Node, Node]]:
        kites: List[Tuple[Node, Node, Node, Node]] = []
        for pair in permutations(self.find_triangles(graph), 2):
            if (pair[0][0] == pair[1][0]) and (pair[0][2] == pair[1][2]) and (
                    graph.node_map[pair[0][1]] < graph.node_map[pair[1][1]]) and (
                    graph.graph[graph.node_map[pair[0][1]], graph.node_map[pair[1][1]]] == 0):
                kites.append((pair[0][0], pair[0][1], pair[1][1], pair[0][2]))

        return kites

        # return [(pair[0][0], pair[0][1], pair[1][1], pair[0][2]) for pair in permutations(self.findTriangles(), 2)
        #        if pair[0][0] == pair[1][0] and pair[0][2] == pair[1][2]
        #        and pair[0][1] < pair[1][1] and self.adjmat[pair[0][1], pair[1][1]] == -1]

    def sdh(self, graph1: Graph, graph2: Graph) -> int:
        nodes = graph1.get_nodes()
        error = 0

        for i1 in list(range(1, graph1.get_num_nodes())):
            for i2 in list(range(i1 + 1, graph1.get_num_nodes())):
                e1 = graph1.get_edge(nodes[i1], nodes[i2])
                e2 = graph2.get_edge(nodes[i1], nodes[i2])
                error = error + self.shd_one_edge(e1, e2)

        return error

    def shd_one_edge(self, e1: Edge, e2: Edge) -> int:
        if self.no_edge(e1) and self.undirected(e2):
            return 1
        elif self.no_edge(e2) and self.undirected(e1):
            return 1
        elif self.no_edge(e1) and self.directed(e2):
            return 2
        elif self.no_edge(e2) and self.directed(e1):
            return 2
        elif self.undirected(e1) and self.directed(e2):
            return 1
        elif self.undirected(e2) and self.directed(e1):
            return 1
        elif self.directed(e1) and self.directed(e2):
            if e1.get_endpoint1() == e2.get_endpoint2():
                return 1
        elif self.bi_directed(e1) or self.bi_directed(e2):
            return 2
        return 0

    def no_edge(self, e: Edge | None) -> bool:
        return e is None

    def undirected(self, e: Edge) -> bool:
        return e.get_endpoint1() == Endpoint.TAIL and e.get_endpoint2() == Endpoint.TAIL

    def directed(self, e: Edge) -> bool:
        return (e.get_endpoint1() == Endpoint.TAIL and e.get_endpoint2() == Endpoint.ARROW) \
               or (e.get_endpoint1() == Endpoint.ARROW and e.get_endpoint2() == Endpoint.TAIL)

    def bi_directed(self, e: Edge) -> bool:
        return e.get_endpoint1() == Endpoint.ARROW and e.get_endpoint2() == Endpoint.ARROW

    def adj_precision(self, truth: Graph, est: Graph) -> float:
        confusion = AdjacencyConfusion(truth, est)
        return confusion.get_adj_tp() / (confusion.get_adj_tp() + confusion.get_adj_fp())

    def adj_recall(self, truth: Graph, est: Graph) -> float:
        confusion = AdjacencyConfusion(truth, est)
        return confusion.get_adj_tp() / (confusion.get_adj_tp() + confusion.get_adj_fn())

    def arrow_precision(self, truth: Graph, est: Graph) -> float:
        confusion = ArrowConfusion(truth, est)
        return confusion.get_arrows_tp() / (confusion.get_arrows_tp() + confusion.get_arrows_fp())

    def arrow_recall(self, truth: Graph, est: Graph) -> float:
        confusion = ArrowConfusion(truth, est)
        return confusion.get_arrows_tp() / (confusion.get_arrows_tp() + confusion.get_arrows_fn())

    def arrow_precision_common_edges(self, truth: Graph, est: Graph) -> float:
        confusion = ArrowConfusion(truth, est)
        return confusion.get_arrows_tp() / (confusion.get_arrows_tp() + confusion.get_arrows_fp_ce())

    def arrow_recall_common_edges(self, truth: Graph, est: Graph) -> float:
        confusion = ArrowConfusion(truth, est)
        return confusion.get_arrows_tp() / (confusion.get_arrows_tp() + confusion.get_arrows_fn_ce())

    def exists_directed_path_from_to_breadth_first(self, node_from: Node, node_to: Node, G: Graph) -> bool:
        Q: Deque[Node] = deque()
        V: List[Node] = [node_from]
        Q.append(node_from)

        while len(Q) > 0:
            t = Q.pop()
            for u in G.get_adjacent_nodes(t):
                if G.is_parent_of(t, u) and G.is_parent_of(u, t):
                    return True

                edge = G.get_edge(t, u)
                edges = Edges()
                c = edges.traverse_directed(t, edge)

                if c is None:
                    continue
                if c in V:
                    continue
                if c == node_to:
                    return True

                V.append(c)
                Q.append(c)

    @staticmethod
    def to_pgv(G: Graph, title: str = ""):
        # warnings.warn("GraphUtils.to_pgv() is deprecated", DeprecationWarning)
        import pygraphviz as pgv
        graphviz_g = pgv.AGraph(directed=True)
        graphviz_g.graph_attr['label'] = title
        graphviz_g.graph_attr['labelfontsize'] = 18
        nodes = G.get_nodes()
        for i, node in enumerate(nodes):
            graphviz_g.add_node(i)
            graphviz_g.get_node(i).attr['label'] = node.get_name()
            if node.get_node_type() == NodeType.LATENT:
                graphviz_g.get_node(i).attr['shape'] = 'square'

        def get_g_arrow_type(endpoint):
            if endpoint == Endpoint.TAIL:
                return 'none'
            elif endpoint == Endpoint.ARROW:
                return 'normal'
            elif endpoint == Endpoint.CIRCLE:
                return 'odot'
            else:
                raise NotImplementedError()

        for edge in G.get_graph_edges():
            if not edge:
                continue
            node1 = edge.get_node1()
            node2 = edge.get_node2()
            node1_id = nodes.index(node1)
            node2_id = nodes.index(node2)
            graphviz_g.add_edge(node1_id, node2_id)
            g_edge = graphviz_g.get_edge(node1_id, node2_id)
            g_edge.attr['dir'] = 'both'

            g_edge.attr['arrowtail'] = get_g_arrow_type(edge.get_endpoint1())
            g_edge.attr['arrowhead'] = get_g_arrow_type(edge.get_endpoint2())

        return graphviz_g

    @staticmethod
    def to_pydot(G: Graph, edges: List[Edge] | None = None, labels: List[str] | None = None, title: str = "", dpi: float = 200):
        '''
        Convert a graph object to a DOT object.

        Parameters
        ----------
        G : Graph
            A graph object of causal-learn
        edges : list, optional (default=None)
            Edges list of graph G
        labels : list, optional (default=None)
            Nodes labels of graph G
        title : str, optional (default="")
            The name of graph G
        dpi : float, optional (default=200)
            The dots per inch of dot object
        Returns
        -------
        pydot_g : Dot
        '''

        nodes = G.get_nodes()
        if labels is not None:
            assert len(labels) == len(nodes)

        pydot_g = pydot.Dot(title, graph_type="digraph", fontsize=18)
        pydot_g.obj_dict["attributes"]["dpi"] = dpi
        
        for i, node in enumerate(nodes):
            node_name = labels[i] if labels is not None else node.get_name()
            if node.get_node_type() == NodeType.LATENT:
                pydot_g.add_node(pydot.Node(i, label=node_name, shape='square'))
            else:
                pydot_g.add_node(pydot.Node(i, label=node_name))

        def get_g_arrow_type(endpoint):
            if endpoint == Endpoint.TAIL:
                return 'none'
            elif endpoint == Endpoint.ARROW:
                return 'normal'
            elif endpoint == Endpoint.CIRCLE:
                return 'odot'
            else:
                raise NotImplementedError()

        if edges is None:
            edges = G.get_graph_edges()

        for edge in edges:
            node1 = edge.get_node1()
            node2 = edge.get_node2()
            node1_id = nodes.index(node1)
            node2_id = nodes.index(node2)
            dot_edge = pydot.Edge(node1_id, node2_id, dir='both', arrowtail=get_g_arrow_type(edge.get_endpoint1()),
                                  arrowhead=get_g_arrow_type(edge.get_endpoint2()))

            if Edge.Property.dd in edge.properties:
                dot_edge.obj_dict["attributes"]["color"] = "green3"

            if Edge.Property.nl in edge.properties:
                dot_edge.obj_dict["attributes"]["penwidth"] = 2.0

            pydot_g.add_edge(dot_edge)

        return pydot_g
