#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@author: Haoyue
@file: MAG_tools.py
@time: 9/21/24 05:14
@desc:
"""

import networkx as nx
from itertools import chain, combinations, permutations, product
def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))

AROW, DASH, CIRC = 'AROW', 'DASH', 'CIRC'
LEFT, RIGHT = 'LEFT', 'RIGHT'

def get_PAG_from_skeleton_and_sepsets(
    nodelist,                      # enumerate of nodes
    skeleton_edges,                # enumerate of tuples of nodes; no need to be symmetric
    sepsets,                       # dict{(node1, node2): set_of_nodes}; assert all not in skeleton_edges should have a sepset
    sure_direct_edges=None,
    sure_no_latents=False
):
    assert set().union(*skeleton_edges) <= set(nodelist)
    assert all(set(k) | set(v) <= set(nodelist) for k, v in sepsets.items())
    if sure_direct_edges is not None: assert set().union(*sure_direct_edges) <= set(nodelist)

    ALLNODES = set(nodelist)
    CURREDGES, SEPSETS = {}, {}
    for x, y in skeleton_edges: CURREDGES[(x, y)] = CURREDGES[(y, x)] = (CIRC, CIRC)
    for (node1, node2), Z in sepsets.items(): SEPSETS[(node1, node2)] = SEPSETS[(node2, node1)] = set(Z)
    assert len(set(CURREDGES.keys()) & set(SEPSETS.keys())) == 0
    assert set(CURREDGES.keys()) | set(SEPSETS.keys()) == {(x, y) for x, y in product(nodelist, nodelist) if x != y}

    def get_curr_edge_type(node1, node2, end=LEFT):
        if (node1, node2) not in CURREDGES: return False
        if end == LEFT:
            return CURREDGES[(node1, node2)][0]
        elif end == RIGHT:
            return CURREDGES[(node1, node2)][1]
        assert False

    def update_edge(node1, node2, type1, type2):
        if type1 is None: type1 = CURREDGES[(node1, node2)][0]
        if type2 is None: type2 = CURREDGES[(node1, node2)][1]
        CURREDGES[(node1, node2)] = (type1, type2)
        CURREDGES[(node2, node1)] = (type2, type1)

    def _RO():
        # print('RO...')
        for alpha, beta in combinations(ALLNODES, 2):  # safe here; it's symmetric
            if (alpha, beta) in CURREDGES: continue
            for gamma in ALLNODES - {alpha, beta}:
                if (alpha, gamma) in CURREDGES and (beta, gamma) in CURREDGES:
                    gamma_not_in_sepset = gamma not in SEPSETS[(alpha, beta)]
                    if gamma_not_in_sepset:
                        update_edge(alpha, gamma, None, AROW)
                        update_edge(beta, gamma, None, AROW)

    def _R1():
        # print('R1...')
        # If α∗→ β◦−−∗ γ, and α and γ are not adjacent, then orient the triple as α∗→ β →γ
        changed_something = False
        for alpha in ALLNODES:
            for beta in [bt for bt in ALLNODES - {alpha} if get_curr_edge_type(alpha, bt, RIGHT) == AROW]:
                for gamma in [gm for gm in ALLNODES - {alpha, beta} if
                              get_curr_edge_type(beta, gm, LEFT) == CIRC and (alpha, gm) not in CURREDGES]:
                    update_edge(beta, gamma, DASH, AROW)
                    changed_something = True
        return changed_something

    def _R2():
        # print('R2...')
        # If α→β∗→ γ or α∗→ β →γ, and α ∗−◦ γ,then orient α ∗−◦ γ as α∗→γ.
        changed_something = False
        for alpha in ALLNODES:
            for beta in [bt for bt in ALLNODES - {alpha} if
                         get_curr_edge_type(alpha, bt, RIGHT) == AROW]:
                for gamma in [gm for gm in ALLNODES - {alpha, beta} if
                              get_curr_edge_type(beta, gm, RIGHT) == AROW]:
                    if get_curr_edge_type(alpha, gamma, RIGHT) == CIRC:
                        if get_curr_edge_type(alpha, beta, LEFT) == DASH or \
                           get_curr_edge_type(beta, gamma, LEFT) == DASH:
                            changed_something = True
                            update_edge(alpha, gamma, None, AROW)
        return changed_something

    def _R3():
        # print('R3...')
        # If α∗→ β ←∗γ, α ∗−◦ θ ◦−∗ γ, α and γ are not adjacent, and θ ∗−◦ β, then orient θ ∗−◦ β as θ∗→ β.
        changed_something = False
        for alpha, gamma in combinations(ALLNODES, 2):  # safe here; it's symmetric
            if (alpha, gamma) in CURREDGES: continue
            for beta in [bt for bt in ALLNODES - {alpha, gamma} if
                         get_curr_edge_type(alpha, bt, RIGHT) == AROW and \
                         get_curr_edge_type(bt, gamma, LEFT) == AROW]:
                for theta in [th for th in ALLNODES - {alpha, beta, gamma} if
                         get_curr_edge_type(alpha, th, RIGHT) == CIRC and \
                         get_curr_edge_type(th, gamma, LEFT) == CIRC]:
                    if get_curr_edge_type(theta, beta, RIGHT) == CIRC:
                        changed_something = True
                        update_edge(theta, beta, None, AROW)
        return changed_something

    def _R4():
        # print('R4...')
        # If u = <θ, ...,α,β,γ> is a discriminating path between θ and γ for β, and β◦−−∗γ;
        # then if β ∈ Sepset(θ,γ), orient β◦−−∗ γ as β →γ; otherwise orient the triple <α,β,γ> as α ↔β ↔γ.
        changed_something = False
        for theta in ALLNODES:
            for gamma in ALLNODES - {theta}:
                if (theta, gamma) in CURREDGES: continue
                for beta in {bt for bt in ALLNODES - {theta, gamma} if
                             get_curr_edge_type(bt, gamma, LEFT) == CIRC}:
                    gamma_parents = {af for af in ALLNODES - {theta, gamma, beta} if \
                                     get_curr_edge_type(af, gamma, LEFT) == DASH and
                                     get_curr_edge_type(af, gamma, RIGHT) == AROW}
                    if len(gamma_parents) < 1: continue
                    # to prevent from nx.all_simple_paths(self.mag_undirected_graph, ..) (too slow) we use subgraph to only allow paths through gamma_parents
                    subgraph = nx.Graph()   # undirected
                    subgraph.add_nodes_from(gamma_parents | {theta, beta})
                    subgraph.add_edges_from([(x, y) for x, y in combinations(gamma_parents | {theta, beta}, 2) if (x, y) in CURREDGES])
                    for theta_beta_path in nx.all_simple_paths(subgraph, theta, beta):
                        if len(theta_beta_path) < 3: continue
                        path = theta_beta_path + [gamma]
                        if all(
                                get_curr_edge_type(path[i - 1], path[i], RIGHT) == AROW and
                                get_curr_edge_type(path[i], path[i + 1], LEFT) == AROW
                                for i in range(1, len(path) - 2)
                        ):
                            changed_something = True
                            # print('    R4 applied!')
                            if beta in SEPSETS[(theta, gamma)]:
                                update_edge(beta, gamma, DASH, AROW)
                            else:
                                update_edge(path[-3], beta, AROW, AROW)
                                update_edge(beta, gamma, AROW, AROW)
        return changed_something

    def _R_no_latents():
        # when we are sure that there are no latents, we confirm all ⚬-> as ->
        changed_something = False
        for (node1, node2), (type1, type2) in CURREDGES.items():
            if type1 == CIRC and type2 == AROW:
                update_edge(node1, node2, DASH, AROW)
                changed_something = True
        return changed_something

    # ============================= main part ======================================
    # first apply background knowledge (for now we dont do consistency check; just trust it)
    if sure_direct_edges is not None:
        for node1, node2 in sure_direct_edges:
            CURREDGES[(node1, node2)] = (DASH, AROW)
            CURREDGES[(node2, node1)] = (AROW, DASH)

    # then fix the unshielded triples using observed CIs
    _RO()

    # then iteratively apply the rules until no more changes
    RULES = [_R1, _R2, _R3, _R4]
    if sure_no_latents: RULES.append(_R_no_latents)
    while True:
        changed_something = False
        for rule in RULES:
            changed_something |= rule()
        if not changed_something:
            break

    # finally, return the result
    pag_edges = {'->': set(), '<->': set(), '--': set(),
                      '⚬--': set(), '⚬->': set(), '⚬-⚬': set()}
    for (node1, node2), (type1, type2) in CURREDGES.items():
        if type1 == DASH and type2 == AROW:
            pag_edges['->'].add((node1, node2))
        elif type1 == AROW and type2 == AROW:
            pag_edges['<->'].add((node1, node2))  # has symmetric repeats
        elif type1 == DASH and type2 == DASH:
            pag_edges['--'].add((node1, node2))  # has symmetric repeats
        elif type1 == CIRC and type2 == DASH:
            pag_edges['⚬--'].add((node1, node2))
        elif type1 == CIRC and type2 == AROW:
            pag_edges['⚬->'].add((node1, node2))
        elif type1 == CIRC and type2 == CIRC:
            pag_edges['⚬-⚬'].add((node1, node2))  # has symmetric repeats

    return pag_edges

def get_skeleton_and_sepsets(nodelist, CI_tester):
    ALLNODES = set(nodelist)
    curr_skeleton = [(min(i, j), max(i, j)) for i, j in combinations(ALLNODES, 2)]
    curr_neighbors = {i: set(ALLNODES) - {i} for i in ALLNODES}
    Sepsets = {}
    l = -1
    while True:
        l += 1
        found_something = False
        visited_pairs = set()
        while True:
            this_i, this_j = None, None
            for i, j in curr_skeleton:
                if (i, j) in visited_pairs: continue
                assert j in curr_neighbors[i]
                if len(curr_neighbors[i]) - 1 >= l or len(curr_neighbors[j]) - 1 >= l:
                    this_i, this_j = i, j
                    found_something = True
                    break
            if this_i is None: break
            visited_pairs.add((this_i, this_j))
            choose_subset_from = set(map(frozenset, combinations(curr_neighbors[this_i] - {this_j}, l))) | \
                                 set(map(frozenset, combinations(curr_neighbors[this_j] - {this_i}, l)))
            for subset in choose_subset_from:
                if CI_tester(this_i, this_j, subset):
                    curr_skeleton.remove((this_i, this_j))
                    curr_neighbors[this_i].remove(this_j)
                    curr_neighbors[this_j].remove(this_i)
                    Sepsets[(this_i, this_j)] = Sepsets[(this_j, this_i)] = set(subset)
                    break
        if not found_something: break
    return set(curr_skeleton), Sepsets

class MAG(object):

    def __init__(self, nxDiG, observed_nodes, latent_nodes, selected_nodes):
        assert set(observed_nodes) | set(latent_nodes) | set(selected_nodes) == set(nxDiG.nodes)
        assert set(observed_nodes) & set(latent_nodes) == set() and set(observed_nodes) & set(
            selected_nodes) == set() and set(latent_nodes) & set(selected_nodes) == set()
        self.observed_nodes = set(observed_nodes)
        self.latent_nodes = set(latent_nodes)
        self.selected_nodes = set(selected_nodes)

        self.dag = nxDiG
        assert nx.is_directed_acyclic_graph(self.dag)
        self.dag_parents = {i: set(self.dag.predecessors(i)) for i in self.dag.nodes}
        self.dag_children = {i: set(self.dag.successors(i)) for i in self.dag.nodes}
        self.dag_ancestors = {i: nx.ancestors(self.dag, i) | {i} for i in self.dag.nodes}  # including itself
        self.dag_descendants = {i: nx.descendants(self.dag, i) | {i} for i in self.dag.nodes}  # including itself

        self.init_MAG()

    def oracle_ci_with_selection(self, x, y, Z=None, allow_access_latents=False):
        if not hasattr(self, 'CI_cache'): self.CI_cache = {}
        Z = set() if Z is None else set(Z)
        assert x != y and {x, y} & Z == set()
        if not allow_access_latents: assert ({x, y} | Z) <= self.observed_nodes
        # allow_access_latents is a sanity check (usually only CIs among observd can be queried); but sometimes
        #    due to deterministic relations, condition on Xi also conditioned on Xi*, so we allow this backdoor call.
        x, y = min(x, y), max(x, y)
        cachekey = (x, y, frozenset(Z))
        if cachekey not in self.CI_cache:
            # BE SURE TO INCLUDE ALL THE SELECTED NODES HERE!
            self.CI_cache[cachekey] = nx.d_separated(self.dag, {x}, {y}, set(Z) | self.selected_nodes)
        return self.CI_cache[cachekey]

    def is_m_separated(self, x, y, Z=None):
        # confirmed: it's equivalent to oracle_ci_with_selection. But prevent from using it (too slow in finding paths)
        Z = set() if Z is None else set(Z)
        assert x != y and {x, y} & Z == set() and ({x, y} | Z) <= self.observed_nodes
        for path in nx.all_simple_paths(self.mag_undirected_graph, x, y):
            colliders_on_path = {path[i] for i in range(1, len(path) - 1) if
                                 {path[i - 1], path[i + 1]} <= self.mag_parents[path[i]] | self.mag_spouses[path[i]]}
            noncolliders_on_path = {path[i] for i in range(1, len(path) - 1) if path[i] not in colliders_on_path}
            if len(noncolliders_on_path & set(Z)) == 0 and all(
                    len(self.mag_descendants[c] & set(Z)) > 0 for c in colliders_on_path):
                return False
        return True


    def init_MAG(self):
        def _exists_inducing_path(x, y):
            # This is correct but we want to prevent using it: too slow!!
            for path in nx.all_simple_paths(self.dag.to_undirected(), x, y):
                observed_and_selected_nodes_on_path = {path[i] for i in range(1, len(path) - 1) if
                                                       path[i] in self.observed_nodes | self.selected_nodes}
                colliders_on_path = {path[i] for i in range(1, len(path) - 1) if
                                     {path[i - 1], path[i + 1]} <= self.dag_parents[path[i]]}
                ancestors_of_x_y_and_S = (self.dag_ancestors[x] | self.dag_ancestors[y]).union(
                    *[self.dag_ancestors[s] for s in self.selected_nodes])
                if observed_and_selected_nodes_on_path <= colliders_on_path and colliders_on_path <= ancestors_of_x_y_and_S:
                    return True
            return False

        # Since checking for inducing path (nx.all_simple_paths) is too slow, we run PC phase 1 to get adjacencies.
        self.skeleton_edges_in_mag, self.Sepset_cache = get_skeleton_and_sepsets(self.observed_nodes, self.oracle_ci_with_selection)
        # adjacencies_found_by_inducing_path = {tuple(sorted([x, y]))
        #                 for x, y in combinations(self.observed_nodes, 2) if _exists_inducing_path(x, y)}
        # assert adjacencies_found_by_inducing_path == self.skeleton_edges_in_mag

        self.mag_edges = {'->': set(), '<->': set(), '--': set()}
        for x, y in self.skeleton_edges_in_mag:   # set of tuples (i, j) where i < j
            x_causes_y_or_selection = x in self.dag_ancestors[y].union(
                *[self.dag_ancestors[s] for s in self.selected_nodes])
            y_causes_x_or_selection = y in self.dag_ancestors[x].union(
                *[self.dag_ancestors[s] for s in self.selected_nodes])
            if x_causes_y_or_selection and not y_causes_x_or_selection:
                self.mag_edges['->'].add((x, y))
            elif not x_causes_y_or_selection and y_causes_x_or_selection:
                self.mag_edges['->'].add((y, x))
            elif not x_causes_y_or_selection and not y_causes_x_or_selection:
                self.mag_edges['<->'] |= {(x, y), (y, x)}
            else:
                self.mag_edges['--'] |= {(x, y), (y, x)}

        self.mag_undirected_graph = nx.Graph()
        self.mag_undirected_graph.add_nodes_from(self.observed_nodes)
        self.mag_undirected_graph.add_edges_from(self.skeleton_edges_in_mag)
        self.mag_only_directed_edges_graph = nx.DiGraph()
        self.mag_only_directed_edges_graph.add_nodes_from(self.observed_nodes)
        self.mag_only_directed_edges_graph.add_edges_from(self.mag_edges['->'])
        self.mag_parents = {i: set(self.mag_only_directed_edges_graph.predecessors(i)) for i in self.observed_nodes}
        self.mag_children = {i: set(self.mag_only_directed_edges_graph.successors(i)) for i in self.observed_nodes}
        self.mag_ancestors = {i: nx.ancestors(self.mag_only_directed_edges_graph, i) | {i} for i in self.observed_nodes}
        self.mag_descendants = {i: nx.descendants(self.mag_only_directed_edges_graph, i) | {i} for i in
                                self.observed_nodes}
        self.mag_spouses = {i: {j for j in self.observed_nodes if (i, j) in self.mag_edges['<->']} for i in
                            self.observed_nodes}
        self.mag_neighbors = {i: {j for j in self.observed_nodes if (i, j) in self.mag_edges['--']} for i in
                              self.observed_nodes}



class Checker(object):
    def __init__(self, nodenum, edgelist, interv_targets, selection_parents):
        self.nodenum = nodenum
        self.edgelist = edgelist
        self.interv_targets = interv_targets  # ordered list of enumerates; can be empty(?)
        self.selection_parents = selection_parents  # ordered list of enumerates; can be empty
        self.num_of_selections = len(self.selection_parents)
        self.num_of_intervs = len(self.interv_targets)

        self.original_G = nx.DiGraph()
        self.original_G.add_nodes_from(
            list(range(1, self.nodenum + 1)) + [f'S{i}' for i in range(1, len(self.selection_parents) + 1)])
        self.original_G.add_edges_from(
            self.edgelist + [(j, f'S{i + 1}') for i, parents in enumerate(self.selection_parents) for j in parents])

        self.new_G_without_selection = nx.DiGraph()
        self.new_G_without_selection.add_nodes_from(
            list(range(1, self.nodenum + 1)) + [f'I{i + 1}' for i in range(len(self.interv_targets))])
        self.new_G_without_selection.add_edges_from(
            self.edgelist + [(f'I{i + 1}', j) for i, targets_in_config_i in enumerate(self.interv_targets) for j in
                             targets_in_config_i])

        self.twin_net_G = nx.DiGraph()
        self.twin_net_G.add_nodes_from([f'X{i + 1}' for i in range(self.nodenum)] +
                                       [f'X*{i + 1}' for i in range(self.nodenum)] +
                                       [f'S*{i + 1}' for i in range(len(self.selection_parents))] +
                                       [f'E{i + 1}' for i in range(self.nodenum)] +
                                       [f'I{i + 1}' for i in range(len(self.interv_targets))])
        self.twin_net_G.add_edges_from([(f'X{i}', f'X{j}') for i, j in self.edgelist])
        self.twin_net_G.add_edges_from([(f'X*{i}', f'X*{j}') for i, j in self.edgelist])
        self.twin_net_G.add_edges_from(
            [(f'X*{j}', f'S*{i + 1}') for i, parents in enumerate(self.selection_parents) for j in parents])
        self.twin_net_G.add_edges_from([(f'E{i}', f'X{i}') for i in range(1, 1 + self.nodenum)])
        self.twin_net_G.add_edges_from([(f'E{i}', f'X*{i}') for i in range(1, 1 + self.nodenum)])
        self.twin_net_G.add_edges_from(
            [(f'I{i + 1}', f'X{j}') for i, targets_in_config_i in enumerate(self.interv_targets) for j in
             targets_in_config_i])

        self.MAG_of_original_G = MAG(self.original_G,
                                     observed_nodes=list(range(1, self.nodenum + 1)),
                                     latent_nodes=[],
                                     selected_nodes=[f'S{i}' for i in range(1, len(self.selection_parents) + 1)])
        self.MAG_of_twin_G = MAG(
            self.twin_net_G,
            observed_nodes=[f'X{i + 1}' for i in range(self.nodenum)] + [f'I{i + 1}' for i in
                                                                         range(len(self.interv_targets))],
            latent_nodes=[f'X*{i + 1}' for i in range(self.nodenum)] + [f'E{i + 1}' for i in range(self.nodenum)],
            selected_nodes=[f'S*{i + 1}' for i in range(len(self.selection_parents))])

    def is_CI_in_domain_k(self, x, y, Z, k):
        assert 0 <= k <= self.num_of_intervs
        assert x != y and len({x, y} & set(Z)) == 0

        # assume faithfulness in observ. domain
        if k == 0: return self.MAG_of_original_G.oracle_ci_with_selection(x, y, set(Z))

        Ik_targets = self.interv_targets[k - 1]
        varIDs_changed_by_intervention = {int(_) for _ in set().union(*[self.MAG_of_original_G.dag_descendants[ikt]
                                                                        for ikt in Ik_targets]) if
                                          not isinstance(_, str)}  # exclude S variables, just for sanity
        XZs = {f'X{z}' for z in Z}
        XZ_stars_accessable = {f'X*{z}' for z in
                               set(Z) - varIDs_changed_by_intervention}  # i.e., those not changed by intervention at all
        all_I_indices = {f'I{i + 1}' for i in
                         range(self.num_of_intervs)}  # condition on this: Ik=1, and all other Ikprimes=0
        return self.MAG_of_twin_G.oracle_ci_with_selection(
            f'X{x}', f'X{y}', XZs | XZ_stars_accessable | all_I_indices, allow_access_latents=True)

    def is_conditional_distribution_invariant_in_domain_pk_and_p0(self, x, Z, k):
        assert 1 <= k <= self.num_of_intervs
        assert not x in Z

        Ik_targets = self.interv_targets[k - 1]
        varIDs_changed_by_intervention = {int(_) for _ in set().union(*[self.MAG_of_original_G.dag_descendants[ikt]
                                                                        for ikt in Ik_targets]) if
                                          not isinstance(_, str)}  # exclude S variables, just for sanity
        XZs = {f'X{z}' for z in Z}
        XZ_stars_accessable = {f'X*{z}' for z in
                               set(Z) - varIDs_changed_by_intervention}  # i.e., those not changed by intervention at all
        all_I_other_indices = {f'I{kprime}' for kprime in range(1, 1 + self.num_of_intervs) if kprime != k}
        return self.MAG_of_twin_G.oracle_ci_with_selection(
            f'X{x}', f'I{k}', XZs | XZ_stars_accessable | all_I_other_indices, allow_access_latents=True)

    def run_algo(self):
        # ======= step 1: FCI on X on observational domain only. mainly to get skeleton and sepsets for vstrucs.
        #         the orientation rules are FCI's R0-R4, with additionally: all ⚬-> to -> (since we know there're no latents)
        self.pag_edges_from_observational = get_PAG_from_skeleton_and_sepsets(
            nodelist=self.MAG_of_original_G.observed_nodes,
            skeleton_edges=self.MAG_of_original_G.skeleton_edges_in_mag,
            sepsets=self.MAG_of_original_G.Sepset_cache,
            sure_no_latents=True)
        adjacencies_from_observational = self.MAG_of_original_G.skeleton_edges_in_mag  # (x,y) with x<y

        # ======= (the whole step 2 is technically just running FCI over I and X again; we decouple it to escape from bothering over Is)
        # ======= step 2.1: find the adjacencies between I and X, the (pseudo)-intervention targets of each domain.
        #         by "pseudo" we mean the adjacent Xs contains but not limited to the actual targets.
        adjacencies_between_I_and_X = set()
        Sepsets_in_Gtwin = {}
        for k, kprime in combinations(range(1, 1 + self.num_of_intervs), 2):
            Sepsets_in_Gtwin[(f'I{k}', f'I{kprime}')] = Sepsets_in_Gtwin[
                (f'I{kprime}', f'I{k}')] = set()  # ensure indices mutually indep.
        for k in range(1, 1 + self.num_of_intervs):
            for i in range(1, 1 + self.nodenum):
                # we search if there's a subset C s.t. p(Xi|XC) is invariant between domain p(0) and p(k)
                #   can show that such search on C can be restricted to the adjacencies of Xi in observational skeleton.
                i_adjacencies = {j for j in range(1, 1 + self.nodenum) if
                                 tuple(sorted((i, j))) in adjacencies_from_observational}
                found_invariant = False
                for C in powerset(i_adjacencies):
                    # to check whether p(Xi|XC) is invariant, we check whether Ik _||_ Xi | XC and other Ikprimes=0
                    #   if d-separation on twin G^I holds, by Markov property, the CI holds.
                    #   if not, will there also be unfaithfulness as in those among X?
                    #       the answer is no, as in this case [given XC* yields however dsep] => [Ik affects XC (so XC* cannot be given)]
                    # therefore, here we can safely use d-separation on twin G^I to check CI.
                    if self.is_conditional_distribution_invariant_in_domain_pk_and_p0(i, C, k):
                        found_invariant = True
                        Sepsets_in_Gtwin[(f'I{k}', f'X{i}')] = Sepsets_in_Gtwin[(f'X{i}', f'I{k}')] = \
                            {f'X{c}' for c in C} | {f'I{kprime}' for kprime in range(1, 1 + self.num_of_intervs) if
                                                    kprime != k}
                        break
                if not found_invariant:
                    adjacencies_between_I_and_X.add((f'I{k}', f'X{i}'))

        # ======= step 2.2: some additional adjacencies among X have to be added.
        #         though we know they are incorrect (as the correct ones are just those from step 1's observational domain),
        #            we have to add them to ensure the next step's FCI orientation rules are correct.
        #         basically it's because: an Xi_||_Xj | XC in p(0) may be dependent in p(k).
        #         [*]: Xi--Xj can be cut off, iff there exists a C, s.t. Xi_||_Xj | XC holds in all domains from p(0) to p(K).
        adjacencies_between_X_and_X = set()

        def all_tester(x, y, Z):
            return all(self.is_CI_in_domain_k(x, y, Z, k) for k in range(0, 1 + self.num_of_intervs))

        all_skeleton_full, all_sepsets = get_skeleton_and_sepsets(self.MAG_of_original_G.observed_nodes, all_tester)
        for x, y in set(combinations(range(1, 1 + self.nodenum), 2)):
            if (x, y) in all_skeleton_full:
                adjacencies_between_X_and_X.add((f'X{x}', f'X{y}'))
            else:
                Xz_and_all_Iks = {f'X{z}' for z in all_sepsets[(x, y)]} | {f'I{k}' for k in
                                                                           range(1, 1 + self.num_of_intervs)}
                Sepsets_in_Gtwin[(f'X{x}', f'X{y}')] = Sepsets_in_Gtwin[(f'X{y}', f'X{x}')] = Xz_and_all_Iks

        # ======= step 2.3: apply FCI rules on those adjacencies and sepsets, and known directions.
        directed_edges_from_observational = {(f'X{x}', f'X{y}') for x, y in self.pag_edges_from_observational['->']}
        self.pag_edges_from_interventional = get_PAG_from_skeleton_and_sepsets(
            nodelist=self.MAG_of_twin_G.observed_nodes,
            skeleton_edges=adjacencies_between_X_and_X | adjacencies_between_I_and_X,
            sepsets=Sepsets_in_Gtwin,
            sure_direct_edges=directed_edges_from_observational | adjacencies_between_I_and_X,
        )


        adjacencies_in_MAG = self.MAG_of_original_G.skeleton_edges_in_mag
        adjacencies_in_MAG |= {(y, x) for x, y in adjacencies_in_MAG}

        directed_edges, bidirected_edges, circdirected_edges, circcirc_edges, no_edges = set(), set(), set(), set(), set()

        for x, y in product(range(1, 1 + self.nodenum), repeat=2):
            if x == y: continue
            elif (x, y) not in adjacencies_in_MAG:
                no_edges.add((x, y))
            elif (f'X{x}', f'X{y}') in self.pag_edges_from_interventional['->']:
                directed_edges.add((x, y))
            elif (f'X{x}', f'X{y}') in self.pag_edges_from_interventional['<->']:
                bidirected_edges.add((x, y))
            elif (f'X{x}', f'X{y}') in self.pag_edges_from_interventional['⚬->']:
                circdirected_edges.add((x, y))
            elif (f'X{x}', f'X{y}') in self.pag_edges_from_interventional['⚬-⚬']:
                circcirc_edges.add((x, y))


        return {
            'NO': no_edges,
            '->': directed_edges,
            '<->': bidirected_edges,
            '⚬->': circdirected_edges,
            '⚬-⚬': circcirc_edges,
        }