import logging
logging.disable(logging.CRITICAL)
import time
from typing import List, Optional, Union
from Utils.PMB_mixgraph import MixGraph
from Utils.sepset import Separation_Set
import pandas as pd
import numpy as np
from Utils.Graph_utils import Edge, Mark, Node
from Utils.PMB_CI_test import CI_test
from itertools import combinations
import networkx as nx

from Utils.Markov_Blanket_Learner import MB_learn

"""
NOTE: The 'Node' only be used in  the Graph learning process, we don't use it in the other process, e.g., MB learning.
"""
logger = logging.getLogger("PMB_Learner") 
logger.setLevel(logging.INFO)


def local_get_pastar_mb_pag(data: pd.DataFrame, target: str, alpha: float = 0.05, ci_type: str = "Fisher_Z", latent_nodes: list = None, **kwargs) -> dict:
    """
    Local PAG learning for the Pastar set and Markov Blanket of a target node.
    Args:
        data: pandas.DataFrame, observational data
        target: str, target node name
        alpha: significance level for CI tests
        ci_type: type of CI test ("Fisher_Z" or "D_sep")
        latent_nodes: list of latent node names (optional)
    Returns:
        dict with keys 'Pastar' (set(str) of node names) and 'PAG.MixGraph' (the learned PAG)
    """

    pmb_learner = PMB_Learner(data, alpha=alpha, ci_type=ci_type, latent_nodes=latent_nodes, **kwargs)
    result = pmb_learner.Pastar_learner(target)
    ci_num = pmb_learner.get_ci_test_number()
    result['CI_num'] = ci_num
    return result





class PMB_Learner:

    def __init__(self, data, alpha=0.05, ci_type="Fisher_Z", **kwargs):
        self.latent_nodes = kwargs.get("latent_nodes", None)
        if self.latent_nodes is not None and ci_type != "D_sep":
            raise ValueError(
                "Latent nodes or selection bias nodes are only supported with D_sep CI type."
            )
        if ci_type == "D_sep":
            if not isinstance(data, pd.DataFrame):
                raise TypeError("For D_sep, data must be adjacency matrix DataFrame.")

            full_adj_df = data.copy()
            adj = full_adj_df.values
            if adj.shape[0] != adj.shape[1]:
                raise ValueError(
                    f"Adjacency matrix must be square, got {adj.shape}"
                )
            G = nx.from_numpy_array(adj, create_using=nx.DiGraph)
            mapping = {i: name for i, name in enumerate(full_adj_df.columns)}
            self.G = nx.relabel_nodes(G, mapping)
            observed_data = (
                full_adj_df.drop(columns=self.latent_nodes)
                if self.latent_nodes is not None
                else full_adj_df
            )


        # ---------- DATA CASE: FisherZ ----------
        else:
            observed_data = data
            self.G = None  # no true graph in FisherZ case

        # ---------- CI test ----------
        self.ci_test = CI_test(
            data,
            method_type=ci_type,
            alpha=alpha,
            **kwargs
        )

        # ---------- MB learner ----------
        if ci_type == "D_sep":
            mb_method = "TC"
        else:  # Fisher_Z
            mb_method = "gaussian_MB"

        self.mb_learner = MB_learn(
            observed_data,
            ci_test=self.ci_test,
            alpha=alpha,
            mb_method_type=mb_method,
            **kwargs
        )

        # ---------- basic configs ----------
        self.max_K = kwargs.get("max_K", 5)

        # ---------- nodes / PAG ----------
        self._init_nodes(observed_data)
        self.sepsets = Separation_Set(set(self.Nodes_list))
        self.pag = MixGraph(incoming_graph_data=self.Nodes_list)

        # ---------- bookkeeping ----------
        self.all_vars = observed_data.columns.to_list()
        self._cache_mb = dict()


    def get_mb(self, target_node: Node) -> set[Node]:
        """
        Get the Markov Blanket of the target node.
        """
        if target_node.name in self._cache_mb:
            return self._cache_mb[target_node.name]

        mb_names = self.mb_learner(target_node.name)
        mb_nodes = {self.Nodes_dict[node_name] for node_name in mb_names}
        for node in self.Nodes_list:
            if node not in mb_nodes:
                self.sepsets._add(target_node, node, mb_nodes)
        self._cache_mb[target_node.name] = mb_nodes
        return mb_nodes


    def get_ci_test_number(self) -> int:
        """
        Get the number of CI tests performed.
        """
        return self.ci_test.get_ci_num()  
    
    def _init_nodes(self, data: pd.DataFrame):

        if isinstance(data, pd.DataFrame):
            self.Nodes_list = [Node(node_name, index) for index, node_name in enumerate(data.columns)]
            self.Nodes_dict = {node.name: node for node in self.Nodes_list} 
        else:
            raise TypeError("Data must be a pandas DataFrame.")



    def learn_skeleton_base_adj(self, sub_nodes: list[Node]) -> MixGraph:
        """
        Learn skeleton over sub_nodes.
        FIXED LocalPC:
        Use symmetric candidate separating sets:
            Cand(x,y) = (Adj(x)\{y}) ∪ (Adj(y)\{x})
        """

        sub_graph = MixGraph(incoming_graph_data=sub_nodes)
        sub_graph._init_complete_graph()
        logger.info(f'Initial graph information: {sub_graph}')

        for x, y in combinations(sub_nodes, 2):
            if self.sepsets.has_sepset(x, y):
                sub_graph.remove_Edge(x, y)
                logger.info(f'remove {x} -- {y} via has sepset')

        sep_size = 0
        while sep_size <= self.max_K and (sub_graph.max_degree() - 1 >= sep_size):

            edges_snapshot = []
            for x in sub_graph.node_list:
                for y in sub_graph.get_adj_nodes(x):
                    if x.index < y.index:
                        edges_snapshot.append((x, y))

            removed_any = False

            for x_node, y_node in edges_snapshot:

                if y_node not in sub_graph.get_adj_nodes(x_node):
                    continue

                if self.pag.has_edge(x_node, y_node):
                    continue

                adj_x = sub_graph.get_adj_nodes(x_node) - {y_node}
                adj_y = sub_graph.get_adj_nodes(y_node) - {x_node}
                cand_nodes = adj_x | adj_y
                if len(cand_nodes) < sep_size:
                    continue

                cand_names = [n.name for n in cand_nodes]

                for sepset in combinations(cand_names, sep_size):
                    if self.ci_test(x_node.name, y_node.name, list(sepset))[0]:
                        self.sepsets._add(
                            x_node,
                            y_node,
                            set(self.Nodes_dict[name] for name in sepset)
                        )
                        sub_graph.remove_Edge(x_node, y_node)
                        removed_any = True
                        break

            sep_size += 1

        return sub_graph


    def learn_skeleton_complete(self, sub_nodes: list[Node], verbose: bool = False) -> MixGraph:
        """
        Learn the skeleton of the graph based on the provided subset of nodes.
        
        This is an FCI-style skeleton learner with TWO stages:
        1) Symmetric PC-style adjacency pruning (FIXED)
        2) Possible-D-SEP based pruning (standard FCI)
        
        FIX:
        - Stage 1 now uses symmetric candidate separating sets:
            Cand(x,y) = (Adj(x)\{y}) ∪ (Adj(y)\{x})
        instead of Adj(x) only.
        """

        sub_graph = MixGraph(incoming_graph_data=sub_nodes)
        sub_graph._init_complete_graph()
        logger.info(f'Initial graph information: {sub_graph}')

        for x, y in combinations(sub_nodes, 2):
            if self.sepsets.has_sepset(x, y):
                sub_graph.remove_Edge(x, y)
                logger.info(f'remove {x} -- {y} via has sepset')

        time_one = time.perf_counter()
        sep_size: int = 0

        while sep_size <= self.max_K and (sub_graph.max_degree() - 1 >= sep_size):

            edges_snapshot = []
            for x_node in sub_graph.node_list:
                for y_node in sub_graph.get_adj_nodes(x_node):
                    if x_node.index < y_node.index:
                        edges_snapshot.append((x_node, y_node))

            removed_any = False

            for x_node, y_node in edges_snapshot:

                if not sub_graph.has_edge(x_node, y_node):
                    continue

                if self.pag.has_edge(x_node, y_node):
                    continue

                adj_x = sub_graph.get_adj_nodes(x_node) - {y_node}
                adj_y = sub_graph.get_adj_nodes(y_node) - {x_node}
                cand_nodes = adj_x | adj_y

                if len(cand_nodes) < sep_size:
                    continue

                cand_names = [n.name for n in cand_nodes]
                logger.info(
                    f'[Adj-stage] Checking {x_node} -- {y_node}, '
                    f'sep_size={sep_size}, |Cand|={len(cand_names)}'
                )

                for sepset in combinations(cand_names, sep_size):
                    if self.ci_test(x_node.name, y_node.name, list(sepset))[0]:
                        self.sepsets._add(
                            x_node,
                            y_node,
                            set(self.Nodes_dict[name] for name in sepset)
                        )
                        sub_graph.remove_Edge(x_node, y_node)
                        removed_any = True
                        break

            sep_size += 1

        time_two = time.perf_counter()
        if verbose:
            print(f"Time taken for symmetric adj-stage: {time_two - time_one:.4f} seconds")

        self.orient_collider(sub_graph)

        time_three = time.perf_counter()

        for x_node in sub_graph.node_list:
            pds_set_x = sub_graph.get_possible_d_sep(x_node)
            adj_list_x = list(sub_graph.get_adj_nodes(x_node))

            for y_node in adj_list_x:

                if not sub_graph.has_edge(x_node, y_node):
                    continue

                sep_size = 0
                while sep_size <= self.max_K and (len(pds_set_x) - 1 >= sep_size):

                    pds_node_x_noy = pds_set_x - {y_node}
                    pds_names = [n.name for n in pds_node_x_noy]

                    logger.info(
                        f'[PDS-stage] Checking {x_node} -- {y_node}, sep_size={sep_size}'
                    )

                    for sepset in combinations(pds_names, sep_size):
                        if self.ci_test(x_node.name, y_node.name, list(sepset))[0]:
                            self.sepsets._add(
                                x_node,
                                y_node,
                                set(self.Nodes_dict[name] for name in sepset)
                            )
                            sub_graph.remove_Edge(x_node, y_node)
                            break
                    if not sub_graph.has_edge(x_node, y_node):
                        break

                    sep_size += 1

        time_four = time.perf_counter()
        if verbose:
            print(f"Time taken for PDS-stage: {time_four - time_three:.4f} seconds")

        sub_graph.clear_all_orientations()
        return sub_graph



    def update_subpag_to_pag(self, sub_pag: MixGraph, target_node: Node):
        for node in sub_pag.get_adj_nodes(target_node):
            edge = sub_pag.get_Edge(target_node, node)
            if not self.pag.has_edge(target_node, node):
                self.pag.add_Edge(target_node, node, edge)

        collider_paths = sub_pag.get_all_uncovered_collider_paths_from_target(target_node)
        for path in collider_paths:
            for i in range(len(path) - 1):
                edge = sub_pag.get_Edge(path[i], path[i + 1])
                if not self.pag.has_edge(path[i], path[i + 1]):
                    self.pag.add_Edge(path[i], path[i + 1], edge)

        self.pag = self.orient_rules(self.pag)

    def stop_one(self, target_node: Node) -> bool:

        mb_plus_set = self.get_mb(target_node) | {target_node}
        for a, b in combinations(mb_plus_set, 2):
            if (not self.sepsets.has_sepset(a, b)) and (not self.pag.has_edge(a, b)): # 如果没存储分离集，但是pag中不邻接，意味着没有学习
                return False
            if self.pag.has_edge(a, b) and \
                (self.pag.get_Edge(a, b).lmark == Mark.CIRCLE or self.pag.get_Edge(a, b).rmark == Mark.CIRCLE):
                return False
            
        return True


    def stop_two(self, target_node: Node, donelist: list[Node]) -> bool:

        def stop_rule_three(node, done=None, maxdepth=5, depth=1) -> bool:
            if depth > maxdepth:
                return True
            if done is None:
                done = set()

            done.add(node)
            adj_nodes = self.pag.get_adj_nodes(node) - done

            if len(adj_nodes) == 0:
                return node in donelist
            
            for adj_n in adj_nodes:
                if self.pag.get_Edge(node, adj_n).rmark != Mark.ARROW:
                    if not stop_rule_three(adj_n, done, maxdepth, depth + 1):
                        return False
            return True
        

        mb_plus_set = self.get_mb(target_node) | {target_node}
        for a, b in combinations(mb_plus_set, 2):
            if (not self.pag.has_edge(a, b)) and (not self.sepsets.has_sepset(a, b)): 
                return False
            if self.sepsets.has_sepset(a, b):
                continue
            if self.pag.has_edge(a, b):
                edge = self.pag.get_Edge(a, b)
                if edge.lmark != Mark.CIRCLE and edge.rmark != Mark.CIRCLE:
                    continue

                if edge.lmark == Mark.CIRCLE and edge.rmark == Mark.ARROW:
                    if not stop_rule_three(b, done={a}):
                        return False
                elif edge.lmark == Mark.ARROW and edge.rmark == Mark.CIRCLE:
                    if not stop_rule_three(a, done={b}):
                        return False
                else: 
                    if not stop_rule_three(a):
                        return False
        return True

    def robust_stop_rule_early(self, target_node, donelist) -> bool:

        mb_plus_set = self.get_mb(target_node) | {target_node}
        if all(v in donelist for v in mb_plus_set) and len(donelist) > len(mb_plus_set)*2:
            return True
        return False



    def mb_strcuture_learner(self, target: str, verbose=False) -> MixGraph:
        
        target_node = self.Nodes_dict[target]
        waitlist = [target_node]
        donelist = []
        while (waitlist):
            node = waitlist.pop(0)
            mbplus_node_list = list(self.get_mb(node) | {node})

            skeleton = self.learn_skeleton_base_adj(mbplus_node_list)
            part_directed_graph = self.orient_collider(skeleton)
            sub_pag = self.orient_rules(part_directed_graph)
            self.update_subpag_to_pag(sub_pag, node)

            donelist.append(node)
            for n in list(sub_pag.get_adj_nodes(node)):
                if n not in donelist and n not in waitlist:
                    waitlist.append(n)

            if verbose:
                print(f'-----testing--rules---')
            if self.stop_one(target_node) or self.stop_two(target_node, donelist) or self.robust_stop_rule_early(target_node, donelist):
                if verbose:
                    print(f'-----testing--rules--get-')
                break
            else:
                if verbose:
                    print(f'-----testing--rules--not-get--')
        
        return self.pag
  

    def Pastar_learner(self, target: str) -> set[str]:
        target_node = self.Nodes_dict[target]
        mb_pag = self.mb_strcuture_learner(target)
        pre_set = mb_pag.get_into_nodes(target_node)
        pastar = pre_set.copy()

        candidate_set = self.get_mb(target_node) - mb_pag.get_adj_nodes(target_node)
        updated = True
        while updated:
            updated = False
            for cand_node in list(candidate_set):
                for size in range(min(len(pre_set), self.max_K +1)):
                    ind_temp = False
                    for sepset in combinations(pre_set, size):
                        if self.ci_test(cand_node.name, target_node.name, list(node.name for node in sepset))[0]:
                            pre_set.add(cand_node)
                            candidate_set.remove(cand_node)
                            ind_temp = True
                            updated = True
                            break
                    if ind_temp:
                        break
        
        candidate_set = list(pre_set - pastar)
        for node in candidate_set:
            arrow_collider_path = mb_pag.get_all_arrow_collider_paths(source=target_node, end=node)
            if len(arrow_collider_path) == 0:
                continue
            for path in arrow_collider_path:
                if all(v in pre_set for v in path[2:-1]):
                    pastar.add(node)
                    break
        pastar = sorted(pastar, key=lambda x: x.index)
        return {'Pastar': [node.name for node in pastar], 'PAG.MixGraph': mb_pag}

    def orient_collider(self, undirected_graph: MixGraph) -> MixGraph:

        Cand_triplets = undirected_graph.find_unique_triplets()
        for (z, y, x) in Cand_triplets:
            if self.sepsets.has_sepset(x, z): 
                if not self.sepsets.is_in_sepset(target=y, node1=x, node2=z):
                    undirected_graph.update_Edge(node1=x, lmark=None, rmark=Mark.ARROW, node2=y)
                    undirected_graph.update_Edge(node1=z, lmark=None, rmark=Mark.ARROW, node2=y)
                    logger.info(f"Orienting collider: {x} *-> {y} <-* {z}")

        return undirected_graph
    
    def orient_rules(self, pag: MixGraph) -> MixGraph:

        """
        Apply the orientation rules to the PAG.
        """

        update_flag = True
        while update_flag:  
            update_flag = False
            pag, update_flag = self.Rule_1(pag, update_flag)
            pag, update_flag = self.Rule_2(pag, update_flag)
            pag, update_flag = self.Rule_3(pag, update_flag)
            pag, update_flag = self.Rule_4(pag, update_flag)


        update_flag = True
        while update_flag:  
            update_flag = False
            pag, update_flag = self.Rule_8(pag, update_flag)
            pag, update_flag = self.Rule_9(pag, update_flag)
            pag, update_flag = self.Rule_10(pag, update_flag)

        return pag


    def Rule_1(self, pag: MixGraph, update_flag: bool) -> MixGraph:
        """
        If a *-> b o-* r, and a and r are not adjacent, then orient the triple as a *-> b -> r.
        """
        for b, r in pag.get_circ_star_Edge(): 
            for a in pag.get_into_nodes(b):
                if self.sepsets.has_sepset(a, r) and \
                    self.sepsets.is_in_sepset(target=b, node1=a, node2=r):
                        pag.update_Edge(node1=b, lmark=Mark.TAIL, rmark=Mark.ARROW, node2=r)
                        update_flag = True
                        logger.info(f"Orienting Rule 1: {b} o-* {r} to {b} --> {r}")
                        break 

        return pag, update_flag        

    def Rule_2(self, pag: MixGraph, update_flag: bool) -> MixGraph:
        """
        If a -> b *-> r or a *-> b -> r, and a *–○ r, then orient a *–○ r as a *-> r.
        """
        for r, a in pag.get_circ_star_Edge():  
            for b in pag.get_into_nodes(r): 
                if pag.has_directed_Edge(a, b) or \
                    (pag.has_into_Edge(a, b) and pag.has_out_Edge(b, r)):
                    pag.update_Edge(node1=a, lmark=None, rmark=Mark.ARROW, node2=r)
                    update_flag = True
                    logger.info(f"Orienting Rule 2: {a} *-o {r} to {a} *-> {r}")
                    break
        return pag, update_flag

    def Rule_3(self, pag: MixGraph, update_flag: bool) -> MixGraph:
        """
        If a *-> b <-* r, a *-o t o-* r, a and r are not adjacent, and t *-o b,
        then orient t *-o b as t *-> b.
        """
        for b, t in pag.get_circ_star_Edge():  
            Cand_ar = pag.get_into_nodes(b)   
            Cand_ar = {a for a in Cand_ar if pag.has_into_Edge(a, b)}  
            if len(Cand_ar) >= 2:
                for a, r in combinations(Cand_ar, 2):
                    if self.sepsets.has_sepset(a, r) and \
                        self.sepsets.is_in_sepset(target=t, node1=a, node2=r):  
                            pag.update_Edge(node1=t, lmark=None, rmark=Mark.ARROW, node2=b)
                            update_flag = True
                            logger.info(f"Orienting Rule 3: {t} *-o {b} to {t} *-> {b}")
                            break
        return pag, update_flag

    def updateList(self, path, new_ts, old_path_list): 
        """
        Update the list of paths by adding new paths formed with elements from the given set.
        """
        return old_path_list + [path + [t] for t in new_ts]

    def minDiscrPath(self, a: Node, b: Node, r: Node, pag: MixGraph) -> Optional[List[Node]]:
        """
        Find the minimal discriminating path between two nodes given a third node.
        We had a path a <-* b o-* r and a -> r, then we need to find the minDiscrPath for Rule4 in Zhang 2008.
        Parameters:
        - a: The first node.
        - b: The second node.
        - r: The third node.


        Returns:
        - A list of nodes representing the minimal discriminating path, or None if no such path exists.
        """
        Cand_ts = pag.get_into_nodes(a)  
        visited = {a, b, r}  
        Cand_ts = Cand_ts - visited  
        if len(Cand_ts) == 0:
            return None
        
        list_paths = self.updateList([a], Cand_ts, [])  

        while list_paths:
            path = list_paths.pop(0)
            cand_t = path[-1]  

            if self.sepsets.has_sepset(cand_t, r):  
                return path[:: -1] + [b, r]  

            pred_t = path[-2]
            visited.add(cand_t)  

            if pag.has_directed_Edge(cand_t, r) and pag.has_into_Edge(pred_t, cand_t): 
                Cand_ts = pag.get_into_nodes(cand_t) - visited  
                if len(Cand_ts) > 0:
                    list_paths = self.updateList(path, Cand_ts, list_paths)

        return None

    def Rule_4(self, pag: MixGraph, update_flag: bool) -> MixGraph:
        """
        If u = <t, ..., a, b, r> is a discriminating path between t and r for b, and b o-* r;
        then if b ∈ SepSet(t, r), orient b o-* r as b -> r; otherwise orient the triple <a, b, r> as a <-> b <-> r.
        """
        for b, r in pag.get_circ_star_Edge():  
            Cand_as = pag.get_parents(r)
            Cand_as = {a for a in Cand_as if pag.has_into_Edge(b, a)}
            while len(Cand_as) > 0:
                a = Cand_as.pop()  
                md_path = self.minDiscrPath(a, b, r, pag)
                if md_path is not None:
                    t = md_path[0]
                    if self.sepsets.is_in_sepset(target=b, node1=t, node2=r):
                        pag.update_Edge(node1=b, lmark=Mark.TAIL, rmark=Mark.ARROW, node2=r)
                        
                        logger.info(f"Orienting Rule 4: {b} o-* {r} to {b} -> {r}")
                    else:
                        pag.update_Edge(node1=a, lmark=Mark.ARROW, rmark=Mark.ARROW, node2=b)
                        pag.update_Edge(node1=b, lmark=Mark.ARROW, rmark=Mark.ARROW, node2=r)
                        logger.info(f"Orienting Rule 4: {a} <-> {b} <-> {r}")
                    update_flag = True
                    break  
        return pag, update_flag

    def minUncovCircPath(self, path, pag: MixGraph) -> Optional[List[Node]]:
        """
        Find a minimal uncovered circle path starting from the given path([a, r, ..., t, b]).
        Parameters:
            path: [a, r, t, b] under interest, such that r o-o a o-o b o-o t and a, t are not adjacent, b, r are not adjacent.
        """
        
        a = path[0]
        r = path[1]
        t = path[2]
        b = path[3]
        Cand_xs = pag.get_nondirect_adj_nodes(r) 
        visited = {r, a, b, t}
        Cand_xs = Cand_xs - visited  
        if len(Cand_xs) == 0:
            return None

        list_paths = self.updateList([r], Cand_xs, [])  

        while list_paths:
            path = list_paths.pop(0)
            cand_x = path[-1]
            visited.add(cand_x)
            if pag.has_circ_circ_Edge(cand_x, t): 
                mpath = [a] + path + [t, b]  
                if self.is_uncovered_path(mpath):
                    return mpath
            else:
                Cand_xis = pag.get_nondirect_adj_nodes(cand_x) 
                Cand_xis = Cand_xis - visited  
                if len(Cand_xis) > 0:
                    list_paths = self.updateList(path, Cand_xis, list_paths)
        return None

    def Rule_8(self, pag: MixGraph, update_flag: bool) -> MixGraph:
        """
        If a -> b -> r or a -o b -> r, and a o-> r, orient a o-> r as a -> r.
        """
        for a, r in pag.get_circ_arrow_Edge():
            for b in pag.get_parents(r):
                if pag.has_directed_Edge(a, b) or pag.has_tail_circ_Edge(a, b):
                    pag.update_Edge(node1=a, lmark=Mark.TAIL, rmark=Mark.ARROW, node2=r)
                    update_flag = True
                    logger.info(f"Orienting Rule 8: {a} o-> {r} to {a} -> {r}")
                    break
        return pag, update_flag
    
    def is_uncovered_path(self, path: List[Node]) -> bool:
        """
        Check if the given path is uncovered by sepsets for 'minUncovPdPath' and 'minUncovCircPath' in orient rules.
        In local learning, direct adjacency between nodes may not be fully determined. Therefore, we use self.is_uncovered_path (which relies on separation sets) to check for uncovered paths in orientation rules, rather than MixGraph.is_uncovered_path, to ensure correctness in these cases.
        """
        for i in range(1, len(path) - 1):
            if not (self.sepsets.has_sepset(path[i - 1], path[i + 1]) and self.sepsets.is_in_sepset(target=path[i], node1=path[i - 1], node2=path[i + 1])):
                return False
        return True

    def minUncovPdPath(self, a: Node, b: Node, r: Node, pag: MixGraph) -> Optional[List[Node]]:
        """
            Find a minimal uncovered pd path from initial (a,b,r)
        """
        Cand_ts = pag.get_pd_path_nodes(b) 
        visited = {a, b, r}  
        Cand_ts = Cand_ts - visited  
        if len(Cand_ts) == 0:
            return None

        list_paths = self.updateList([b], Cand_ts, [])  

        while list_paths:
            path = list_paths.pop(0)
            cand_t = path[-1]  
            visited.add(cand_t)  
            tr_edge = pag.get_Edge(cand_t, r)
            if tr_edge is not None and \
                tr_edge.lmark != Mark.ARROW and tr_edge.rmark != Mark.TAIL:
                    mpath = [a] + path + [r] 
                    if self.is_uncovered_path(mpath):
                        return mpath

            else:
                Cand_tis = pag.get_pd_path_nodes(cand_t)
                Cand_tis = Cand_tis - visited  
                if len(Cand_tis) > 0:
                    list_paths = self.updateList(path, Cand_tis, list_paths)

        return None

    def Rule_9(self, pag: MixGraph, update_flag: bool) -> MixGraph:
        """
        If a o-> r, and p = <a, b, ..., r> is an uncovered potentially directed path from a to r such that r and b are not adjacent, then orient a o-> r as a -> r.
        """
        for a, r in pag.get_circ_arrow_Edge():
            
            Cand_bs = pag.get_pd_path_nodes(a) 
            Cand_bs = {b for b in Cand_bs if not pag.has_edge(b, r)}  
            Cand_bs.remove(r)  
            while len(Cand_bs) > 0:
                b = Cand_bs.pop()
                if not (self.sepsets.has_sepset(b, r) and \
                    self.sepsets.is_in_sepset(target=a, node1=b, node2=r)): 

                    continue
                upd_path = self.minUncovPdPath(a, b, r, pag)
                if upd_path is not None:
                    pag.update_Edge(node1=a, lmark=Mark.TAIL, rmark=Mark.ARROW, node2=r)
                    update_flag = True
                    logger.info(f"Orienting Rule 9: {a} o-> {r} to {a} -> {r} via {b}")
                    break

        return pag, update_flag

    def Rule_10(self, pag: MixGraph, update_flag: bool) -> MixGraph:

        """
        R10: Suppose a o-> r, b -> r <- t, p1 is an uncovered p.d. path from a to b, and p2 is an uncovered p.d. path from a to t.
        Let u be the vertex adjacent to a on p1 (u could be b), and w be the vertex adjacent to a on p2 (w could be t).
        If u and w are distinct, and are not adjacent, then orient a o-> r as a -> r.
        """
        for a, r in pag.get_circ_arrow_Edge():
            pa_r = pag.get_parents(r)  
            if len(pa_r) < 2:
                continue
            for b, t in combinations(pa_r, 2):
               
                if pag.has_pd_Edge(a, b) and pag.has_pd_Edge(a, t) and \
                    self.sepsets.has_sepset(b, t) and \
                    self.sepsets.is_in_sepset(target=a, node1=b, node2=t):
                    pag.update_Edge(node1=a, lmark=Mark.TAIL, rmark=Mark.ARROW, node2=r)
                    update_flag = True
                    logger.info(f"Orienting Rule 10: {a} o-> {r} to {a} -> {r} via {b} and {t}")
                    break  
                else:
                    Cand_uw = pag.get_pd_path_nodes(a)  
                    Cand_uw.remove(r) 
                    if len(Cand_uw) < 2:
                        continue
                    for u, w in combinations(Cand_uw, 2):
                       
                        if pag.has_edge(u, w):
                            continue  

                        if u == b:
                            p1 = [a, b]
                        else:
                            p1 = self.minUncovPdPath(a, u, b, pag)
                        if p1 is not None:
                            if w == t:
                                p2 = [a, t]
                            else:
                                p2 = self.minUncovPdPath(a, w, t, pag)

                            if p2 is not None and u!=w and \
                                self.sepsets.has_sepset(u, w) and \
                                self.sepsets.is_in_sepset(target=a, node1=u, node2=w):
                                    pag.update_Edge(node1=a, lmark=Mark.TAIL, rmark=Mark.ARROW, node2=r)
                                    update_flag = True
                                    logger.info(f"Orienting Rule 10: {a} o-> {r} to {a} -> {r} via {b} and {t}")
                                    break
                    if pag.has_directed_Edge(a, r):
                        break  

        return pag, update_flag


