from itertools import combinations
import networkx as nx
from Utils.Graph_utils import *
from typing import Optional, Set, Tuple, Union, List
import numpy as np
import pandas as pd
from collections import defaultdict



class MixGraph(nx.Graph):
    def __init__(self, incoming_graph_data: Optional[Union[pd.DataFrame, np.ndarray, list]] = None):
        """
        Initialize a Mix Graph based on networkx.Graph.
        If incoming_graph_data is provided, it will be used to create the graph.
        pd.DataFrame or np.ndarray means the adjacency matrix. if list, it means the node list.
        """
        super().__init__()
        self._cache_path = defaultdict(dict)  # for caching results of get...paths functions
        self._cache_nodes = defaultdict(dict)  # for caching results of get...nodes functions
        if incoming_graph_data is not None:
            if isinstance(incoming_graph_data, pd.DataFrame):
                self.from_pandas_adjacency(incoming_graph_data)
            elif isinstance(incoming_graph_data, np.ndarray):
                self.from_numpy_array(incoming_graph_data)
            elif isinstance(incoming_graph_data, list):
                self.from_node_list(incoming_graph_data)
            else:
                raise TypeError("Invalid graph data type.")

    def clear_cache(self):
        """Clear the cache."""
        self._cache_path.clear()
        self._cache_nodes.clear()

    ####################
    ## Edge functions ##
    ####################

    def _ensure_order(self, node1: Node, node2: Node) -> tuple[Node, Node]:
        """
        Ensure node1 < node2 based on the integer mapping.
        """
        if node1.index > node2.index:
            return node2, node1
        return node1, node2


    def add_Edge(self, node1: Node, node2: Node, edge: Edge):
        """
        Add an edge to the graph. 
        The parent function that for add different types of edges.
        """
        self.clear_cache()
        if not isinstance(node1, Node) or not isinstance(node2, Node):
            raise TypeError("node1 and node2 must be of type Node.")
        if not isinstance(edge, Edge):
            raise TypeError("edge must be of type Edge.")
        if not self.has_node(node1) or not self.has_node(node2):
            raise ValueError("Both nodes must exist in the graph before adding an edge.")
        
        # If an edge exists, remove it first
        if self.has_edge(node1, node2):
            raise ValueError(f"Edge between {node1} and {node2} already exists. Use update_Edge() to update it.")
        
        # Ensure node1.index < node2.index
        n1, n2 = self._ensure_order(node1, node2)
        if (n1, n2) == (node2, node1):
            edge._invert()  # Invert the edge if the order is reversed

        self.add_edge(node1, node2, edge=edge)
        

    def remove_Edge(self, node1: Node, node2: Node):
        """
        Remove an edge between two nodes if it exists.
        """
        self.clear_cache()
        if not isinstance(node1, Node) or not isinstance(node2, Node):
            raise TypeError("node1 and node2 must be of type Node.")
        if not self.has_node(node1) or not self.has_node(node2):
            raise ValueError("Both nodes must exist in the graph before removing an edge.")
        
        if self.has_edge(node1, node2):
            self.remove_edge(node1, node2)
        else:
            raise ValueError(f"Edge between {node1} and {node2} does not exist.")

        


    def add_circ_Edge(self, node1: Node, node2: Node):
        """
        Set an Undirected edge between two nodes with CIRCLE marks. If the edge already exists, it will not be replaced.
        node1 o-o node2
        Parameters:
        - node1: The first node.
        - node2: The second node.
        """
        self.clear_cache()
        # Add the new edge with the specified marks
        edge = Edge(node1, Mark.CIRCLE, Mark.CIRCLE, node2)
        self.add_Edge(node1, node2, edge=edge)

        

    def add_directed_Edge(self, node1: Node, node2: Node):
        """
        Add a directed edge between two nodes if it does not already exist. If the edge already exists, it will not be replaced.
        node1 -> node2
        Parameters:
        - node1: The first node.
        - node2: The second node.  
        """
        self.clear_cache()
        edge = Edge(node1, Mark.TAIL, Mark.ARROW, node2)
        # Add the new edge with the specified marks
        self.add_Edge(node1, node2, edge=edge)

        

    def add_bidirected_Edge(self, node1: Node, node2: Node):
        """
        Add a bidirected edge between two nodes if it does not already exist. If the edge already exists, it will not be replaced.
        node1 <-> node2
        Parameters:
        - node1: The first node.
        - node2: The second node.
        """
        self.clear_cache()
        edge = Edge(node1, Mark.ARROW, Mark.ARROW, node2)
        # Add the new edge with the specified marks
        self.add_Edge(node1, node2, edge=edge)

        

    def add_circ_arrow_Edge(self, node1: Node, node2: Node):
        """
        Add a directed edge between two nodes with CIRCLE marks. If the edge already exists, it will not be replaced.
        node1 o-> node2
        Parameters:
        - node1: The first node.
        - node2: The second node.
        """
        self.clear_cache()
        edge = Edge(node1, Mark.CIRCLE, Mark.ARROW, node2)
        # Add the new edge with the specified marks
        self.add_Edge(node1, node2, edge=edge)

        

    def add_circ_tail_Edge(self, node1: Node, node2: Node):
        """
        Add a directed edge between two nodes with CIRCLE marks. If the edge already exists, it will not be replaced.
        node1 o-- node2
        Parameters:
        - node1: The first node.
        - node2: The second node.
        """
        self.clear_cache()

        edge = Edge(node1, Mark.CIRCLE, Mark.TAIL, node2)
        # Add the new edge with the specified marks
        self.add_Edge(node1, node2, edge=edge)

        

    def add_tail_Edge(self, node1: Node, node2: Node):
        """
        Add a directed edge between two nodes with TAIL marks. If the edge already exists, it will not be replaced.
        node1 --- node2
        Parameters:
        - node1: The first node.
        - node2: The second node.
        """
        self.clear_cache()

        edge = Edge(node1, Mark.TAIL, Mark.TAIL, node2)
        # Add the new edge with the specified marks
        self.add_Edge(node1, node2, edge=edge)

    def clear_all_orientations(self):
        """
        Clear all edge orientations in the graph.
        """
        for u, v in self.edges():
            self.update_Edge(u, Mark.CIRCLE, Mark.CIRCLE, v)
 


    def update_Edge(self, node1: Node, lmark: Mark, rmark: Mark, node2: Node):
        """
        Orient an edge between two nodes with specific marks. If the edge does not exist, it will not be created.

        Parameters:
        - node1: The first node.
        - node2: The second node.
        - lmark: The left mark.
        - rmark: The right mark.
        """
        self.clear_cache()

        if not self.has_node(node1) or not self.has_node(node2):
            raise ValueError("Both nodes must exist in the graph before adding an edge.")

        if self.has_edge(node1, node2):

            # Ensure node1 < node2
            n1, n2 = self._ensure_order(node1, node2)
            if (n1, n2) == (node2, node1):
                lmark, rmark = rmark, lmark  # Swap marks if the order is reversed
                node1, node2 = node2, node1  # Swap nodes

            # Update the existing edge with new marks
            edge = self[node1][node2]['edge'].copy()  # Copy the existing edge object
            if lmark is not None: 
                edge.lmark = lmark
            if rmark is not None:
                edge.rmark = rmark
            self.add_edge(node1, node2, edge=edge)  # Update the edge in the graph
        else:
            raise ValueError(f"Edge between {node1} and {node2} does not exist. Use add_xx_edge() to create it.")

        

    def get_Edge(self, start: Node, end: Node) -> Optional[Edge]:
        """
        Retrieve the edge object between two nodes if it exists.
        start --- end

        Parameters:
        - start: The starting node of the edge.
        - end: The ending node of the edge.

        Returns:
        - The Edge object (start--edge--end) if the edge exists, otherwise None.
        """

        if self.has_edge(start, end):
            if self[start][end]['edge'] is not None:
                edge = self[start][end]['edge'].copy()
            else:
                raise ValueError(f"get_Edge error, 'edge' between {start} and {end} does not exist.")
            if edge.start == end and edge.end == start:
                edge._invert()

            if edge.start != start or edge.end != end:
                raise ValueError(f"get_Edge error, Edge between {start} and {end} is not in the correct order.")
            return  edge # Invert the edge if the order is reversed
        return None

    def has_directed_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is a directed edge from the start node to the end node.  
        start -> end

        Parameters:
        - start: The source node.
        - end: The target node.

        Returns:
        - True if a directed edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark == Mark.TAIL and edge.rmark == Mark.ARROW)
        return False

    def has_bidirected_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is a bidirected edge between two nodes.  
        start <-> end

        Parameters:
        - start: The first node.
        - end: The second node.

        Returns:
        - True if a bidirected edge exists.
        - False otherwise.
        """
        
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark == Mark.ARROW and edge.rmark == Mark.ARROW)
        return False

    def has_into_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is an edge into the start node from the end node.  
        start *-> end

        Parameters:
        - start: The source node.
        - end: The target node.

        Returns:
        - True if a into edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.rmark == Mark.ARROW)
        return False
    
    def has_pd_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is a potential directed edge between two nodes.
        start (o-)-- (o>) end

        Parameters:
        - start: The source node.
        - end: The target node.

        Returns:
        - True if a potential directed edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark in {Mark.CIRCLE, Mark.TAIL} and edge.rmark in {Mark.CIRCLE, Mark.ARROW})
        return False

    def has_out_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is an edge out of the start node to the end node.  
        start --* end

        Parameters:
        - start: The source node.
        - end: The target node.

        Returns:
        - True if an out edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark == Mark.TAIL)
        return False
    
    def has_circ_star_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is a circ-star edge between two nodes.  
        start o-* end

        Parameters:
        - start: The first node.
        - end: The second node.

        Returns:
        - True if a circ-star edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark == Mark.CIRCLE)
        return False
    
    def has_tail_circ_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is a tail-circ edge between two nodes.  
        start --o end

        Parameters:
        - start: The first node.
        - end: The second node.

        Returns:
        - True if a tail-circ edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark == Mark.TAIL and edge.rmark == Mark.CIRCLE)
        return False

    def has_tail_tail_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is a tail-tail edge between two nodes.  
        start --- end

        Parameters:
        - start: The first node.
        - end: The second node.

        Returns:
        - True if a tail-tail edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark == Mark.TAIL and edge.rmark == Mark.TAIL)
        return False


    def has_circ_circ_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is a circ-circ edge between two nodes.  
        start o-o end

        Parameters:
        - start: The first node.
        - end: The second node.

        Returns:
        - True if a circ-circ edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark == Mark.CIRCLE and edge.rmark == Mark.CIRCLE)
        return False
    
    def has_circ_arrow_Edge(self, start: Node, end: Node) -> bool:
        """
        Determine if there is a circ-arrow edge between two nodes.  
        start o-> end

        Parameters:
        - start: The first node.
        - end: The second node.

        Returns:
        - True if a circ-arrow edge exists.
        - False otherwise.
        """
        edge = self.get_Edge(start, end)
        if edge is not None:
            return (edge.lmark == Mark.CIRCLE and edge.rmark == Mark.ARROW)
        return False

    def _check_Edge(self):
        """
        Check if the edges in the graph are node1 < node2.
        """

        for edge in self.edges(data=True):
            edge_obj = edge[2]['edge']
            node1 = edge_obj.start
            node2 = edge_obj.end

            if node1.index > node2.index:
                raise ValueError(f"Edge {node1} - {node2} is not in the correct order. Edge:{edge_obj}: {node1}:{node1.name}, {node1.index} - {node2}:{node2.name}, {node2.index}")

    


    def get_circ_star_Edge(self) -> List[tuple[Node, Node]]:
        """
        Get all edges formed by o-* in the graph.
        Returns:
            A list of tuples representing the edges formed by o-* in the graph.
            In the list, each tuple is of the form (start_node, end_node), start_node o-* end_node.
        """
        circ_star_edges = []
        for edge in self.edges(data=True):
            edge_obj = edge[2]['edge']
            if edge_obj.lmark == Mark.CIRCLE:  #  start o-* end
                circ_star_edges.append((edge_obj.start, edge_obj.end))
            if edge_obj.rmark == Mark.CIRCLE:  #  end o-* start
                circ_star_edges.append((edge_obj.end, edge_obj.start))

        return circ_star_edges
    
    def get_circ_circ_Edge(self) -> List[tuple[Node, Node]]:
        """
        Get all edges formed by o-o in the graph.
        Returns:
            A list of tuples representing the edges formed by o-o in the graph.
            In the list, each tuple is of the form (start_node, end_node), start_node o-o end_node.
        """
        circ_circ_edges = []
        for edge in self.edges(data=True):
            edge_obj = edge[2]['edge']
            if edge_obj.lmark == Mark.CIRCLE and edge_obj.rmark == Mark.CIRCLE:  #  start o-o end
                circ_circ_edges.append((edge_obj.start, edge_obj.end))
            if edge_obj.rmark == Mark.ARROW and edge_obj.lmark == Mark.CIRCLE:
                circ_circ_edges.append((edge_obj.end, edge_obj.start))   #  end o-o start

        return circ_circ_edges
    
    def get_circ_arrow_Edge(self) -> List[tuple[Node, Node]]:
        """
        Get all directed edges in the graph.
        Returns:
            A list of tuples representing the directed edges in the graph.
            In the list, each tuple is of the form (start_node, end_node), start_node o-> end_node.
        """
        circ_arrow_edges = []
        for edge in self.edges(data=True):
            edge_obj = edge[2]['edge']
            if edge_obj.lmark == Mark.CIRCLE and edge_obj.rmark == Mark.ARROW:  #  start o-> end
                circ_arrow_edges.append((edge_obj.start, edge_obj.end))
            if edge_obj.lmark == Mark.ARROW and edge_obj.rmark == Mark.CIRCLE:
                circ_arrow_edges.append((edge_obj.end, edge_obj.start))   #  start <-o end

        return circ_arrow_edges

    def get_directed_Edge(self) -> List[tuple[Node, Node]]:
        """
        Get all directed edges in the graph.
        Returns:
            A list of tuples representing the directed edges in the graph.
            In the list, each tuple is of the form (start_node, end_node), start_node -> end_node.
        """
        directed_edges = []
        for edge in self.edges(data=True):
            edge_obj = edge[2]['edge']
            if edge_obj.lmark == Mark.TAIL and edge_obj.rmark == Mark.ARROW:  #  start --> end
                directed_edges.append((edge_obj.start, edge_obj.end))
            if edge_obj.lmark == Mark.ARROW and edge_obj.rmark == Mark.TAIL:  # start <-- end
                directed_edges.append((edge_obj.end, edge_obj.start))
        
        return directed_edges

    def max_degree(self) -> int:
        """
        Get the maximum degree of the graph.
        """
        return max(dict(self.degree()).values(), default=0)

    def visible_Edge(self, X: Node, Y: Node) -> bool:
        """
        Purpose: Check if the directed edge from X to Y in a MAG or in a PAG
        is visible or not.
        A visible edge X → Y means that there are no latent confounders between X and Y.

        Parameters:
        - X: The first node.
        - Y: The second node.

        Returns:
        - True if a directed edge exists and visible.
        - False otherwise.
        """

        if not self.has_node(X) or not self.has_node(Y):
            raise ValueError("Both nodes must exist in the graph before checking visibility.")
        
        if not self.has_directed_Edge(X, Y):
            return False

        # 1. scenario: there exists a vertex V not adjacent to Y with *--> X
        for V in self.get_into_nodes(X):
            if not self.has_edge(V, Y):
                return True


        # 2. scenario: there is a collider path between V and X that is into X and every non-endpoint node on the path is a parent of Y

        Parents_Y = self.get_parents(Y)
        District_X = self.get_district(X)
        discriminator = (Parents_Y & District_X)  # The discriminator set includes all parents of Y that are in the district of X

        Cand_Vs= set()
        for node in discriminator:
            S = self.get_into_nodes(node) - {X}.update(discriminator)    # V*-> node
            if S:
                for s in S:
                    if not self.has_edge(s, Y):
                        Cand_Vs.add(s)        # Add all candidates to the set Cand_Vs
                # Cand_Vs |= S  

        if len(Cand_Vs) == 0:
            return False
        else:
            for V in Cand_Vs:
                for path in self.get_all_paths(source=V, end=X):
                    if all(node in discriminator for node in path[1:-1]) and self.is_collider_path(path):  # path[1:-1] are all non-endpoint nodes
                        return True
        
        return False



    def clear_all_Edges(self):
        """
        Clear all edges in the graph.
        """
        self.clear_edges()

    ########################
    ## End Edge functions ##
    ########################

    ####################
    ## Path functions ##
    ####################

    def has_path(self, source: Node, end: Node) -> bool:
        """
        Check if there is a path between two nodes in the graph.

        Parameters:
        - source: The source node.
        - end: The end node.

        Returns:
        - True if a path exists, False otherwise.
        """
        if not self.has_node(source) or not self.has_node(end):
            raise ValueError("Both nodes must exist in the graph before checking for a path.")

        return nx.has_path(self, source, end)

    def has_pd_path(self, source: Node, end: Node) -> bool:
        """
        Efficiently check if there is a potentially directed path (p.d. path) from source to end.
        Uses DFS, only traversing edges that satisfy has_pd_Edge.
        """
        if not self.has_node(source) or not self.has_node(end):
            raise ValueError("Both nodes must exist in the graph before checking for a p.d. path.")

        visited = set()

        def dfs(current):
            if current == end:
                return True
            visited.add(current)
            for neighbor in self.get_adj_nodes(current):
                if neighbor not in visited and self.has_pd_Edge(current, neighbor):
                    if dfs(neighbor):
                        return True
            return False

        return dfs(source)
    
    
    
    def has_directed_path(self, source: Node, end: Node) -> bool:
        """
        Efficiently check if there is a directed path from source to end.
        Uses DFS, only traversing edges that satisfy has_directed_Edge.
        """
        if not self.has_node(source) or not self.has_node(end):
            raise ValueError("Both nodes must exist in the graph before checking for a directed path.")

        visited = set()

        def dfs(current):
            if current == end:
                return True
            visited.add(current)
            for neighbor in self.get_adj_nodes(current):
                if neighbor not in visited and self.has_directed_Edge(current, neighbor):
                    if dfs(neighbor):
                        return True
            return False

        return dfs(source)
    


    def get_all_paths(self, source: Node, end: Node) -> List[List[Node]]:
        """
        Get all paths between two nodes in the graph.

        Parameters:
        - source: The source node.
        - end: The end node.

        Returns:
        - A list of all paths, where each path is represented as a list of nodes.
        """
        if not self.has_node(source) or not self.has_node(end):
            raise ValueError("Both nodes must exist in the graph before getting all paths.")
        
        si, ei = source.index, end.index
        paths_by_src = self._cache_path['paths']   # If not exists, automatically create an empty dict
        src_map = paths_by_src.setdefault(si, {})  # Ensure src_map exists, create if not
        cached = src_map.get(ei)
        if cached is not None:
            return cached.copy()


        paths = list(nx.all_simple_paths(self, source=source, target=end))
        src_map[ei] = paths
        return paths.copy()



    def get_all_uncovered_pd_path(self, source: Node, end: Node) -> List[List[Node]]:
        """
        Efficiently generate all uncovered potentially directed paths (p.d. paths) from source to end using DFS.
        - Paths contain no repeated nodes.
        - Each step must be a potentially directed edge.
        - For every triple (Vi-1, Vi, Vi+1) in the path, Vi-1 and Vi+1 must not be adjacent (uncovered).
        """

        si, ei = source.index, end.index
        paths_by_src = self._cache_path['uncovered_pd_paths']   # If not exists, automatically create an empty dict
        src_map = paths_by_src.setdefault(si, {})  # Ensure src_map exists, create if not
        cached = src_map.get(ei)
        if cached is not None:
            return cached.copy()

        results = []

        def dfs(path: List[Node], visited: set):
            current = path[-1]
            if current == end:
                results.append(path.copy())
                return
            for neighbor in self.get_adj_nodes(current):
                if neighbor in visited:
                    continue
                # Must be a potentially directed edge
                if not self.has_pd_Edge(current, neighbor):
                    continue
                # Ensure uncovered: path[-2] and neighbor are not adjacent (if length >= 2)
                if len(path) >= 2 and self.has_edge(path[-2], neighbor):
                    continue
                path.append(neighbor)
                visited.add(neighbor)
                dfs(path, visited)
                path.pop()
                visited.remove(neighbor)

        visited = {source}
        dfs([source], visited)
        src_map[ei] = results
        return results.copy()
    
    def get_all_uncovered_collider_paths_from_target(self, source: Node) -> List[List[Node]]:
        """
        Get all uncovered collider paths starting from a specific source node.
        """
        si = source.index
        paths_by_src = self._cache_path['uncovered_collider_paths_from_t']   # If not exists, automatically create an empty dict
        uncovered_collider_paths_from_target = paths_by_src.setdefault(si, {})  # Ensure uncovered_collider_paths_from_target exists, create if not
        if uncovered_collider_paths_from_target:
            return uncovered_collider_paths_from_target.copy()


        results = []

        def dfs(path: List[Node], visited: set):
            current = path[-1]
            if current == source:
                if path not in results:
                    results.append(path.copy())
                return
            for neighbor in self.get_adj_nodes(current):
                if neighbor in visited:
                    continue                
                prev = path[-2]
                # Current node must be collider: prev *-> current <-* neighbor
                if self.has_into_Edge(neighbor, current) and (not self.has_edge(prev, neighbor)):
                    if self.has_into_Edge(current, neighbor): # if current <-> neighbor
                        path.append(neighbor)
                        visited.add(neighbor)
                        dfs(path, visited)
                        path.pop()
                        visited.remove(neighbor)
                    else:                             # if not current <-> neighbor
                        path.append(neighbor)
                        if path not in results and len(path) >= 3:
                            results.append(path.copy())
                        path.pop()
                    
            
        candidate_paths = []
        for neighbor in self.get_adj_nodes(source):
            if self.has_into_Edge(source, neighbor):
                candidate_paths.append([source, neighbor])

        for path in candidate_paths:
            dfs(path, visited=set(path))

        uncovered_collider_paths_from_target = results
        return results.copy()
    
    def get_all_arrow_collider_paths(self, source: Node, end: Node) -> List[List[Node]]:
        """
        Efficiently search for all arrow-collider paths:
        - The first edge must point to the source (i.e., source <-* neighbor)
        - Every non-endpoint node must be a collider (i.e., Vi-1 *-> Vi <-* Vi+1)
        - The path must not contain repeated nodes
        """

        si, ei = source.index, end.index
        paths_by_src = self._cache_path['arrow_collider_paths']   # If not exists, automatically create an empty dict
        src_map = paths_by_src.setdefault(si, {})  # Ensure src_map exists, create if not
        cached = src_map.get(ei)
        if cached is not None:
            return cached.copy()

        results = []

        def dfs(path: List[Node], visited: set):
            current = path[-1]
            if current == end:
                if len(path) == 2 or all(
                    self.has_into_Edge(path[i-1], path[i]) and self.has_into_Edge(path[i+1], path[i])
                    for i in range(1, len(path)-1)
                ):
                    results.append(path.copy())
                return
            for neighbor in self.get_adj_nodes(current):
                if neighbor in visited:
                    continue
                # The first edge must point to source
                if len(path) == 1 and not self.has_into_Edge(neighbor, current):
                    continue
                # Non-endpoint must be collider
                if len(path) >= 2:
                    prev = path[-2]
                    # Current node must be collider: prev *-> current <-* neighbor
                    if not (self.has_into_Edge(prev, current) and self.has_into_Edge(neighbor, current)):
                        continue
                path.append(neighbor)
                visited.add(neighbor)
                dfs(path, visited)
                path.pop()
                visited.remove(neighbor)

        visited = {source}
        dfs([source], visited)
        # self._cache[key] = results
        src_map[ei] = results
        return results.copy()

    def is_collider(self, node1: Node, node2: Node, node3: Node) -> bool:
        """
        Check if the path node1 -> node2 <- node3 is a collider path.

        Parameters:
        - node1: The first node.
        - node2: The second node (the collider).
        - node3: The third node.

        Returns:
        - True if the path is a collider path, False otherwise.
        """
        if not self.has_node(node1) or not self.has_node(node2) or not self.has_node(node3):
            raise ValueError("All nodes must exist in the graph before checking for a collider.")
        
        return self.has_into_Edge(node1, node2) and self.has_into_Edge(node3, node2)


    def is_collider_path(self, path: List[Node]) -> bool:
        """
        Check if a given path is a collider path.

        Parameters:
        - path: The path to check.

        Returns:
        - True if the path is a collider path, False otherwise.
        """
        if len(path) < 3:
            return False

        # Check if the middle node is a collider
        for i in range(1, len(path) - 1):
            if not self.is_collider(path[i - 1], path[i], path[i + 1]):
                return False
        return True

    def is_uncovered_path(self, path: List[Node]) -> bool:
        """
        Check if a given path is an uncovered path.

        An uncovered path <V0, ..., Vn> is one where for every consecutive triple (Vi-1, Vi, Vi+1),
        Vi-1 and Vi+1 are not adjacent.

        Parameters:
        - path: The path to check (list of Node).

        Returns:
        - True if the path is uncovered, False otherwise.
        """
        if len(path) < 3:
            return True  # Trivially uncovered

        for i in range(1, len(path) - 1):
            prev_node = path[i - 1]
            next_node = path[i + 1]
            if self.has_edge(prev_node, next_node):
                return False
        return True
    
    def is_potential_directed_path(self, path: List[Node]) -> bool:
        """
        Check if a given path is a potentially directed path (p.d. path).

        A path p = <V0, ..., Vn> is potentially directed if for every 0 <= i <= n-1,
        the edge between Vi and Vi+1 is not into Vi or out of Vi+1.

        Parameters:
        - path: The path to check (list of Node).

        Returns:
        - True if the path is potentially directed, False otherwise.
        """
        if len(path) < 2:
            return False  # A path must have at least two nodes

        for i in range(len(path) - 1):
            Vi = path[i]
            Vi1 = path[i + 1]
            edge = self.get_Edge(Vi, Vi1)
            if edge is None:
                return False
            # Not into Vi: 
            if edge.lmark == Mark.ARROW:
                return False
            # Not out of Vi+1: 
            if edge.rmark == Mark.TAIL:
                return False
        return True
 

    ########################
    ## End Path functions ##
    ########################



    ###########################################
    ## Initializing Graphs in Different Ways ##
    ###########################################
    def from_numpy_array(self, adj_matrix: np.ndarray, node_list: List[Node] = None):
        """
        Create a Mixed Graph from a NumPy adjacency matrix.
        """
        if node_list is None:
            self.node_list = [Node(name=str(i), index=i) for i in range(adj_matrix.shape[0])]  # [Node(str, 0), Node(str, 1), ...]
        else:
            self.node_list = node_list

        self.add_nodes_from(self.node_list)
        for i, j in combinations(self.node_list, 2):
            mark_ij = adj_matrix[i.index, j.index]  # i markji --- markij j
            mark_ji = adj_matrix[j.index, i.index]

            # print(f"Mark.ARROW:{Mark.ARROW.value}, Mark.CIRCLE:{Mark.CIRCLE}, Mark.TAIL:{Mark.TAIL}, Mark.NULL:{Mark.NULL}")

            if mark_ij == Mark.ARROW.value:
                if mark_ji == Mark.ARROW.value:    # i <-> j
                    self.add_bidirected_Edge(i, j)  
                elif mark_ji == Mark.CIRCLE.value:  # i o-> j
                    self.add_circ_arrow_Edge(i, j)
                elif mark_ji == Mark.TAIL.value:     # i --> j
                    self.add_directed_Edge(i, j)
                else:  
                    raise ValueError(f"Invalid edge mark: {mark_ji} between {i} and {j}.")
            elif mark_ij == Mark.CIRCLE.value:      
                if mark_ji == Mark.ARROW.value:    # i <-o j
                    self.add_circ_arrow_Edge(j, i)
                elif mark_ji == Mark.CIRCLE.value:  # i o-o j
                    self.add_circ_Edge(j, i)
                elif mark_ji == Mark.TAIL.value:  # i --o j
                    self.add_circ_tail_Edge(j, i)
                else:  
                    raise ValueError(f"Invalid edge mark: {mark_ji} between {i} and {j}.")
            elif mark_ij == Mark.TAIL.value:
                if mark_ji == Mark.ARROW.value:  # i <-- j
                    self.add_directed_Edge(j, i)
                elif mark_ji == Mark.CIRCLE.value: # i o-- j
                    self.add_circ_tail_Edge(i, j)
                elif mark_ji == Mark.TAIL.value:  # i --- j
                    self.add_tail_Edge(i, j)
                else:  
                    raise ValueError(f"Invalid edge mark: {mark_ji} between {i} and {j}.")
            elif mark_ij == Mark.NULL.value:
                pass  # No edge between i and j
            else:
                raise ValueError(f"Invalid edge mark: {mark_ij} between {i} and {j}." "Invalid edge mark in the adjacency matrix.")
            
    def DAG_from_numpy_array(self, adj_matrix: np.ndarray, node_list: List[Node] = None):
        """
        Create a Directed Acyclic Graph (DAG) from a NumPy adjacency matrix.
        """
        if node_list is None:
            self.node_list = [Node(name=str(i), index=i) for i in range(adj_matrix.shape[0])]  # [Node(str, 0), Node(str, 1), ...]
        else:
            self.node_list = node_list


        self.add_nodes_from(self.node_list)
        for i, j in combinations(self.node_list, 2):
            mark_ij = adj_matrix[i.index, j.index]  # i markji --- markij j
            mark_ji = adj_matrix[j.index, i.index]

            # print(f"Mark.ARROW:{Mark.ARROW.value}, Mark.CIRCLE:{Mark.CIRCLE}, Mark.TAIL:{Mark.TAIL}, Mark.NULL:{Mark.NULL}")

            if mark_ij == 1:
                self.add_directed_Edge(i, j)     # i --> j
            elif mark_ij == 0:
                pass  # No edge between i and j
            else:
                raise ValueError(f"Invalid edge mark: {mark_ij} between {i} and {j}." "Invalid edge mark in the adjacency matrix.")

    def from_pandas_adjacency(self, adj_matrix: pd.DataFrame, graph_type: str = 'MG'):

        """
        Create a Mixed Graph from a Pandas DataFrame adjacency matrix.
        """

        Node_list = [Node(name=node, index=i) for i, node in enumerate(adj_matrix.columns.to_list())]  # [Node
        if graph_type == 'MG':
            self.from_numpy_array(adj_matrix.to_numpy(), node_list=Node_list)
        elif graph_type == 'DAG':
            self.DAG_from_numpy_array(adj_matrix.to_numpy(), node_list=Node_list)
            

    def from_node_list(self, node_list: List[Union[str, Node]]):
        """
        Create a empty graph from a list of nodes.
        """
        if all(isinstance(node, Node) for node in node_list):
            # If node_list is already a list of Node objects, use it directly
            Node_list = node_list
        else:
            Node_list = [Node(name=node, index=i) for i, node in enumerate(node_list)] # [Node(str, 0), Node(str, 1), ...]
        self.node_list = Node_list
        self.add_nodes_from(self.node_list)
        self.clear_cache()

    ###############################################
    ## End Initializing Graphs in Different Ways ##
    ###############################################



    ####################
    ## Node functions ##
    ####################
    def add_Node(self, node: Node):
        """
        Add a node to the graph.

        Parameters:
        - node: The node to be added.
        """
        if not isinstance(node, Node):
            raise TypeError("node must be of type Node.")
        self.add_node(node)
        self.node_list.append(node)
        # Keep node_list ordered by node.index (ascending)
        self.node_list.sort(key=lambda n: n.index if n.index is not None else 0)

    def remove_Node(self, node: Node):
        """
        Remove a node from the graph.

        Parameters:
        - node: The node to be removed.
        """
        
        if not isinstance(node, Node):
            raise TypeError("node must be of type Node.")
        if not self.has_node(node):
            raise ValueError(f"Node {node} does not exist in the graph.")
        
        self.remove_node(node)
        # Remove all edges connected to the node
        self.remove_edges_from(list(self.edges(node)))
        self.clear_cache()

    def get_adj_nodes(self, node: Node) -> Set[Node]:
        """
        Get the set of nodes that are adjacent to the given node in the graph.  
        *-* node
        """
        t = node.index
        nodes_by_src = self._cache_nodes['adj_nodes']   # If not exists, automatically create an empty dict
        adj_nodes = nodes_by_src.setdefault(t, set())
        if adj_nodes:  # Cache exists and the cached set is not empty
            return adj_nodes.copy()
        
        adj = set(self.neighbors(node))
        nodes_by_src[t] = adj

        return adj.copy()

    def get_into_nodes(self, node: Node) -> Set[Node]:
        """
        Get the set of nodes that are adjacent into the given node in the graph. 
        node <-*
        """
        t = node.index
        nodes_by_src = self._cache_nodes['into_nodes']   # If not exists, automatically create an empty dict
        into_nodes_t = nodes_by_src.setdefault(t, set())
        if into_nodes_t:  # Cache exists and the cached set is not empty
            return into_nodes_t.copy()

        adj_nodes = self.get_adj_nodes(node)
        into_nodes = set()
        for adj_node in adj_nodes:
            if self.has_into_Edge(adj_node, node):
                into_nodes.add(adj_node)

        nodes_by_src[t] = into_nodes
        return into_nodes.copy()
    
    def get_no_into_nodes(self, node: Node) -> Set[Node]: 
        """
        Get the set of nodes that are adjacent not into the given node in the graph. 
        node -* or node o-*
        """
        t = node.index
        nodes_by_src = self._cache_nodes['no_into_nodes']   # If not exists, automatically create an empty dict
        no_into_nodes_t = nodes_by_src.setdefault(t, set())
        if no_into_nodes_t:  # Cache exists and the cached set is not empty
            return no_into_nodes_t.copy()

        adj_nodes = self.get_adj_nodes(node)
        into_nodes = self.get_into_nodes(node)
        no_into_nodes = adj_nodes - into_nodes

        nodes_by_src[t] = no_into_nodes
        return no_into_nodes.copy()
    
    def get_nondirect_adj_nodes(self, node: Node) -> Set[Node]:
        """
        Get the set of nondirected neighbors of the given node in the graph.
        node o-o adj_node
        """

        t = node.index
        nodes_by_src = self._cache_nodes['nondirect_adj_nodes']   # If not exists, automatically create an empty dict
        nondirect_adj_nodes_t = nodes_by_src.setdefault(t, set())
        if nondirect_adj_nodes_t:  # Cache exists and the cached set is not empty
            return nondirect_adj_nodes_t.copy()

        adj_nodes = self.get_adj_nodes(node)
        nondirect_adj_nodes = set()
        for adj_node in adj_nodes:
            if self.has_circ_circ_Edge(node, adj_node):
                nondirect_adj_nodes.add(adj_node)

        nodes_by_src[t] = nondirect_adj_nodes
        return nondirect_adj_nodes.copy()

    def get_circ_star_nodes(self, node: Node) -> Set[Node]:
        """
        Get the set of adjacent nodes such that node (o) --* adj_node.
        node o-* adj_node
        """

        t = node.index
        nodes_by_src = self._cache_nodes['circ_star_nodes']   # If not exists, automatically create an empty dict
        circ_star_nodes_t = nodes_by_src.setdefault(t, set())
        if circ_star_nodes_t:  # Cache exists and the cached set is not empty
            return circ_star_nodes_t.copy()

        adj_nodes = self.get_adj_nodes(node)
        circ_star_nodes = set()
        for adj_node in adj_nodes:
            if self.has_circ_star_Edge(node, adj_node):
                circ_star_nodes.add(adj_node)


        nodes_by_src[t] = circ_star_nodes
        return circ_star_nodes.copy()
    
    def max_pds_size(self) -> int:
        """
        Get the maximum size of the possible d-separation set for all nodes in the graph.
        """
        return max(len(self.get_possible_d_sep(node)) for node in self.node_list) if self.node_list else 0

    def get_possible_d_sep(self, node: Node) -> Set[Node]:
        """
        Compute Possible-D-SEP(node):
        X_k ∈ pds(C, X_i, X_j) iff there exists a path π between X_i and X_k such that
        for every subpath <X_m, X_l, X_h> of π, X_l is a collider on the subpath in C
        or <X_m, X_l, X_h> is a triangle in C (i.e. X_m and X_h are adjacent).
        Here we return the set of nodes X_k reachable from `node` satisfying that condition.
        """
        t = node.index
        nodes_by_src = self._cache_nodes['possible_d_sep']   # If not exists, automatically create an empty dict
        pds_t = nodes_by_src.setdefault(t, set())
        if pds_t:  # cache exists and not empty
            return pds_t.copy()

        results = set()

        # Initialize stack with 1-step paths [node, neighbor]
        for neigh in self.get_adj_nodes(node):
            if neigh == node:
                continue
            stack = [[node, neigh]]
            while stack:
                path = stack.pop()
                last = path[-1]
                # add reachable node (exclude the start node itself)
                if last is not node:
                    results.add(last)

                # try to extend path
                for nbr in self.get_adj_nodes(last):
                    if nbr in path:
                        continue  # avoid cycles / repeated nodes
                    # for the triple (prev, curr, nbr), check condition:
                    prev = path[-2]  # guaranteed because path length >=2
                    curr = last
                    # condition: curr is collider on subpath OR prev and nbr are adjacent (triangle)
                    if self.is_collider(prev, curr, nbr) or self.has_edge(prev, nbr):
                        # extension allowed
                        new_path = path + [nbr]
                        stack.append(new_path)
                    else:
                        # extension not allowed by definition
                        continue

        # remove the node itself if accidentally added
        results.discard(node)

        nodes_by_src[t] = results
        return results.copy()



    
    def get_pd_path_nodes(self, node: Node) -> Set[Node]:
        """
        Get the set of adjacent nodes such that node (o or -) -- (o or >) adj_node.
        node o-o adj_node or node o-> adj_node or node --o adj_node or node --> adj_node
        """
        t = node.index
        nodes_by_src = self._cache_nodes['pd_path_nodes']   
        pd_path_nodes_t = nodes_by_src.setdefault(t, set())
        if pd_path_nodes_t:  
            return pd_path_nodes_t.copy()

        adj_nodes = self.get_adj_nodes(node)
        pd_path_nodes = set()
        for adj_node in adj_nodes:
            edge = self.get_Edge(node, adj_node)
            if (edge.lmark in {Mark.CIRCLE, Mark.TAIL}) and (edge.rmark in {Mark.CIRCLE, Mark.ARROW}):
                pd_path_nodes.add(adj_node)

        nodes_by_src[t] = pd_path_nodes
        return pd_path_nodes.copy()

    def get_parents(self, node: Node) -> Set[Node]:
        """
        Get the set of parent nodes of the given node in the graph.  
        --> node
        """

        t = node.index
        nodes_by_src = self._cache_nodes['parents']   
        parents_t = nodes_by_src.setdefault(t, set())
        if parents_t:  
            return parents_t.copy()

        adj_nodes = self.get_adj_nodes(node)
        parents = set()
        for adj_node in adj_nodes:
            if self.has_directed_Edge(adj_node, node):
                parents.add(adj_node)

        nodes_by_src[t] = parents
        return parents.copy()

    def get_children(self, node: Node) -> Set[Node]:
        """
        Get the set of child nodes of the given node in the graph.  
        node -> *
        """

        t = node.index
        nodes_by_src = self._cache_nodes['children']   
        children_t = nodes_by_src.setdefault(t, set())
        if children_t:  
            return children_t.copy()

        adj_nodes = self.get_adj_nodes(node)
        children = set()
        for adj_node in adj_nodes:
            if self.has_directed_Edge(node, adj_node):
                children.add(adj_node)

        nodes_by_src[t] = children
        return children.copy()

    def get_spouse(self, node: Node) -> Set[Node]:
        """
        Get the set of spouse nodes of the given node in the graph.  
        node <-> *
        """
        t = node.index
        nodes_by_src = self._cache_nodes['spouses']   
        spouses_t = nodes_by_src.setdefault(t, set())
        if spouses_t:  
            return spouses_t.copy()

        adj_nodes = self.get_adj_nodes(node)
        spouses = set()
        for adj_node in adj_nodes:
            if self.has_bidirected_Edge(node, adj_node):
                spouses.add(adj_node)

        nodes_by_src[t] = spouses
        return spouses.copy()

    def get_district(self, node: Node) -> Set[Node]:
        """
        Get the district of the given node in the graph.  
        The district includes all nodes reachable from the given node via only bidirected edges (<->).
        """
        t = node.index
        nodes_by_src = self._cache_nodes['district']   # If not exists, automatically create an empty dict
        district_t = nodes_by_src.setdefault(t, set())
        if district_t:  # Cache exists and the cached set is not empty
            return district_t.copy()

        district = set()
        stack = [node]  # Use a stack for DFS
        visited = set()

        while stack:
            current_node = stack.pop()
            if current_node not in visited:
                visited.add(current_node)
                district.add(current_node)
                # Add all spouses (nodes connected by bidirected edges) to the stack
                spouses = self.get_spouse(current_node)
                stack.extend(spouses - visited)

        # Remove the node itself from the district
        district.discard(node)

        nodes_by_src[t] = district
        return district.copy()
    
    

    def get_PossibleDe(self, node: Union[str, int]) -> Set[Union[str, int]]:
        """
        A possibly directed path or possibly causal path from X to Y is a path from X to Y that does not contain an arrowhead pointing in the direction of X.
        If there is a directed (possibly directed) path from X to Y, then X is a ancestor (possible ancestor) of Y, and Y is a descendant (possible descendant) of X.
        """


        possible_de = set()
        stack = [node]
        visited = set()

        while stack:
            current_node = stack.pop()
            if current_node not in visited:
                visited.add(current_node)
                possible_de.add(current_node)
                # Add all adjacent nodes to the stack
                not_into_nodes = self.get_no_into_nodes(current_node)
                stack.extend(not_into_nodes - visited)

        # Remove the node itself from the possible_de set
        possible_de.discard(node)


        return possible_de



    def find_unique_triplets(self) -> List[Tuple[Node, Node, Node]]:
        """
        Efficiently find unique triplets <z, y, x> in MixGraph, avoiding symmetric duplicates.
        z.index < y.index < x.index, avoid <z, y, x> and <x, y, z>
        Returns:
            List [Tuple[Node, Node, Node]]: List of unique triplets. 
        """


        triplets = []

        for y in self.node_list:
            # Get neighbors of y
            neighbors_y = self.get_adj_nodes(y)

            if len(neighbors_y) >= 2:
                for z, x in combinations(neighbors_y, 2):
                    triplets.append((z, y, x))


        return triplets


    ########################
    ## End Node functions ##
    ########################

    def _init_complete_graph(self):
        """
        Initialize a complete undirected graph with the given node set.
        """
        for node1, node2 in combinations(self.node_list, 2):
            self.add_circ_Edge(node1, node2)
        self.clear_cache()
    
    #############################
    ## Visualization functions ##
    #############################
    """
    Install Graphviz from https://graphviz.org/download/ (12.2.1)
    and make sure to add the Graphviz bin directory to your PATH environment variable.
    Use 'dot -V' in the command line to check if Graphviz is installed correctly.
    NOTE: (1) pip install graphviz==0.20.3 pydot==3.0.2
          (2) add the graphviz bin directory (default path: C:\Program Files\Graphviz\bin) to your PATH environment variable

    The function to_pydot() is better than to_dot() in using.
    """
    def to_dot(self, filename: str = "example_file_dot", view: bool = True):
        """
        Visualize the PAG using graphviz.
        Reference:
        https://github.com/cmu-phil/py-tetrad/blob/main/pytetrad/tools/visualize.py
        """
        from graphviz import Digraph

        # arrowhead/arrowtail map(Graphviz odot= '◦', normal= '>', none= '—'）
        map_mark = {Mark.CIRCLE: "odot", 
                    Mark.TAIL: "none", 
                    Mark.ARROW: "normal"}
        
        gdot = Digraph(comment="MG")
        # Add nodes
        for Node in self.node_list:
            gdot.node(Node.name,shape="circle",
                  fixedsize="true",
                  style="filled",
                  color="lightgray")
            
        for edge in self.edges(data=True):
            edge_obj = edge[2]['edge']
            node1 = edge_obj.start
            node2 = edge_obj.end
            head = edge_obj.rmark
            tail = edge_obj.lmark

            gdot.edge(node1.name, node2.name,
                    arrowtail=map_mark[tail],
                    arrowhead=map_mark[head],
                    dir="both")

        gdot.render(filename, view=view, format="pdf")  
        return gdot  


    def to_pydot(self, filename: str = None, view: bool = True, **kwargs):
        """
        Convert the PAG to a PyDot graph object.

        Reference:
        https://github.com/py-why/causal-learn/blob/9689c1bdc468847729eacf0921b76f598161ae16/causallearn/utils/GraphUtils.py#L512
        """
        import sys
        if not (sys.version_info.major == 3 and sys.version_info.minor == 9 and sys.version_info.micro == 19):
            raise RuntimeError("Python version must be 3.9.19")

        import pydot

        # Create a new PyDot graph
        graph_pydot = pydot.Dot(graph_type='digraph', fontsize=18)

        target = kwargs.get('target', None) # A Node
        Mb_nodes = kwargs.get('Mb_nodes', None) # A list of Node objects
        Mb_nodes = [node.name for node in Mb_nodes] if Mb_nodes is not None else None
        # Add nodes
        for node in self.node_list:
            if str(node.name).startswith("L"):
                graph_pydot.add_node(pydot.Node(node.name, shape='circle', style='filled', color='lightgray'))
            elif str(node.name).startswith("S"):
                
                graph_pydot.add_node(pydot.Node(node.name, shape='box', style='filled', color='lightgray'))
            else:
                if target is not None and node.name == target.name:
                    
                    graph_pydot.add_node(pydot.Node(node.name, shape='circle', style='filled', fillcolor='#FF6666CC',  
                    color='#CC0000',       
                    ))
                elif Mb_nodes is not None and node.name in Mb_nodes:
                    graph_pydot.add_node(pydot.Node(node.name, shape='circle', style='filled', fillcolor='#66B3FFCC',  
                    color='#0055AA',        
                    ))
                else:
                    graph_pydot.add_node(pydot.Node(node.name, shape='circle', style='', color='black'))

        map_mark = {Mark.CIRCLE: "odot", 
                    Mark.TAIL: "none", 
                    Mark.ARROW: "normal"}

        # Add edges
        for edge in self.edges(data=True):
            edge_obj = edge[2]['edge']
            node1 = edge_obj.start
            node2 = edge_obj.end
            head = edge_obj.rmark
            tail = edge_obj.lmark

            # Add directed edges with appropriate arrowheads and arrowtails
            graph_pydot.add_edge(pydot.Edge(node1.name, node2.name,
                                      arrowtail=map_mark[tail],
                                      arrowhead=map_mark[head],
                                      dir='both'))
        if filename is not None:
            # Save the graph to a file
            pdf_path = filename + '.pdf'
            graph_pydot.write_pdf(pdf_path)
            if view:
                # Automatically open the PDF file if view is True
                import os
                os.startfile(pdf_path)
        

        return graph_pydot

    



    #################################
    ## End Visualization functions ##
    #################################

    #####################
    ## Graph functions ##
    #####################

    def MG_to_numpy_array(self) -> np.ndarray:
        """
        Convert the graph to a NumPy adjacency matrix.
        """
        length = len(self.node_list)
        graph_matrix = np.zeros((length, length), dtype=int)
        for u, v in self.edges():
            edge = self.get_Edge(u, v)
            lmark_index = edge.lmark.value
            rmark_index = edge.rmark.value
            graph_matrix[u.index, v.index] = rmark_index
            graph_matrix[v.index, u.index] = lmark_index

        return graph_matrix
    
    def MG_to_pandas_adjacency(self) -> pd.DataFrame:
        """
        Convert the graph to a Pandas DataFrame adjacency matrix.
        """
        graph_matrix = self.MG_to_numpy_array()
        return pd.DataFrame(graph_matrix, index=[node.name for node in self.node_list], columns=[node.name for node in self.node_list])

    def is_equal(self, other: "MixGraph") -> bool:
        """
        Compare this MixGraph with another MixGraph for structural and edge-mark equivalence using networkx.is_isomorphic.

        Parameters:
        - other: Another MixGraph instance.

        Returns:
        - True if both graphs are isomorphic with respect to edge marks.
        - False otherwise.
        """
        if not isinstance(other, MixGraph):
            return False

        return nx.is_isomorphic(self, other)
    
    def copy(self) -> "MixGraph":
        """
        Create a deep copy of the MixGraph, including the node_list attribute.
        """
        new_graph = MixGraph()
        # Copy the graph structure from networkx
        new_graph.add_nodes_from(self.nodes())
        new_graph.add_edges_from((u, v, attr.copy()) for (u, v, attr) in self.edges(data=True))
        # Copy the node_list
        new_graph.node_list = self.node_list.copy() if hasattr(self, 'node_list') else []
        # Copy the cache
        new_graph._cache_path = self._cache_path.copy()
        new_graph._cache_nodes = self._cache_nodes.copy()
        return new_graph

    #########################
    ## End Graph functions ##
    #########################

