from abc import ABC
from typing import List, Dict
from numpy import ndarray
import numpy as np
from .Node import Node
from .Edge import Edge
from .Endpoint import Endpoint


class GeneralGraph(ABC):
    def __init__(self, nodes: List[Node]):
        self.nodes: List[Node] = nodes
        self.num_vars: int = len(nodes)

        node_map: Dict[Node, int] = {node: i for i, node in enumerate(nodes)}
        self.node_map: Dict[Node, int] = node_map

        self.scc_dict = None

        # A dictionary where keys are node1, and values are dictionaries where keys are node2 and values are lists of edges.
        # self.graph: Dict[Node, Dict[Node, List[Edge]]] = {}
        # Initialize the graph dictionary with all nodes
        self.graph: Dict[Node, Dict[Node, List[Edge]]] = {node: {} for node in nodes}

        self.directed_edges: ndarray = np.zeros((self.num_vars, self.num_vars), np.dtype(int))
        self.bidirected_edges: ndarray = np.zeros((self.num_vars, self.num_vars), np.dtype(int))
        self.blunt_edges: ndarray = np.zeros((self.num_vars, self.num_vars), np.dtype(int))

        self.adjacency_matrix = np.zeros((self.num_vars, self.num_vars), np.dtype(int))
        
        self.pag = False

    def get_node(self, str: str) -> Node:
        for node in self.nodes:
            if node.get_name() == str:
                return node
        return None 

    def print_edges(self):
        list_of_edges = set()
        for node1 in self.graph.keys():
            for node2 in self.graph[node1].keys():
                for edge in self.graph[node1][node2]:

                    list_of_edges.add(edge)
        for edge in list_of_edges:
            print(edge, edge.get_endpoint1(), edge.get_endpoint2())
            edge.print_edge()

    def add_edge(self, edge: Edge):
        node1, node2 = edge.get_node1(), edge.get_node2()
        if node1 not in self.graph:
            self.graph[node1] = {}
        if node2 not in self.graph[node1]:
            self.graph[node1][node2] = []

        if node2 not in self.graph:
            self.graph[node2] = {}
        if node1 not in self.graph[node2]:
            self.graph[node2][node1] = []

        # Add the edge between node1 and node2
        self.graph[node1][node2].append(edge)
        self.graph[node2][node1].append(edge) # note sure whether this really helps!

    def add_directed_edge(self, node1: Node, node2: Node):
        edge = Edge(node1, node2, Endpoint.TAIL, Endpoint.ARROW)
        self.add_edge(edge)
        self.directed_edges[self.node_map[node2], self.node_map[node1]] = 1

    def add_bidirected_edge(self, node1: Node, node2: Node):
        edge = Edge(node1, node2, Endpoint.ARROW, Endpoint.ARROW)
        self.add_edge(edge)
        self.bidirected_edges[self.node_map[node1], self.node_map[node2]] = 1
        self.bidirected_edges[self.node_map[node2], self.node_map[node1]] = 1

    def add_blunt_edge(self, node1: Node, node2: Node):
        edge = Edge(node1, node2, Endpoint.BLUNT, Endpoint.BLUNT)
        self.add_edge(edge)
        self.blunt_edges[self.node_map[node1], self.node_map[node2]] = 1
        self.blunt_edges[self.node_map[node2], self.node_map[node1]] = 1

    def add_undirected_edge(self, node1: Node, node2: Node):
        edge = Edge(node1, node2, Endpoint.TAIL, Endpoint.TAIL)
        self.add_edge(edge)

    def add_nondirected_edge(self, node1: Node, node2: Node):
        edge = Edge(node1, node2, Endpoint.CIRCLE, Endpoint.CIRCLE)
        self.add_edge(edge)


    def add_partially_oriented_edge(self, node1: Node, node2: Node):
        edge = Edge(node1, node2, Endpoint.CIRCLE, Endpoint.ARROW)
        self.add_edge(edge)

    # Check
    def get_parents(self, node: Node) -> List[Node]:
        parents = []
        parents_indices = self.directed_edges[self.node_map[node], :].nonzero()[0]
        for parent_index in parents_indices:
            parents.append(self.nodes[parent_index])
        return parents

    # Check 
    def get_children(self, node: Node) -> List[Node]:
        children = []
        children_indices = self.directed_edges[:, self.node_map[node]].nonzero()[0]
        for child_index in children_indices:
            children.append(self.nodes[child_index])
        return children
    
    def get_edges_node(self, node: Node) -> List[Edge]:
        all_edges = []
        for other_node in self.graph[node].keys():
            for edge in self.graph[node][other_node]:
                all_edges.append(edge)
        return all_edges
    

    # Check: depth-first search inspired by https://www.geeksforgeeks.org/strongly-connected-components/#brute-force-approach-for-finding-strongly-connected-components
    def dfs(self, start_node: Node, destination_node: Node, visited: List[Node]):
        if start_node == destination_node:
            return True
        visited.append(start_node)
        for node in self.get_children(start_node):
            if node not in visited:
                if self.dfs(node, destination_node= destination_node, visited = visited):
                    return True
        return False
    
    # Check: 
    def exists_directed_path(self, start_node: Node, destination_node: Node) -> bool:
        visited = []
        return self.dfs(start_node, destination_node, visited)

    # Check: 
    def calculate_strongly_connected_components(self):
        if self.scc_dict == None:
            visited = [0] * self.num_vars
            running_scc_number = 1
            for i in range(self.num_vars):
                if visited[i] == 0: # the node is not assigned to a strongly connected component
                    visited[i] = running_scc_number # assign the node to the current strongly connected component
                    for j in range(i+1, self.num_vars):
                        if visited[j] == 0:
                            if self.exists_directed_path(self.nodes[i], self.nodes[j]) and self.exists_directed_path(self.nodes[j], self.nodes[i]):
                                visited[j] = running_scc_number 
                    running_scc_number += 1
            self.scc_dict = visited
