import networkx as nx
import subprocess
import pyAgrum as gum
import graphical_models
import copy
import numpy as np
import matplotlib.pyplot as plt
#import graphtheory
from itertools import chain, combinations
import random
import os
from collections import deque
import pickle

import pandas as pd
from scipy.stats import chisquare
import causallearn
from causallearn.graph.GeneralGraph import GeneralGraph
from causallearn.graph.GraphNode import GraphNode
from causallearn.graph.Edge import Edge
from causallearn.graph.Endpoint import Endpoint

import numpy as np
import itertools

def random_complete_admg(n):
    # Step 1: Make a random complete DAG
    dag = random_complete_dag(n)
    # print("Random complete DAG:\n", dag)

    # Step 2: Add bidirected edges with probability 0.3
    admg_true = add_bidirected_edges(dag, n, p=0.5)

    return admg_true

def random_complete_dag(n):
    """
    Generate a random *complete* DAG on n nodes (0..n-1) using:
      1. Random permutation (topological order) of the nodes.
      2. For each pair (i, j), orient it in the direction of their order.
         i.e., if perm.index(i) < perm.index(j), we have i->j,
         else j->i.

    Return a dict:
      {
        'directed': [(u,v), ...]   # all oriented edges
      }
    so that for every distinct pair (i,j), exactly one directed edge is present.
    """
    # 1) Create a random permutation of nodes: topological order
    nodes = list(range(n))
    random.shuffle(nodes)  # in-place shuffle

    # 2) For each pair (i, j), if nodes.index(i) < nodes.index(j), i->j; else j->i
    directed_edges = []
    # We can do this more efficiently by tracking the positions in a dict
    pos = {nodes[i]: i for i in range(n)}  # pos[x] = index of x in the permutation

    for i in range(n):
        for j in range(i + 1, n):
            # Compare pos[i], pos[j] if we want them as actual node labels i,j
            # But we must be careful: "i" is not the label in the perm, it's the index in [0..n-1].
            # Instead, let's just do "for each pair (x,y) in the original set of node labels"
            pass

    # It's clearer to do it with the node labels directly:
    directed_edges = []
    for a in range(n):
        for b in range(a + 1, n):
            # Compare positions
            if pos[a] < pos[b]:
                directed_edges.append((a, b))  # a->b
            else:
                directed_edges.append((b, a))  # b->a

    return {
        'directed': directed_edges
    }

def add_bidirected_edges(dag, n, p=0.5):
    """
    Take a *complete DAG* in dict form, i.e. {'directed': [...]}
    plus the node count n,
    then randomly add bidirected edges for each unordered pair (i<j),
    each with probability p.

    Return a dict:
      {
        'directed':   dag['directed'],
        'bidirected': [(x,y), ...]
      }
    where each bidirected edge (x<y) is chosen with probability p.
    """
    import random

    directed_edges = dag['directed']
    # Let's build a set of "used" directed pairs to avoid confusion,
    # though for a complete DAG, every pair (x,y) is in directed form one way or another.
    # But for *bidirected*, we do not care about direction, so we always treat (x<y).

    # We'll store new bidirected edges in a list
    bidirected_edges = []

    # For each unordered pair i<j
    for i in range(n):
        for j in range(i + 1, n):
            # Decide with probability p if we add i<->j
            if random.random() < p:
                bidirected_edges.append((i, j))

    return {
        'directed': directed_edges,
        'bidirected': bidirected_edges
    }

def all_labeled_DAGs(n):
    """
    Generate all labeled DAGs on n nodes {0..n-1} as adjacency matrices.
    This yields duplicates under isomorphism.
    Returns a list of numpy 2D arrays, each shape (n,n),
    with 1 for an edge u->v, 0 otherwise.
    """
    edges = [(r, c) for r in range(n) for c in range(r + 1, n)]

    # We'll fill upper triangle (r<c).  For each pair, either 0 or 1.

    def is_acyclic(adj):
        """Check adjacency matrix 'adj' for cycles by DFS or other method."""
        visited = [0] * n  # 0=unvisited,1=visiting,2=visited

        def dfs(u):
            if visited[u] == 1:
                return True  # found cycle
            if visited[u] == 2:
                return False
            visited[u] = 1
            for v in range(n):
                if adj[u, v] == 1:
                    if dfs(v):
                        return True
            visited[u] = 2
            return False

        for node in range(n):
            if visited[node] == 0:
                if dfs(node):
                    return False
        return True

    results = []
    adj = np.zeros((n, n), dtype=int)

    def backtrack(idx):
        if idx == len(edges):
            # We've assigned all upper-triangle edges
            # Check acyclicity
            if is_acyclic(adj):
                results.append(adj.copy())
            return
        (r, c) = edges[idx]
        # Option 1: no edge
        adj[r, c] = 0
        backtrack(idx + 1)
        # Option 2: edge r->c
        adj[r, c] = 1
        backtrack(idx + 1)
        # revert
        adj[r, c] = 0

    backtrack(0)
    return results


def canonical_form(adj):
    """
    Compute a lexicographically minimal adjacency matrix
    among all label permutations, to serve as a canonical form.
    adj is an (n,n) numpy array for a DAG.
    Returns a tuple-of-tuples representation of the minimal adjacency matrix.
    """
    n = adj.shape[0]
    best = None
    for perm in itertools.permutations(range(n)):
        # permute rows and columns of adj
        # row i -> row perm[i], col j -> col perm[j]
        perm_adj = adj[list(perm)][:, list(perm)]
        # convert to tuple-of-tuples
        rep = tuple(tuple(row) for row in perm_adj)
        if best is None or rep < best:
            best = rep
    return best


def all_unlabeled_DAGs_up_to_isomorphism(n):
    """
    Return a list of adjacency matrices (np arrays) representing
    all DAGs on n nodes, *one per isomorphism class*.
    """
    all_labelled = all_labeled_DAGs(n)
    seen = set()
    results = []
    for dag in all_labelled:
        cform = canonical_form(dag)
        if cform not in seen:
            seen.add(cform)
            results.append(dag)
    return results


def all_ADMGs_from_DAG(dag_adj):
    """
    Given a DAG adjacency matrix 'dag_adj' (n x n),
    yield all ADMGs by enumerating all subsets of bidirected edges.
    Returns a list of dictionaries:
      {
        'directed':   [ (u,v), ... ],    # u->v from dag
        'bidirected': [ (x,y), ... ]     # x<->y for each subset
      }
    """
    import itertools

    n = dag_adj.shape[0]
    # collect all directed edges from adjacency
    directed_edges = []
    for u in range(n):
        for v in range(n):
            if dag_adj[u, v] == 1:
                directed_edges.append((u, v))

    # list all unordered pairs i<j
    pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
    results = []
    # For each subset of pairs, that subset is the set of bidirected edges
    for subset_size in range(len(pairs) + 1):
        for combo in itertools.combinations(pairs, subset_size):
            # combo is the chosen set of bidirected edges
            bidirected_edges = list(combo)
            # build the ADMG representation
            admg = {
                'directed': directed_edges,
                'bidirected': bidirected_edges
            }
            results.append(admg)
    return results


def all_ADMGs_up_to_isomorphism(n):
    """
    Enumerate all DAGs up to isomorphism on n nodes,
    then for each DAG, add all subsets of bidirected edges to get ADMGs.
    This can be huge, even for small n.

    Returns a big list of ADMG dicts.
    """
    all_dags = all_unlabeled_DAGs_up_to_isomorphism(n)
    all_admgs = []
    for dag in all_dags:
        # dag is an (n,n) adjacency matrix for the directed part
        admgs_for_dag = all_ADMGs_from_DAG(dag)
        all_admgs.extend(admgs_for_dag)
    return all_admgs


def verify_MAG_equi_adj(adj_dir_1: np.ndarray, adj_bi_1: np.ndarray,
                        adj_dir_2: np.ndarray, adj_bi_2: np.ndarray, discriminating_path=False):
    """
    Verify if two MAGs are Markov Equivalent
    Args:
        adj_dir_1: adjacency matrix of the first MAG with directed edges
        adj_bi_1: adjacency matrix of the first MAG with bidirected edges
        adj_dir_2: adjacency matrix of the second MAG with directed edges
        adj_bi_2: adjacency matrix of the second MAG with bidirected edges

    Returns: if the two MAGs are Markov Equivalent
    """

    # Check if mag 1 and mag 2 have the same adjacencies
    if not np.array_equal(np.logical_or(adj_dir_1, adj_bi_1), np.logical_or(adj_dir_2, adj_bi_2)):
        return False

    # Check if mag 1 and mag 2 have the same unshielded triples
    if not compare_unshielded_triples(adj_dir_1, adj_bi_1, adj_dir_2, adj_bi_2):
        return False

    # Check if mag 1 and mag 2 have the same discriminating path status
    if discriminating_path:
        if not compare_discriminating_paths(adj_dir_1, adj_bi_1, adj_dir_2, adj_bi_2):
            return False

    return True


def verify_MAG_equivalence(mag_1: GeneralGraph, mag_2: GeneralGraph, discriminating_path=False):
    """
    Verify if two MAGs are Markov Equivalent
    Args:
        mag_1: twin augmented graph of an admg
        mag_2: augmented graph learned by the algorithm

    Returns: if the given twin aug mag of an admg is in the set represented by the aug graph learned by the algorithm
    """

    # Get the adjacency matrix of the two MAGs
    name_id_map = {}
    id = 0
    for node in mag_1.get_nodes():
        name_id_map[node.get_name()] = id
        id += 1
    adj_dir_1, adj_bi_1 = get_adj_matrix_from_graph(mag_1, name_id_map)
    adj_dir_2, adj_bi_2 = get_adj_matrix_from_graph(mag_2, name_id_map)

    # Check if mag 1 and mag 2 have the same adjacencies
    if not np.array_equal(np.logical_or(adj_dir_1, adj_bi_1), np.logical_or(adj_dir_2, adj_bi_2)):
        return False

    # Check if mag 1 and mag 2 have the same unshielded triples
    if not compare_unshielded_triples(adj_dir_1, adj_bi_1, adj_dir_2, adj_bi_2):
        return False

    # Check if mag 1 and mag 2 have the same discriminating path status
    if discriminating_path:
        if not compare_discriminating_paths(adj_dir_1, adj_bi_1, adj_dir_2, adj_bi_2):
            return False

    return True


def get_adj_matrix_from_graph(G: GeneralGraph, name_id_map: dict):
    """
    Get the adjacency matrix of a graph, 1 for arrowhead, 2 for tail
    Args:
        G: a graph, MAG

    Returns: adjacency matrix of G with directed edges, adjacency matrix of G with bidirected edges
    """

    nodes = G.get_nodes()
    n = len(nodes)
    adj_dir = np.zeros((n, n))
    adj_bi = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if G.is_adjacent_to(nodes[i], nodes[j]):
                # If there is a directed edge from i to j, i.e., i -> j
                if G.get_edge(nodes[i], nodes[j]).get_endpoint1() == Endpoint.TAIL and G.get_edge(nodes[i], nodes[
                    j]).get_endpoint2() == Endpoint.ARROW:
                    adj_dir[i, j] = 2
                    adj_dir[j, i] = 1
                else:
                    # There is a bidirected edge between i and j
                    adj_bi[i, j] = 1
                    adj_bi[j, i] = 1

    return adj_dir, adj_bi


def compare_unshielded_triples(adj_dir_1, adj_bi_1, adj_dir_2, adj_bi_2):
    """
    Compare the unshielded triples of two MAGs.
    Args:
        adj_dir_1: adjacency matrix of the first MAG with directed edges
        adj_bi_1: adjacency matrix of the first MAG with bidirected edges
        adj_dir_2: adjacency matrix of the second MAG with directed edges
        adj_bi_2: adjacency matrix of the second MAG with bidirected edges

    Returns: if the unshielded triples of the two MAGs are the same
    """

    # Find all unshielded triples in the first MAG
    triples_1 = list_unshielded_colliders(adj_dir_1, adj_bi_1)
    # Find all unshielded triples in the second MAG
    triples_2 = list_unshielded_colliders(adj_dir_2, adj_bi_2)

    # Check if the unshielded triples are the same
    if set(triples_1) != set(triples_2):
        return False

    return True


def compare_discriminating_paths(adj_dir_1: np.ndarray, adj_bi_1: np.ndarray, adj_dir_2: np.ndarray,
                                 adj_bi_2: np.ndarray):
    """
    Compare the discriminating paths of two MAGs.
    Args:
        adj_dir_1: adjacency matrix of the first MAG with directed edges
        adj_bi_1: adjacency matrix of the first MAG with bidirected edges
        adj_dir_2: adjacency matrix of the second MAG with directed edges
        adj_bi_2: adjacency matrix of the second MAG with bidirected edges

    Returns: if the discriminating paths of the two MAGs are the same
    """

    # Find all discriminating paths with exactly 3 edges in the first MAG
    paths_1 = list_discriminating_paths_3(adj_dir_1, adj_bi_1)
    # Find all discriminating paths with exactly 3 edges in the second MAG
    paths_2 = list_discriminating_paths_3(adj_dir_2, adj_bi_2)

    # Check if the discriminating paths are the same
    if set(paths_1) != set(paths_2):
        return False

    return True


def list_unshielded_colliders(adj_dir, adj_bi):
    """
    Return all triples (A, B, C) that form an unshielded collider at B:
      1) A--B, B--C are edges, but A and C are not adjacent.
      2) Arrowheads into B from both A and C (i.e. A->B or A<->B, and C->B or C<->B).
    """
    n = adj_dir.shape[0]
    colliders = []

    for A in range(n):
        for B in range(n):
            if B == A:
                continue
            for C in range(n):
                if C == A or C == B:
                    continue

                if (A, B, C) in colliders or (C, B, A) in colliders:
                    continue

                # (1) A and B adjacent, B and C adjacent
                if not (is_adjacent(A, B, adj_dir, adj_bi) and
                        is_adjacent(B, C, adj_dir, adj_bi)):
                    continue

                # (2) A and C not adjacent => unshielded
                if not is_not_adjacent(A, C, adj_dir, adj_bi):
                    continue

                # (3) arrowheads at B from both A and C:
                #     that means arrowhead_at_j_from_i(A,B) and arrowhead_at_j_from_i(C,B)
                if (arrowhead_at_j_from_i(A, B, adj_dir, adj_bi) and
                        arrowhead_at_j_from_i(C, B, adj_dir, adj_bi)):
                    colliders.append((A, B, C))

    return colliders


def is_adjacent(i, j, adj_dir, adj_bi):
    """
    Returns True if i and j are adjacent in the MAG
    (i.e. either i->j / j->i, or i<->j).
    """
    # Directed adjacency if either i->j or j->i
    if adj_dir[i, j] != 0 or adj_dir[j, i] != 0:
        return True
    # Bidirected adjacency
    if adj_bi[i, j] == 1:
        return True
    return False


def is_not_adjacent(i, j, adj_dir, adj_bi):
    """
    Returns True if i and j have no edge (neither directed nor bidirected).
    """
    return (
            adj_dir[i, j] == 0 and
            adj_dir[j, i] == 0 and
            adj_bi[i, j] == 0
    )


def arrowhead_at_j_from_i(i, j, adj_dir, adj_bi):
    """
    Returns True if there's an arrowhead at j from i.
    That is: i->j (adj_dir[i,j] == 2) OR i<->j (adj_bi[i,j] == 1).
    """
    if adj_dir[i, j] == 2:
        return True
    if adj_bi[i, j] == 1:
        return True
    return False


def compute_descendants(adj_dir):
    """
    Compute 'descendants[x]' = set of nodes reachable from x by following
    only x->child edges (i.e. adj_dir[x,y] == 2).
    Then x is an ancestor of y if y in descendants[x].
    """
    n = adj_dir.shape[0]
    descendants = [set() for _ in range(n)]

    for start in range(n):
        visited = set()
        queue = deque([start])
        while queue:
            curr = queue.popleft()
            for nxt in range(n):
                # If curr->nxt, i.e. adj_dir[curr,nxt] == 2
                if adj_dir[curr, nxt] == 2 and nxt not in visited:
                    visited.add(nxt)
                    descendants[start].add(nxt)
                    queue.append(nxt)
    return descendants


def list_discriminating_paths_3(adj_dir, adj_bi):
    """
    Enumerate all discriminating paths with exactly 3 edges in
    a MAG represented by 'adj_dir' (directed) and 'adj_bi' (bidirected).

    Returns a list of tuples (X, A, Z, Y) that satisfy the definition:
      1) X and Y not adjacent
      2) Edges (X,A), (A,Z), (Z,Y) exist
      3) A is a collider on (X,A,Z) and a parent of Y
      4) Z is adjacent to Y (guaranteed by edge (Z,Y))
      5) The path has length 3 edges => 4 nodes
    """
    n = adj_dir.shape[0]  # number of nodes
    paths = []

    for X in range(n):
        for A in range(n):
            if A == X:
                continue
            for Z in range(n):
                if Z in (X, A):
                    continue
                for Y in range(n):
                    if Y in (X, A, Z):
                        continue

                    # (a) X not adjacent to Y
                    if not is_not_adjacent(X, Y, adj_dir, adj_bi):
                        continue

                    # (b) (X,A), (A,Z), (Z,Y) must be edges (directed or bidirected)
                    if not is_adjacent(X, A, adj_dir, adj_bi):
                        continue
                    if not is_adjacent(A, Z, adj_dir, adj_bi):
                        continue
                    if not is_adjacent(Z, Y, adj_dir, adj_bi):
                        continue

                    # (c) A is a collider on triple (X,A,Z)
                    #     => arrowheads at A from both X and Z
                    if not (arrowhead_at_j_from_i(X, A, adj_dir, adj_bi) and
                            arrowhead_at_j_from_i(Z, A, adj_dir, adj_bi)):
                        continue

                    # (d) A is a parent of Y => A->Y => adj_dir[A, Y] == 2
                    if adj_dir[A, Y] != 2:
                        continue

                    # If all conditions satisfied, we have a discriminating path
                    paths.append((X, A, Z, Y))

    return paths


def exists_inducing_path(i, j, adj_dir, adj_bi):
    """
    Returns True if there is any inducing path from node i to node j in the MAG.
    Otherwise False.
    """
    n = adj_dir.shape[0]

    # Precompute "descendants[x]" for all x
    # so we can check "x is ancestor of i or j" quickly
    descendants = compute_descendants(adj_dir)

    visited = set()  # to store (current_node, path_so_far)
    stack = [(i, [i])]  # DFS stack: (current_node, path_list)

    while stack:
        curr, path = stack.pop()

        # If we've reached j, check if path is an inducing path
        if curr == j:
            if is_inducing_path(path, i, j, adj_dir, adj_bi, descendants):
                return True
            # else continue checking other paths

        # Otherwise, expand neighbors
        for nxt in range(n):
            # must be adjacent to curr
            # and not yet in the path (avoid cycles)
            if nxt not in path and is_adjacent(curr, nxt, adj_dir, adj_bi):
                new_path = path + [nxt]
                stack.append((nxt, new_path))

    return False


def arrowhead_at_x_from_y(x, y, adj_dir, adj_bi):
    """
    True if there is an arrowhead at node x coming from y.
    That is: y->x (adj_dir[y, x] == 2) or y<->x (adj_bi[y, x] == 1).
    """
    # y->x if adj_dir[y,x] == 2
    # y<->x if adj_bi[y,x] == 1
    if adj_dir[y, x] == 2:
        return True
    if adj_bi[y, x] == 1:
        return True
    return False


def is_inducing_path(path, i, j, adj_dir, adj_bi, descendants):
    """
    Check if the given path (a list of nodes) is an inducing path
    from i=path[0] to j=path[-1].

    Condition: every non-endpoint v in path is:
      (1) a collider on path => arrowheads from both neighbors
      (2) an ancestor of i or j
    """
    # Quick check
    if path[0] != i or path[-1] != j:
        return False  # Not even from i->j

    # Check each consecutive pair is indeed adjacent (safety)
    for idx in range(len(path) - 1):
        if not is_adjacent(path[idx], path[idx + 1], adj_dir, adj_bi):
            return False

    # Check each interior node
    for idx in range(1, len(path) - 1):
        v = path[idx]
        prev_ = path[idx - 1]
        next_ = path[idx + 1]

        # (A) v must be a collider w.r.t (prev_, v, next_)
        #     => arrowheads from both prev_->v and next_->v
        if not (arrowhead_at_x_from_y(v, prev_, adj_dir, adj_bi) and
                arrowhead_at_x_from_y(v, next_, adj_dir, adj_bi)):
            return False

        # (B) v must be ancestor of i or j => i in descendants[v] or j in descendants[v]
        # That means there's a directed path v->...->i or v->...->j
        if not ((i in descendants[v]) or (j in descendants[v])):
            return False

    return True


def get_adj_of_aug_mag(admg: dict, n: int, target_dict: dict):
    # Construct the adjacency matrix of the augmented MAG of the ADMG
    # Construct the mapping from name to ids
    name_id_map = {}
    idx = 0
    for dom in target_dict.keys():
        for i in range(n):
            node_name = f"X{dom}_{i}"
            name_id_map[node_name] = idx
            idx += 1

    # Add F node
    name_id_map["F"] = idx
    adj_aug_dir = np.zeros((idx + 1, idx + 1))
    adj_aug_bi = np.zeros((idx + 1, idx + 1))

    # Get the adjacency matrices of the augmented graph of the ADMG
    for dom in target_dict.keys():
        target = target_dict[dom]
        # For each directed edge, if the second node is not in the target set, add it to the adjacency matrix
        for edge in admg['directed']:
            if edge[1] not in target:
                node1_name = f"X{dom}_{edge[0]}"
                node2_name = f"X{dom}_{edge[1]}"
                adj_aug_dir[name_id_map[node1_name], name_id_map[node2_name]] = 2
                adj_aug_dir[name_id_map[node2_name], name_id_map[node1_name]] = 1
        # For each bidirected edge, if the both  nodes are not in the target set, add it to the adjacency matrix
        for edge in admg['bidirected']:
            if edge[0] not in target and edge[1] not in target:
                node1_name = f"X{dom}_{edge[0]}"
                node2_name = f"X{dom}_{edge[1]}"
                adj_aug_bi[name_id_map[node1_name], name_id_map[node2_name]] = 1
                adj_aug_bi[name_id_map[node2_name], name_id_map[node1_name]] = 1

    # Add edges from F to the symmetric difference of target sets
    target_diff = target_dict[0].symmetric_difference(target_dict[1])
    for target in list(target_diff):
        for dom in target_dict.keys():
            node_name = f"X{dom}_{target}"
            adj_aug_dir[name_id_map["F"], name_id_map[node_name]] = 2
            adj_aug_dir[name_id_map[node_name], name_id_map["F"]] = 1

    # Get the adjacency matrix of the augmented MAG
    adj_aug_mag_dir = np.zeros((idx + 1, idx + 1))
    adj_aug_mag_bi = np.zeros((idx + 1, idx + 1))

    # Within each domain, iterate over all pairs of observed nodes
    for dom in target_dict.keys():
        for (node_1, node_2) in combinations(range(n), 2):
            node1_name = f"X{dom}_{node_1}"
            node2_name = f"X{dom}_{node_2}"
            # Check if the nodes are adjacent in the augmented graph of the ADMG
            # If they are adjacent, add node 1 -> node 2 if node 1 is node 2's ancestor
            # Add node 2 -> node 1 if node 2 is node 1's ancestor
            # Add node 1 <-> node 2 if they are adjacent but not ancestor of each other
            if adj_aug_dir[name_id_map[node1_name], name_id_map[node2_name]] == 2:
                adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 2
                adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 1
            elif adj_aug_dir[name_id_map[node1_name], name_id_map[node2_name]] == 1:
                adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 1
                adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 2
            # If node 1 is node 2's ancestor
            elif adj_aug_bi[name_id_map[node2_name], name_id_map[node1_name]] == 1 and is_ancestor(
                    name_id_map[node1_name], name_id_map[node2_name], adj_aug_dir):
                adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 2
                adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 1
            elif adj_aug_bi[name_id_map[node2_name], name_id_map[node1_name]] == 1 and is_ancestor(
                    name_id_map[node2_name], name_id_map[node1_name], adj_aug_dir):
                adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 1
                adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 2
            elif adj_aug_bi[name_id_map[node2_name], name_id_map[node1_name]] == 1:
                # Add a bidirected edge
                adj_aug_mag_bi[name_id_map[node1_name], name_id_map[node2_name]] = 1
                adj_aug_mag_bi[name_id_map[node2_name], name_id_map[node1_name]] = 1

    # The row and column of the F node is the same as the adj_aug_dir
    adj_aug_mag_dir[-1, :] = adj_aug_dir[-1, :]
    adj_aug_mag_dir[:, -1] = adj_aug_dir[:, -1]

    # Within each domain, iterate over all pairs of observed nodes
    for dom in target_dict.keys():
        for (node_1, node_2) in combinations(range(n), 2):
            node1_name = f"X{dom}_{node_1}"
            node2_name = f"X{dom}_{node_2}"
            # Skip if the nodes are adjacent in the augmented graph of the ADMG
            if is_adjacent(name_id_map[node1_name], name_id_map[node2_name], adj_aug_dir, adj_aug_bi):
                continue

            # Check if there is an inducing path from node 1 to node 2 in augmented graph of the ADMG
            if exists_inducing_path(name_id_map[node1_name], name_id_map[node2_name], adj_aug_dir, adj_aug_bi):
                # If node 1 is node 2's ancestor, add node 1 -> node 2
                if is_ancestor(name_id_map[node1_name], name_id_map[node2_name], adj_aug_dir):
                    adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 2
                    adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 1
                elif is_ancestor(name_id_map[node2_name], name_id_map[node1_name], adj_aug_dir):
                    adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 1
                    adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 2
                else:
                    # Add a bidirected edge
                    adj_aug_mag_bi[name_id_map[node1_name], name_id_map[node2_name]] = 1
                    adj_aug_mag_bi[name_id_map[node2_name], name_id_map[node1_name]] = 1

    # For the F node, iterate over all domains
    for dom in target_dict.keys():
        for i in range(n):
            # Skip if i is in the symmetric difference of the target sets
            if i in target_dict[0].symmetric_difference(target_dict[1]):
                continue
            node_name = f"X{dom}_{i}"
            # Check if there is an inducing path from the F node to the node
            if exists_inducing_path(name_id_map["F"], name_id_map[node_name], adj_aug_dir, adj_aug_bi):
                adj_aug_mag_dir[-1, name_id_map[node_name]] = 2
                adj_aug_mag_dir[name_id_map[node_name], -1] = 1
                other_domain = 1 - dom
                other_node_name = f"X{other_domain}_{i}"
                adj_aug_mag_dir[-1, name_id_map[other_node_name]] = 2
                adj_aug_mag_dir[name_id_map[other_node_name], -1] = 1

    return adj_aug_mag_dir, adj_aug_mag_bi


def get_adj_of_aug_mag_soft(admg: dict, n: int, target_dict: dict):
    # Construct the adjacency matrix of the augmented MAG of the ADMG
    # Construct the mapping from name to ids
    name_id_map = {}
    idx = 0
    for i in range(n):
        node_name = f"X_{i}"
        name_id_map[node_name] = idx
        idx += 1

    # Add F node
    name_id_map["F"] = idx
    adj_aug_dir = np.zeros((idx + 1, idx + 1))
    adj_aug_bi = np.zeros((idx + 1, idx + 1))

    # Get the adjacency matrices of the augmented graph of the ADMG
    for edge in admg['directed']:
        node1_name = f"X_{edge[0]}"
        node2_name = f"X_{edge[1]}"
        adj_aug_dir[name_id_map[node1_name], name_id_map[node2_name]] = 2
        adj_aug_dir[name_id_map[node2_name], name_id_map[node1_name]] = 1
    for edge in admg['bidirected']:
        node1_name = f"X_{edge[0]}"
        node2_name = f"X_{edge[1]}"
        adj_aug_bi[name_id_map[node1_name], name_id_map[node2_name]] = 1
        adj_aug_bi[name_id_map[node2_name], name_id_map[node1_name]] = 1

    # Add edges from F to the symmetric difference of target sets
    target_diff = target_dict[0].symmetric_difference(target_dict[1])
    for target in list(target_diff):
        node_name = f"X_{target}"
        adj_aug_dir[name_id_map["F"], name_id_map[node_name]] = 2
        adj_aug_dir[name_id_map[node_name], name_id_map["F"]] = 1

    # Get the adjacency matrix of the augmented MAG
    adj_aug_mag_dir = np.zeros((idx + 1, idx + 1))
    adj_aug_mag_bi = np.zeros((idx + 1, idx + 1))

    # Iterate over all pairs of observed nodes
    for (node_1, node_2) in combinations(range(n), 2):
        node1_name = f"X_{node_1}"
        node2_name = f"X_{node_2}"
        # Check if the nodes are adjacent in the augmented graph of the ADMG
        # If they are adjacent, add node 1 -> node 2 if node 1 is node 2's ancestor
        # Add node 2 -> node 1 if node 2 is node 1's ancestor
        # Add node 1 <-> node 2 if they are adjacent but not ancestor of each other
        if adj_aug_dir[name_id_map[node1_name], name_id_map[node2_name]] == 2:
            adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 2
            adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 1
        elif adj_aug_dir[name_id_map[node1_name], name_id_map[node2_name]] == 1:
            adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 1
            adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 2
        # If there is bidirected edge and  node 1 is node 2's ancestor
        elif adj_aug_bi[name_id_map[node2_name], name_id_map[node1_name]] == 1 and is_ancestor(name_id_map[node1_name],
                                                                                               name_id_map[node2_name],
                                                                                               adj_aug_dir):
            adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 2
            adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 1
        elif adj_aug_bi[name_id_map[node2_name], name_id_map[node1_name]] == 1 and is_ancestor(name_id_map[node2_name],
                                                                                               name_id_map[node1_name],
                                                                                               adj_aug_dir):
            adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 1
            adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 2
        elif adj_aug_bi[name_id_map[node2_name], name_id_map[node1_name]] == 1:
            # Add a bidirected edge
            adj_aug_mag_bi[name_id_map[node1_name], name_id_map[node2_name]] = 1
            adj_aug_mag_bi[name_id_map[node2_name], name_id_map[node1_name]] = 1

    # The row and column of the F node is the same as the adj_aug_dir
    adj_aug_mag_dir[-1, :] = adj_aug_dir[-1, :]
    adj_aug_mag_dir[:, -1] = adj_aug_dir[:, -1]

    # Iterate over all pairs of observed nodes
    for (node_1, node_2) in combinations(range(n), 2):
        node1_name = f"X_{node_1}"
        node2_name = f"X_{node_2}"
        # Skip if the nodes are adjacent in the augmented graph of the ADMG
        if is_adjacent(name_id_map[node1_name], name_id_map[node2_name], adj_aug_dir, adj_aug_bi):
            continue

        # Check if there is an inducing path from node 1 to node 2 in augmented graph of the ADMG
        if exists_inducing_path(name_id_map[node1_name], name_id_map[node2_name], adj_aug_dir, adj_aug_bi):
            # If node 1 is node 2's ancestor, add node 1 -> node 2
            if is_ancestor(name_id_map[node1_name], name_id_map[node2_name], adj_aug_dir):
                adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 2
                adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 1
            elif is_ancestor(name_id_map[node2_name], name_id_map[node1_name], adj_aug_dir):
                adj_aug_mag_dir[name_id_map[node1_name], name_id_map[node2_name]] = 1
                adj_aug_mag_dir[name_id_map[node2_name], name_id_map[node1_name]] = 2
            else:
                # Add a bidirected edge
                adj_aug_mag_bi[name_id_map[node1_name], name_id_map[node2_name]] = 1
                adj_aug_mag_bi[name_id_map[node2_name], name_id_map[node1_name]] = 1

    # For the F node, iterate over all observed nodes
    for i in range(n):
        node_name = f"X_{i}"
        # Skip if i is in the symmetric difference of the target sets
        if i in target_diff:
            continue
        # Check if there is an inducing path from the F node to the node
        if exists_inducing_path(name_id_map["F"], name_id_map[node_name], adj_aug_dir, adj_aug_bi):
            adj_aug_mag_dir[-1, name_id_map[node_name]] = 2
            adj_aug_mag_dir[name_id_map[node_name], -1] = 1

    return adj_aug_mag_dir, adj_aug_mag_bi


def is_ancestor(i, j, adj_dir):
    """
    Returns True if there is a directed path i -> ... -> j in the MAG
    using only the directed edges from 'adj_dir'.

    :param i: Starting node index
    :param j: Target node index
    :param adj_dir: A NumPy 2D array where adj_dir[u,v] == 2 means (u -> v)
                    and 0 means no directed edge from u to v.
    """
    n = adj_dir.shape[0]
    visited = set()
    stack = [i]

    while stack:
        current = stack.pop()
        if current == j:
            return True  # Found a path i -> ... -> j
        for nxt in range(n):
            # If current -> nxt (adj_dir[current,nxt] == 2) and we haven't visited nxt yet
            if adj_dir[current, nxt] == 2 and nxt not in visited:
                visited.add(nxt)
                stack.append(nxt)

    return False


def compare_with_learned_graph(admg_aug: GeneralGraph, learned_net: GeneralGraph, target_dict: dict):
    """
    Verify if the learned augmented MAG is I-Markov equivalent to the augmented MAG of a given ADMG.
    Args:
        admg_aug: twin augmented graph of an admg
        learned_net: augmented graph learned by the algorithm
        target_dict: interventions

    Returns: if the given twin aug mag of an admg is in the set represented by the aug graph learned by the algorithm
    """

    # Check if two graphs have the same adjacencies
    ## Iterate over all pairs of nodes of the learned graph
    for (node_1, node_2) in combinations(learned_net.get_nodes(), 2):
        # If one of them is F node, check if they are adjacent in the learned graph
        if node_1.get_name() == "F" or node_2.get_name() == "F":
            # Check if the nodes are adjacent in the learned graph
            adj_learned = learned_net.is_adjacent_to(node_1, node_2)
            # Check if the nodes are adjacent in the augmented MAG of the ADMG
            adj_aug_admg = admg_aug.is_adjacent_to(node_1, node_2)
            # Check if the flags are different
            if adj_learned != adj_aug_admg:
                return False
            continue

        # Get domain and ids
        domain1, id1 = extract_domain_id(node_1)
        domain2, id2 = extract_domain_id(node_2)
        # Get the names of the nodes
        node_name1 = node_1.get_name()
        node_name2 = node_2.get_name()
        # Get the nodes of the same name in the augmented MAG of the ADMG
        node_1_aug = admg_aug.get_node(node_name1)
        node_2_aug = admg_aug.get_node(node_name2)
        # Check if the nodes are adjacent in the learned graph
        adj_learned = learned_net.is_adjacent_to(node_1, node_2)
        # Check if the nodes are adjacent in the augmented MAG of the ADMG
        adj_aug_admg = admg_aug.is_adjacent_to(node_1, node_2)
        # Check if the flags are different
        if adj_learned != adj_aug_admg:
            return False
        # If they are adjacent, check if the endpoints of ARROW and TAIL (except CIRCLE) are also in the augmented MAG of the ADMG
        if adj_learned:
            edge_learned = learned_net.get_edge(node_1, node_2)
            edge_aug_admg = admg_aug.get_edge(node_1_aug, node_2_aug)
            # Check if arrow and tail in the learned graph are also in the augmented MAG of the ADMG
            if edge_learned.get_endpoint1() == Endpoint.CIRCLE:
                pass
            else:
                if edge_learned.get_endpoint1() != edge_aug_admg.get_endpoint1():
                    return False
            if edge_learned.get_endpoint2() == Endpoint.CIRCLE:
                pass
            else:
                if edge_learned.get_endpoint2() != edge_aug_admg.get_endpoint2():
                    return False

    return True


def get_soft_augmented_mag_from_admg(admg: dict, n: int, target_dict: dict):
    # Get the augmented mag of the admg under the soft intervention
    G_aug = GeneralGraph([])
    # Add nodes
    for i in range(n):
        node_name = f"X_{i}"
        G_aug.add_node(GraphNode(node_name))

    # Add F node
    F_node = GraphNode("F")
    G_aug.add_node(F_node)

    # Find symmetric difference of target sets
    target_sets = list(target_dict.values())
    target_diff = target_dict[0].symmetric_difference(target_dict[1])

    # Add directed  edges to G_aug
    for edge in admg['directed']:
        node1_name = f"X_{edge[0]}"
        node2_name = f"X_{edge[1]}"
        edge = Edge(G_aug.get_node(node1_name), G_aug.get_node(node2_name), Endpoint.TAIL, Endpoint.ARROW)
        G_aug.add_edge(edge)

    # Add bidirected edges to G_aug
    for edge in admg['bidirected']:
        node1_name = f"X_{edge[0]}"
        node2_name = f"X_{edge[1]}"
        edge = Edge(G_aug.get_node(node1_name), G_aug.get_node(node2_name), Endpoint.ARROW, Endpoint.ARROW)
        G_aug.add_edge(edge)

    # Add edges from F to the symmetric difference of target sets
    for target in list(target_diff):
        node1_name = f"X_{target}"
        edge = Edge(F_node, G_aug.get_node(node1_name), Endpoint.TAIL, Endpoint.ARROW)
        G_aug.add_edge(edge)

    # Get the augmented MAG of G_aug
    G_aug_mag = get_aug_mag_soft_from_aug(G_aug, target_dict, n)

    return G_aug_mag


def get_aug_mag_soft_from_aug(G_aug: GeneralGraph, target_dict: dict, n: int):
    # Obs nodes
    obs_nodes = set(range(n))
    # Create a new graph for the mag
    G_aug_mag = GeneralGraph([])
    # Add nodes from the augmented graph
    for node in G_aug.get_nodes():
        # Get the name of node
        node_name = node.get_name()
        # Add the node with the same name to the mag
        G_aug_mag.add_node(GraphNode(node_name))

    # Iterate over all pairs of nodes
    for (node_1, node_2) in combinations(G_aug.get_nodes(), 2):
        # Get the names of the nodes
        node_name1 = node_1.get_name()
        node_name2 = node_2.get_name()
        # Check If one of them is F node
        if node_1.get_name() == "F" or node_2.get_name() == "F":
            # Check if the nodes are adjacent, if so, add the edge to the mag and skip
            if G_aug.is_adjacent_to(node_1, node_2):
                # Get the edge
                edge_aug = G_aug.get_edge(node_1, node_2)
                # Add the edge to the mag,
                edge_aug_node_name_1 = edge_aug.get_node1().get_name()
                edge_aug_node_name_2 = edge_aug.get_node2().get_name()
                edge = Edge(G_aug_mag.get_node(edge_aug_node_name_1),
                            G_aug_mag.get_node(edge_aug_node_name_2),
                            edge_aug.get_endpoint1(), edge_aug.get_endpoint2())
                G_aug_mag.add_edge(edge)
                continue

        # Check if two nodes are d-separable
        sep_flag = is_separable_soft(G_aug, node_1, node_2, obs_nodes)

        if sep_flag:
            continue
        # check ancestral relation
        if G_aug.is_ancestor_of(node_1, node_2):
            edge = Edge(G_aug_mag.get_node(node_name1), G_aug_mag.get_node(node_name2), Endpoint.TAIL, Endpoint.ARROW)
            G_aug_mag.add_edge(edge)
        elif G_aug.is_ancestor_of(node_2, node_1):
            edge = Edge(G_aug_mag.get_node(node_name2), G_aug_mag.get_node(node_name1), Endpoint.TAIL, Endpoint.ARROW)
            G_aug_mag.add_edge(edge)
        else:
            # Add a bidirected edge
            edge = Edge(G_aug_mag.get_node(node_name1), G_aug_mag.get_node(node_name2), Endpoint.ARROW, Endpoint.ARROW)
            G_aug_mag.add_edge(edge)

    return G_aug_mag


def get_number(n):
    directory_path = './admgs_{}/'.format(n)

    # List everything in the directory
    entries = os.listdir(directory_path)

    # Filter to keep only items that are actual files
    file_count = sum(
        1 for entry in entries
        if os.path.isfile(os.path.join(directory_path, entry))
    )
    return file_count


def get_augmented_mag_from_admg(admg: dict, n: int, target_dict: dict):
    # Create a new graph for the augmented graph
    G_aug = GeneralGraph([])
    # Add node for each domain
    for dom_Id in target_dict.keys():
        for i in range(n):
            node_name = f"X{dom_Id}_{i}"
            G_aug.add_node(GraphNode(node_name))

    # Add F node
    F_node = GraphNode("F")
    G_aug.add_node(F_node)

    # Find symmetric difference of target sets
    target_sets = list(target_dict.values())
    target_diff = target_dict[0].symmetric_difference(target_dict[1])

    # Add edges to G_aug
    for dom_Id in target_dict.keys():
        target = target_dict[dom_Id]
        dom_edge = get_domain_skeleton(admg, target)
        # Add directed edges
        for edge in dom_edge['directed']:
            node1_name = f"X{dom_Id}_{edge[0]}"
            node2_name = f"X{dom_Id}_{edge[1]}"
            edge = Edge(G_aug.get_node(node1_name), G_aug.get_node(node2_name), Endpoint.TAIL, Endpoint.ARROW)
            G_aug.add_edge(edge)
        # Add bidirected edges
        for edge in dom_edge['bidirected']:
            node1_name = f"X{dom_Id}_{edge[0]}"
            node2_name = f"X{dom_Id}_{edge[1]}"
            edge = Edge(G_aug.get_node(node1_name), G_aug.get_node(node2_name), Endpoint.ARROW, Endpoint.ARROW)
            G_aug.add_edge(edge)

    # Add edges from F to the symmetric difference of target sets
    for target in list(target_diff):
        node1_name = f"X{0}_{target}"
        edge = Edge(F_node, G_aug.get_node(node1_name), Endpoint.TAIL, Endpoint.ARROW)
        G_aug.add_edge(edge)
        node2_name = f"X{1}_{target}"
        edge = Edge(F_node, G_aug.get_node(node2_name), Endpoint.TAIL, Endpoint.ARROW)
        G_aug.add_edge(edge)

    # Get the augmented MAG of G_aug
    G_aug_mag = get_aug_mag_from_aug(G_aug, target_dict, n)

    # Get twin augmented MAG of G_aug_mag
    G_aug_twin = get_twin_aug_mag_from_aug_mag(G_aug_mag)
    return G_aug_twin


def get_twin_aug_mag_from_aug_mag(G_aug_mag: GeneralGraph):
    # Find F node
    F_node = G_aug_mag.get_node("F")
    # Find  F node's neighbors
    F_neighbors = G_aug_mag.get_adjacent_nodes(F_node)
    # For each neighbor of F, find its domain and id
    for neighbor in F_neighbors:
        domain, id = extract_domain_id(neighbor)
        # Find the twin node
        twin_name = f"X{1 - domain}_{id}"
        twin_node = G_aug_mag.get_node(twin_name)
        # If the twin node is not adjacent to F, add a directed edge from F to that node
        if not G_aug_mag.is_adjacent_to(F_node, twin_node):
            edge = Edge(F_node, twin_node, Endpoint.TAIL, Endpoint.ARROW)
            G_aug_mag.add_edge(edge)

    return G_aug_mag


def get_aug_mag_from_aug(G_aug: GeneralGraph, target_dict: dict, n: int):
    # Obs nodes are {0, 1, .., n-1}
    obs_nodes = set(range(n))
    # Create a new graph for the mag
    G_aug_mag = GeneralGraph([])
    # Add nodes from the augmented graph
    for node in G_aug.get_nodes():
        # Get the name of node
        node_name = node.get_name()
        # Add the node with the same name to the mag
        G_aug_mag.add_node(GraphNode(node_name))

    # Iterate over all pairs of nodes
    for (node_1, node_2) in combinations(G_aug.get_nodes(), 2):
        # Get the names of the nodes
        node_name1 = node_1.get_name()
        node_name2 = node_2.get_name()
        # If one of them is F node, keep the same adjacency
        if node_1.get_name() == "F" or node_2.get_name() == "F":
            # Check if the nodes are adjacent, if so, add the edge to the mag and skip
            if G_aug.is_adjacent_to(node_1, node_2):
                # Get the edge
                edge_aug = G_aug.get_edge(node_1, node_2)
                # Add the edge to the mag,
                edge_aug_node_name_1 = edge_aug.get_node1().get_name()
                edge_aug_node_name_2 = edge_aug.get_node2().get_name()
                edge = Edge(G_aug_mag.get_node(edge_aug_node_name_1),
                            G_aug_mag.get_node(edge_aug_node_name_2),
                            edge_aug.get_endpoint1(), edge_aug.get_endpoint2())
                G_aug_mag.add_edge(edge)
                continue

        # If there is no F node, skip if the domains are different
        else:
            # Extract domain and ids
            domain1, id1 = extract_domain_id(node_1)
            domain2, id2 = extract_domain_id(node_2)
            # If the domains are not the same, skip
            if domain1 != domain2:
                continue

        # Check if two nodes are d-separable
        sep_flag = is_separable(G_aug, node_1, node_2, obs_nodes)

        if sep_flag:
            continue
        # check ancestral relation
        if G_aug.is_ancestor_of(node_1, node_2):
            edge = Edge(G_aug_mag.get_node(node_name1), G_aug_mag.get_node(node_name2), Endpoint.TAIL, Endpoint.ARROW)
            G_aug_mag.add_edge(edge)
        elif G_aug.is_ancestor_of(node_2, node_1):
            edge = Edge(G_aug_mag.get_node(node_name2), G_aug_mag.get_node(node_name1), Endpoint.TAIL, Endpoint.ARROW)
            G_aug_mag.add_edge(edge)
        else:
            # Add a bidirected edge
            edge = Edge(G_aug_mag.get_node(node_name1), G_aug_mag.get_node(node_name2), Endpoint.ARROW, Endpoint.ARROW)
            G_aug_mag.add_edge(edge)

    return G_aug_mag


def is_separable_soft(G_aug: GeneralGraph, node1: GraphNode, node2: GraphNode, obs_ids: set):
    # Check if two nodes are d-separable in G_aug
    # Check if the nodes are adjacent
    if G_aug.is_adjacent_to(node1, node2):
        return False

    # Check if node 1 or node 2 is F node
    if node1.get_name() == "F" or node2.get_name() == "F":
        # FInd F node
        if node1.get_name() == "F":
            F_node = node1
            node = node2
        else:
            F_node = node2
            node = node1
            id = int(node.get_name().split("_")[1])
            max_condition_set = obs_ids - {id}
            for condition_set in powerset(max_condition_set):
                condition_set_nodes = set()
                for idx in list(condition_set):
                    node_name = f"X_{idx}"
                    condition_set_nodes.add(G_aug.get_node(node_name))

                if G_aug.is_dseparated_from(node, F_node, list(condition_set_nodes)):
                    return True

    else:
        # Extract ids from the node names
        id1 = int(node1.get_name().split("_")[1])
        id2 = int(node2.get_name().split("_")[1])

        max_condition_set = obs_ids - {id1, id2}
        # FInd F node
        F_node = G_aug.get_node("F")
        # Check if the nodes are d-separated given any subset of max_condition_set (including empty set)
        for condition_set in powerset(max_condition_set):
            condition_set_nodes = set()
            for idx in list(condition_set):
                node_name = f"X_{idx}"
                condition_set_nodes.add(G_aug.get_node(node_name))

            full_conditionset = condition_set_nodes.union({F_node})
            if G_aug.is_dseparated_from(node1, node2, list(full_conditionset)):
                return True

    return False


def is_separable(G_aug: GeneralGraph, node1: GraphNode, node2: GraphNode, obs_ids: set):
    """
    Check if two nodes are d-separable given a subset of the observed nodes.

    Args:
        G_aug (GeneralGraph): The augmented graph.
        node1 (GraphNode): The first node.
        node2 (GraphNode): The second node.
        obs_ids (set): The set of observed nodes.

    Returns:
        bool: True if the nodes are d-separable, False otherwise.
    """
    # Check if the nodes are adjacent
    if G_aug.is_adjacent_to(node1, node2):
        return False

    # Check if node 1 or node 2 is F node
    if node1.get_name() == "F" or node2.get_name() == "F":
        # FInd F node
        if node1.get_name() == "F":
            F_node = node1
            node = node2
        else:
            F_node = node2
            node = node1
            dom, id = extract_domain_id(node)
            max_condition_set = obs_ids - {id}
            for condition_set in powerset(max_condition_set):
                condition_set_nodes = set()
                for id in list(condition_set):
                    node_name = f"X{dom}_{id}"
                    condition_set_nodes.add(G_aug.get_node(node_name))

                if G_aug.is_dseparated_from(node, F_node, list(condition_set_nodes)):
                    return True

    else:
        # Extract domain and ID from the node names
        dom1, id1 = extract_domain_id(node1)
        dom2, id2 = extract_domain_id(node2)
        if dom1 != dom2:
            return True
        max_condition_set = obs_ids - {id1, id2}
        # FInd F node
        F_node = G_aug.get_node("F")
        # Check if the nodes are d-separated given any subset of max_condition_set (including empty set)
        for condition_set in powerset(max_condition_set):
            condition_set_nodes = set()
            for id in list(condition_set):
                node_name = f"X{dom1}_{id}"
                condition_set_nodes.add(G_aug.get_node(node_name))

            full_conditionset = condition_set_nodes.union({F_node})
            if G_aug.is_dseparated_from(node1, node2, list(full_conditionset)):
                return True

    return False


def random_admg(n, p_extra_dir=0.5, p_extra_bi=0.5, seed=None):
    """
    Generate a random ADMG on n nodes:
      - The skeleton is connected.
      - The directed edges are acyclic (following a random topological order).
      - We add a random spanning tree + some extra edges with given probabilities.

    :param n: number of nodes, labeled 0..n-1
    :param p_extra_dir: probability of adding each possible extra directed edge
    :param p_extra_bi: probability of adding each possible extra bidirected edge
    :param seed: optional random seed for reproducibility
    :return: A dict {'directed': [...], 'bidirected': [...]}
    """
    if seed is not None:
        random.seed(seed)

    # 1) Pick a random permutation => this is our topological ordering
    topo_order = list(range(n))
    random.shuffle(topo_order)
    # "pos[node]" tells us the rank of that node in the topological order
    pos = {node: idx for idx, node in enumerate(topo_order)}

    # 2) Create a random spanning tree in the *undirected* sense.
    #    We'll do a simple "randomly connect components" approach (Kruskal style):
    #    Start with all nodes separate, then randomly pick edges until we have 1 component.
    #    We'll keep track of components using a union-find (disjoint sets).
    parent = list(range(n))
    rank = [0] * n

    def find(u):
        # union-find "find" with path compression
        if parent[u] != u:
            parent[u] = find(parent[u])
        return parent[u]

    def union(u, v):
        # union-find "union by rank"
        ru, rv = find(u), find(v)
        if ru != rv:
            if rank[ru] < rank[rv]:
                parent[ru] = rv
            elif rank[ru] > rank[rv]:
                parent[rv] = ru
            else:
                parent[rv] = ru
                rank[ru] += 1
            return True
        return False

    # All possible undirected edges (u,v), u<v for uniqueness
    all_undirected = []
    for u in range(n):
        for v in range(u + 1, n):
            all_undirected.append((u, v))
    random.shuffle(all_undirected)

    spanning_edges = []
    num_components = n
    for (u, v) in all_undirected:
        if union(u, v):
            # This edge is chosen in the spanning tree
            spanning_edges.append((u, v))
            num_components -= 1
            if num_components == 1:
                break

    # 3) For each edge in the spanning tree, decide if it is directed or bidirected
    #    If directed, orient it from lower pos -> higher pos in topological order
    directed_edges = []
    bidirected_edges = []
    for (u, v) in spanning_edges:
        # We compare topo_order positions
        if random.random() < 0.5:
            # 50% chance directed, 50% chance bidirected
            # If directed, orient from the one with smaller pos -> bigger pos
            if pos[u] < pos[v]:
                directed_edges.append((u, v))  # u->v
            else:
                directed_edges.append((v, u))  # v->u
        else:
            bidirected_edges.append((u, v))

    # 4) Add extra edges (not in the spanning tree) with probabilities
    #    We'll check all edges that are not in the spanning tree.
    #    For each such pair, we might add it as directed or bidirected (but not both).
    spanning_set = set()
    for (u, v) in spanning_edges:
        spanning_set.add((u, v))
        spanning_set.add((v, u))  # treat as undirected

    for (u, v) in all_undirected:
        if (u, v) in spanning_set or (v, u) in spanning_set:
            continue  # already used in spanning tree
        # With probability p_extra_dir, add as directed (u->v or v->u)
        # but must keep direction consistent with topological order
        if random.random() < p_extra_dir:
            # orient from smaller pos -> bigger pos
            if pos[u] < pos[v]:
                directed_edges.append((u, v))
            else:
                directed_edges.append((v, u))
        else:
            # else, with probability p_extra_bi, add as bidirected
            if random.random() < p_extra_bi:
                bidirected_edges.append((u, v))

    # 5) Construct the final dictionary
    return {
        'directed': directed_edges,
        'bidirected': bidirected_edges
    }


def extract_domain_id(node: GraphNode):
    """
    Extracts the domain and ID from a node name.

    Args:
        node (GraphNode): The node to extract the domain and ID from.

    Returns:
        tuple: The domain and ID of the node.
    """
    # Extract domain and ID from the node name
    domain = node.get_name().split("_")[0][1:]
    id = node.get_name().split("_")[1]
    return int(domain), int(id)


def get_domain_skeleton(admg: dict, target):
    # Get edges
    dir_edges = admg['directed']
    bidir_edges = admg['bidirected']
    # Get the domain skeleton
    dom_edge = {'directed': [], 'bidirected': []}
    for edge in dir_edges:
        if edge[1] in target:
            continue
        else:
            dom_edge['directed'].append(edge)
    for edge in bidir_edges:
        if edge[1] in target or edge[0] in target:
            continue
        else:
            dom_edge['bidirected'].append(edge)
    return dom_edge


def estimate_conditional_probabilities(df, y, w):
    """
    Estimates P(y | w) from binary samples using frequency counts.

    Args:
        df (pd.DataFrame): The dataset (samples from BN).
        y (str): The target variable.
        w (tuple): The conditioning set.

    Returns:
        dict: Mapping P(y=1 | w) with empirical distributions.
    """
    cond_prob = {}

    if not w:  # If w is empty, compute P(y) (marginal probability)
        counts = df[y].value_counts(normalize=True).to_dict()
        cond_prob[()] = counts.get('1', 0)  # Probability of y=1
        return cond_prob

    grouped = df.groupby(list(w))

    for w_values, group in grouped:
        counts = group[y].value_counts(normalize=True).to_dict()
        cond_prob[w_values] = counts.get('1', 0)  # Store probability of y=1

    return cond_prob


def compare_binary_distributions(p1, p2, alpha):
    """
    Compares two estimated conditional probability distributions for binary y.

    Args:
        p1 (dict): Estimated P(y | w) from BN1.
        p2 (dict): Estimated P(y | w) from BN2.
        alpha (float): Significance level for the Chi-Square test.

    Returns:
        bool: True if the distributions are statistically similar.
    """
    keys = set(p1.keys()).intersection(set(p2.keys()))  # Ensure matching w values

    for w_value in keys:
        p1_prob = p1[w_value]
        p2_prob = p2[w_value]

        # Convert to frequency counts (assume 1000 samples for robustness)
        n_samples = 1000
        p1_counts = np.array([p1_prob * n_samples, (1 - p1_prob) * n_samples])
        p2_counts = np.array([p2_prob * n_samples, (1 - p2_prob) * n_samples])

        # Chi-Square test for binary variables
        if chisquare(p1_counts, f_exp=p2_counts).pvalue > alpha:
            return True  # No significant difference → Distributions match

    return False  # No matching distribution found


def find_w_for_given_y(y: int, df1, df2, alpha=0.05):
    """
    Finds a conditioning set w such that P(y | w) is the same in both datasets.

    Args:
        y (str): The target variable.
        df1 (pd.DataFrame): Samples from BN1.
        df2 (pd.DataFrame): Samples from BN2.
        alpha (float): Significance level for the Chi-Square test.

    Returns:
        tuple: (flag, w) where:
            - flag (bool): True if such a w exists, False otherwise.
            - w (set or None): The conditioning set w if found, else None.
    """
    y = str(y)
    variables = set(df1.columns) - {y}  # Ensure w does not contain y

    for w_size in range(len(variables) + 1):  # Include empty set w = ∅
        for w in combinations(variables, w_size):
            p1 = estimate_conditional_probabilities(df1, y, w)
            p2 = estimate_conditional_probabilities(df2, y, w)

            if compare_binary_distributions(p1, p2, alpha):
                return True, set(w)  # Found a match

    return False, None  # No match found


def add_F_y_edges(G_aug, y):
    """
    Add edges between a target node and the F node in the augmented graph.

    Args:
        G_aug (GeneralGraph): The augmented graph.
        y (str): The target node.
    """
    # Add edge between y and F
    F_node = G_aug.get_node("F")
    y_1 = f"X{0}_{y}"
    y_2 = f"X{1}_{y}"
    # Set edge as F -> y
    edge = Edge(G_aug.get_node(y_1), F_node, Endpoint.ARROW, Endpoint.TAIL)
    G_aug.add_edge(edge)
    edge = Edge(G_aug.get_node(y_2), F_node, Endpoint.ARROW, Endpoint.TAIL)
    G_aug.add_edge(edge)

    return G_aug


def get_augmented_graph(learned_pags: dict, target_dict: dict):
    """
    Create an augmented graph from a dictionary of learned PAGs.

    Args:
        learned_pags (dict): Dictionary of learned PAGs for each domain, each PAG is a causallearn.graph.GeneralGraph object.
        target_dict (dict): Dictionary of target sets for each domain.

    Returns:
        GeneralGraph: An augmented graph with all nodes and edges from the learned PAGs.
    """
    # Create a new graph for the augmented graph
    G_aug = GeneralGraph([])
    # Get graph size of a pag
    n = len(list(learned_pags.values())[0].get_nodes())

    # Add nodes from all learned PAGs
    for dom in learned_pags.keys():
        for i in range(n):
            node_name = f"X{dom}_{i}"
            G_aug.add_node(GraphNode(node_name))

    # Add edges from all learned PAGs
    for dom in learned_pags.keys():
        for edge in learned_pags[dom].get_graph_edges():
            node1_id = int(edge.get_node1().get_name()[1:])
            node2_id = int(edge.get_node2().get_name()[1:])
            node1_name = f"X{dom}_{node1_id - 1}"
            node2_name = f"X{dom}_{node2_id - 1}"
            marker1 = edge.get_endpoint1()
            marker2 = edge.get_endpoint2()
            edge = Edge(G_aug.get_node(node1_name), G_aug.get_node(node2_name), marker1, marker2)
            G_aug.add_edge(edge)

    # Add F node
    F_node = GraphNode("F")
    G_aug.add_node(F_node)

    # Find symmetric difference of target sets
    target_sets = list(target_dict.values())
    target_diff = target_dict[0].symmetric_difference(target_dict[1])

    # Add edges from F to the symmetric difference of target sets
    for target in list(target_diff):
        node1_name = f"X{0}_{target}"
        edge = Edge(F_node, G_aug.get_node(node1_name), Endpoint.TAIL, Endpoint.ARROW)
        G_aug.add_edge(edge)
        node2_name = f"X{1}_{target}"
        edge = Edge(F_node, G_aug.get_node(node2_name), Endpoint.TAIL, Endpoint.ARROW)
        G_aug.add_edge(edge)

    return G_aug


def union_renamed_graphs(G1, G2):
    """
    Creates a union of two GeneralGraph objects with renamed nodes and preserved edges.

    Args:
        G1 (GeneralGraph): First graph (ID = 1).
        G2 (GeneralGraph): Second graph (ID = 2).

    Returns:
        GeneralGraph: A new graph with renamed nodes and all edges.
    """
    # Create a new graph for the union
    G_union = GeneralGraph([])

    # Function to rename nodes
    def rename_nodes(graph, graph_id):
        renamed_nodes = {}
        for node in graph.get_nodes():
            old_name = node.get_name()
            old_id = old_name[1:]
            new_id = int(old_id) - 1
            new_name = f"X{graph_id}_{new_id}"
            renamed_nodes[node.get_name()] = GraphNode(new_name)
            G_union.add_node(renamed_nodes[node.get_name()])
        return renamed_nodes

    # Rename and add nodes from G1 (ID=1)
    renamed_nodes_G1 = rename_nodes(G1, 0)

    # Rename and add nodes from G2 (ID=2)
    renamed_nodes_G2 = rename_nodes(G2, 1)

    # Function to add edges with renamed nodes
    def add_edges(graph, renamed_nodes):
        for edge in graph.get_graph_edges():
            node1 = renamed_nodes[edge.get_node1().get_name()]
            node2 = renamed_nodes[edge.get_node2().get_name()]
            G_union.add_edge(node1.get_name(), node2.get_name())

    # Add edges for both graphs
    add_edges(G1, renamed_nodes_G1)
    add_edges(G2, renamed_nodes_G2)

    return G_union


def powerset(iterable):
    """powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))


"""
def is_ancestor_of(graph, A, B):

    #Return True iff there is a directed path from node A to node B
    #in the given 'graph'. This implies A is an ancestor of B in a DAG.

    visited = set()
    queue = deque([A])
    visited.add(A)

    while queue:
        current = queue.popleft()
        if current == B:
            return True  # Found a path A -> ... -> B
        # Check all neighbors of 'current'
        for neighbor in graph.get_adjacent_nodes(current):
            # We only follow edges if current -> neighbor is an arrow
            # from 'current' to 'neighbor'
            if (graph.get_endpoint(current, neighbor) == Endpoint.ARROW and
                    graph.get_endpoint(neighbor, current) == Endpoint.TAIL):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append(neighbor)
    return False
"""


def random_targets(n, k):
    """
    Generate k different random sublists (including the empty list)
    from the list [0, 1, ..., n-1].

    Args:
        n (int): Length of the original list (0 to n-1).
        k (int): Number of different sublists to sample.

    Returns:
        target_dict: a dictionary with k keys, each key is a set of integers
    """
    full_list = list(range(n))
    target_dict = {}

    # Generate all possible subsets of the full list with 0 or 1 element
    power_set = [list(subset) for i in range(2) for subset in combinations(full_list, i)]

    # Randomly sample k unique subsets
    sampled_sublists = random.sample(power_set, k)

    for i in range(k):
        target_dict[i] = set(sampled_sublists[i])

    return target_dict


def generate_random_admg(n: int, rho: float, max_degree=3):
    t = 0.9
    max_flag = True
    while max_flag:

        # Generate a random nx ER graph
        G = nx.erdos_renyi_graph(n, rho)

        # Create a directed graph D, has the same nodes as G and convert G's edges to directed edges
        D = nx.DiGraph()
        D.add_nodes_from(G.nodes())
        D.add_edges_from(G.edges())

        # Add random bidirected edges
        for u, v in combinations(D.nodes(), 2):
            if np.random.rand() < t:
                # Add a node as a latent confounder
                latent_node = max(D.nodes()) + 1
                D.add_node(latent_node)
                D.add_edge(latent_node, u)
                D.add_edge(latent_node, v)

        # max_flag is false if the maximum degree of the graph is less than max_degree
        max_flag = max(dict(D.degree()).values()) >= max_degree
    return D


def rename_node(graph, old_name, new_name):
    """
    Renames a node in the GeneralGraph.

    Args:
        graph (GeneralGraph): The graph containing the node.
        old_name (str): The current name of the node.
        new_name (str): The new name for the node.
    """
    # Find the node with the old_name
    for node in graph.get_nodes():
        if node.get_name() == old_name:
            node.set_name(new_name)
            break
    else:
        print(f"Node with name '{old_name}' not found.")


def filter_columns_by_name(df, n):
    """
    Filters a DataFrame to keep only columns labeled with integers from 0 to n-1.

    Args:
        df (pd.DataFrame): Input DataFrame with unordered column names.
        n (int): Number of columns to keep (variables 0 to n-1).

    Returns:
        pd.DataFrame: Filtered DataFrame with only columns 0 to n-1.
    """
    # Convert column names to integers (if they are stored as strings)
    valid_columns = [col for col in df.columns if isinstance(col, (int, str)) and str(col).isdigit() and int(col) < n]

    # Keep only the columns in range 0 to n-1
    return df[sorted(valid_columns, key=int)]  # Sorting ensures order


def shanmugam_random_chordal(nnodes, density):
    while True:
        d = nx.DiGraph()
        d.add_nodes_from(set(range(nnodes)))
        order = list(range(1, nnodes))
        for i in order:
            num_parents_i = max(1, np.random.binomial(i, density))
            parents_i = random.sample(list(range(i)), num_parents_i)
            d.add_edges_from({(p, i) for p in parents_i})
        for i in reversed(order):
            for j, k in itr.combinations(d.predecessors(i), 2):
                d.add_edge(min(j, k), max(j, k))

        perm = np.random.permutation(list(range(nnodes)))
        d = nx.relabel.relabel_nodes(d, dict(enumerate(perm)))

        return d


# Write a nx graph for Julia loading, the nodes will be 1 to n, returns a dictionary that maps Julia graph id to nx node id
def write_graph(G, graph_dir='graph.gr'):
    nv = nx.number_of_nodes(G)
    ne = nx.number_of_edges(G)
    node_list = list(G.nodes)
    jl_id_to_nx_id_map = {}
    nx_id_to_jl_id_map = {}
    jl_ids = np.arange(nv) + 1
    for i in jl_ids:
        jl_id_to_nx_id_map[i] = node_list[i - 1]
        nx_id_to_jl_id_map[node_list[i - 1]] = i
    with open(graph_dir, 'w') as file:
        # Write nodes
        file.write("{} {}\n\n".format(nv, ne))
        for edge in G.edges:
            a, b = edge
            a_jl = nx_id_to_jl_id_map[a]
            b_jl = nx_id_to_jl_id_map[b]
            file.write("{} {}\n".format(a_jl, b_jl))

    return jl_id_to_nx_id_map


# Read a DAG sampled from Julia
def read_DAG(jl_id_to_nx_id_map, graph_dir='./samples/sample.lgz'):
    dag = nx.DiGraph()
    with open(graph_dir, "r") as file:
        # Read graph properties and arcs from each line
        for line in file:
            s = line.split(',')
            if len(s) > 2:
                n_nodes = np.int32(s[0])
                n_edges = np.int32(s[1])
                dag.add_nodes_from(np.arange(n_nodes) + 1)
            elif len(s) == 2:
                src_jl = np.int64(s[0])
                dst_jl = np.int64(s[1])
                dag.add_edge(jl_id_to_nx_id_map[src_jl], jl_id_to_nx_id_map[dst_jl])
    return dag


# Calculate the MEC size of a CPDAG
def get_CPDAG_MECsize(jl, CPDAG: graphical_models.PDAG):
    MEC_size = 1
    UCCG_list = CPDAG.chain_components()
    if UCCG_list == []:
        MEC_size = 1
    else:
        for UCCG in UCCG_list:
            graph_dir = 'graph.gr'
            write_graph(UCCG.to_nx(), graph_dir)
            UCCG_MEC_size = call_JLMECsize(jl, graph_dir)
            MEC_size = MEC_size * UCCG_MEC_size

    return MEC_size


# Call the Julia MEC counting function
def call_JLMECsize(jl, graph_dir='graph.gr'):
    jl.seval(f'graph_dir = "{graph_dir}"')
    G_MECsize = jl.seval("call_MECCounting(graph_dir)")
    return G_MECsize


# Call the Julia MEC sampling function
def call_JLMECsampler(jl, jl_id_to_nx_id_map, graph_dir='graph.gr',
                      sample_dir='./samples/sample.lgz'):
    jl.seval(f'graph_dir = "{graph_dir}"')
    jl.seval(f'sample_dir = "{sample_dir}"')
    jl.seval("call_MECSampling(graph_dir, sample_dir)")
    DAG_sample = read_DAG(jl_id_to_nx_id_map, sample_dir)
    return DAG_sample


# transform graphical model PDAG to nx Digraph
def PDAG_to_nx(PDAG: graphical_models.PDAG):
    D = nx.DiGraph()
    D.add_edges_from(PDAG.arcs)
    return D


# Get a dictionary that map from BN names to ids
def BN_names_to_id_map(BN):
    v_id_map = {}
    for name in list(BN.names()):
        v_id_map[name] = list(BN.nodeset([name]))[0]
    return v_id_map


# Get the reverse map
def BN_id_to_names_map(BN):
    id_name_map = {}
    for name in list(BN.names()):
        id_name_map[list(BN.nodeset([name]))[0]] = name
    return id_name_map


# Ge skeleton from BN, use as essential graph for now
def Essential_from_BN(BN):
    UCCG = nx.Graph()
    UCCG.add_nodes_from(BN.nodes())
    UCCG.add_edges_from(BN.arcs())
    return UCCG


# Compute the UCCGs from an essential graph
def UCCG_from_BN_ess(BN_ess):
    CPDAG = graphical_models.PDAG(nodes=BN_ess.nodes(), edges=BN_ess.edges(), arcs=BN_ess.arcs())
    UCCG_list = CPDAG.chain_components()
    nx_UCCG_list = []
    for UCCG in UCCG_list:
        UCCG._arcs = set()
        nx_UCCG_list.append(UCCG.to_nx())

    return nx_UCCG_list


# Calculate the skeleton of a given DAG
def skeleton_from_DAG(DAG):
    skeleton = graphical_models.PDAG.from_nx(DAG)
    for arc in skeleton.arcs:
        skeleton = skeleton._replace_arc_with_edge(arc)
    return skeleton


def nodeset_to_nameset(nodeset: set):
    nameset = set()
    for v in list(nodeset):
        nameset = nameset.union(set([str(v)]))
    return nameset


# Calculate the joint probablity of a BN
def get_joint_prob(BN):
    joint_prob = 1
    for var_name in list(BN.names()):
        joint_prob = joint_prob * BN.cpt(var_name)
    return joint_prob


# Calculate the interventional joint prob
def get_int_joint_prob_DAG(BN, joint_prob, DAG, int_V):
    int_joint_prob = 1
    int_V_names = nodeset_to_nameset(int_V)
    # Get joint prob of DAG
    for v in list(DAG.nodes):
        v_name = str(v)
        # If v is in intervetional set, change the distribution of v given Pa_v
        if v_name in int_V_names:
            v_var = BN.variableFromName(v_name)
            Pa_v = DAG.parents[v]
            Pa_v_names = nodeset_to_nameset(Pa_v)
            P_v_given_Pa_v = get_conditional_prob(joint_prob, set([v_name]), Pa_v_names)
            P_v_given_Pa_v = P_v_given_Pa_v.putFirst(v_name)
            int_prob = np.zeros(P_v_given_Pa_v.shape)
            int_prob[..., :] = [0, 1]
            P_v_given_Pa_v[:] = int_prob
            int_joint_prob = int_joint_prob * P_v_given_Pa_v

        else:
            # Get Potential for P(v|Pa_v)
            Pa_v = DAG.parents[v]
            Pa_v_names = nodeset_to_nameset(Pa_v)
            P_v_given_Pa_v = get_conditional_prob(joint_prob, set([v_name]), Pa_v_names)
            int_joint_prob = int_joint_prob * P_v_given_Pa_v

    # Extract array by order
    for i in range(len(DAG.nodes)):
        int_joint_prob = int_joint_prob.putFirst(str(i))
    int_joint_prob = int_joint_prob.toarray()

    assert np.abs(int_joint_prob.sum() - 1) <= 0.001
    return int_joint_prob


# Estimate sample likelihood from P_int
def get_sample_likelihood_from_P_int(x, P_int: dict, P_likelihood: dict):
    for comb_name in list(P_int.keys()):
        int_joint_prob = P_int[comb_name]
        P_x_given_DAG = copy.copy(int_joint_prob)
        # Get Likelihood of x in joint prob of DAG
        length = len(int_joint_prob.shape)
        for i in range(length):
            var_name = str(i)
            label_i = int(x[var_name].values[0])
            P_x_given_DAG = P_x_given_DAG[label_i]
        P_likelihood[comb_name] = P_x_given_DAG

    return P_likelihood


# Calculate the prior of a given UCCG
def get_Prior_of_UCCG(jl, UCCG: nx.Graph, partition='root'):
    # Save for clique picking later
    if partition != 'root':
        raise NotImplementedError
    # Initialize the dictionary to save priors
    P_prior = {}
    # Get the MEC size for UCCG
    graph_dir = 'graph.gr'
    write_graph(UCCG, graph_dir)
    UCCG_MECsize = call_JLMECsize(jl, graph_dir)
    P_prior = {}  # Initialize prior for v as root
    for v in list(UCCG.nodes):
        v_name = str(v)
        # Retrieve the v-rooted CPDAG and UCCGs
        v_rooted_CPDAG = UCCG_to_V_rooted_CPDAG(UCCG, set([v]))  # CPDAG in graphical_models PDAG
        v_rooted_UCCG_list = v_rooted_CPDAG.chain_components()  # Induced subgraph of CC
        # If no UCCG left, it is a tree
        if v_rooted_UCCG_list == []:
            P_prior[v_name] = 1 / UCCG_MECsize
            continue

        # Else, for each UCCG in v-rooted CPDAG, calculate MEC size
        v_rooted_UCCG_MEC_size_l = []
        for v_rooted_UCCG in v_rooted_UCCG_list:
            # Find MECsize of sub UCCG
            graph_dir = 'graph.gr'
            write_graph(v_rooted_UCCG.to_nx(), graph_dir)
            v_MEC_size = call_JLMECsize(jl, graph_dir)
            v_rooted_UCCG_MEC_size_l.append(v_MEC_size)
        P_prior[v_name] = np.prod(v_rooted_UCCG_MEC_size_l) / UCCG_MECsize

    return P_prior


# Calculate the conditional probability from a given joint probability
def get_conditional_prob(joint_prob, v_name: set, Condition_v_names: set):
    copy_joint = copy.copy(joint_prob)
    # No condition, return marg prob
    if len(Condition_v_names) == 0:
        return copy_joint.margSumOut(list(set(joint_prob.names) - v_name))
    else:
        marg_nominator_names = set(joint_prob.names) - v_name - Condition_v_names
        marg_denominator_names = set(joint_prob.names) - Condition_v_names
        P_denominator = copy_joint.margSumOut(list(marg_denominator_names))
        # No need to marg out
        if marg_nominator_names == set():
            P_nominator = joint_prob
        # Marg out rest
        else:
            P_nominator = copy_joint.margSumOut(list(marg_nominator_names))
        return P_nominator / P_denominator


# Given the UCCG and a set of source nodes, calculate the CPDAG
def UCCG_to_V_rooted_CPDAG(UCCG, V_ids: set):
    # Create PDAG to orient
    V_rooted_UCCG = graphical_models.PDAG(nodes=UCCG.nodes, edges=UCCG.edges)
    V_ids_list = list(V_ids)
    # Orient outgoing arcs from the source nodes
    for v in V_ids_list:
        neighbor_list = list(V_rooted_UCCG.neighbors_of(v) - V_ids)
        for u in neighbor_list:
            V_rooted_UCCG.replace_edge_with_arc((v, u))

    V_rooted_UCCG.to_complete_pdag()
    return V_rooted_UCCG


# Get an interventional BN, V_names is the intervention set
def get_intervened_BN(BN, V_names: set):
    BN_inter_V = copy.copy(BN)
    for v_name in list(V_names):
        p_v = BN_inter_V.cpt(v_name)
        a = np.zeros(p_v.shape)
        a[..., :] = [0, 1]
        BN_inter_V.cpt(v_name)[:] = a
    return BN_inter_V


# Sample a DAG that has V as source nodes
def DAG_sample_from_CPDAG(jl, v_rooted_CPDAG):
    v_rooted_UCCG_list = v_rooted_CPDAG.chain_components()
    # If fully oriented, return the CPDAG as DAG
    if v_rooted_UCCG_list == []:
        return v_rooted_CPDAG
    else:
        sample_DAG = copy.deepcopy(v_rooted_CPDAG)
        for v_rooted_UCCG in v_rooted_UCCG_list:
            # Sample a DAG for the UCCG
            graph_dir = 'graph.gr'
            jl_id_to_nx_id_map = write_graph(v_rooted_UCCG.to_nx(), graph_dir)
            sub_DAG = call_JLMECsampler(jl, jl_id_to_nx_id_map, graph_dir)
            sub_DAG_gm = graphical_models.PDAG(nodes=sub_DAG.nodes, arcs=sub_DAG.edges)
            # sub_DAG = graphical_models.PDAG.from_nx(sub_DAG)
            # Orient arcs in to match the sub DAG
            for arc in sub_DAG_gm.arcs:
                sample_DAG.replace_edge_with_arc(arc)
    return sample_DAG


# Get the name set from an id set
def get_V_names(id_name_map, V: set):
    if len(V) == 1:
        return {id_name_map[list(V)[0]]}
    elif len(V) == 0:
        return set()
    else:
        V_names = set()
        for v_id in list(V):
            V_names = V_names.union([id_name_map[v_id]])
        return V_names


# Calculate the likelihood when u==v
def get_P_X_given_V(x_sample, joint_prob_obs, int_V_names: set):
    # Get conditional prob of P(V\v|v)
    P_x_obs_given_int = copy.copy(joint_prob_obs)
    P_x_obs_given_int = get_conditional_prob(P_x_obs_given_int, set(joint_prob_obs.names) - int_V_names,
                                             int_V_names)
    # Get probability of x in joint prob of obs
    length = x_sample.shape[1]
    variable_order = P_x_obs_given_int.variablesSequence()
    for i in range(length):
        var_name = variable_order[length - i - 1].name()
        label_i = int(x_sample[var_name].values[0])
        P_x_obs_given_int = P_x_obs_given_int[label_i]
    return P_x_obs_given_int


# Calculate the likelihood from a sampled DAG
def P_likelihood_given_DAG(x, DAG, BN, joint_prob, int_V: set):
    joint_prob_DAG = 1
    int_V_names = nodeset_to_nameset(int_V)
    # Get joint prob of DAG
    for v in list(DAG.nodes):
        v_name = str(v)
        # If v is in intervetional set, change the distribution of v given Pa_v
        if v_name in int_V_names:
            v_var = BN.variableFromName(v_name)
            Pa_v = DAG.parents[v]
            Pa_v_names = nodeset_to_nameset(Pa_v)
            P_v_given_Pa_v = get_conditional_prob(joint_prob, set([v_name]), Pa_v_names)
            P_v_given_Pa_v = P_v_given_Pa_v.putFirst(v_name)
            int_prob = np.zeros(P_v_given_Pa_v.shape)
            int_prob[..., :] = [0, 1]
            P_v_given_Pa_v[:] = int_prob
            joint_prob_DAG = joint_prob_DAG * P_v_given_Pa_v

        else:
            # Get Potential for P(v|Pa_v)
            Pa_v = DAG.parents[v]
            Pa_v_names = nodeset_to_nameset(Pa_v)
            P_v_given_Pa_v = get_conditional_prob(joint_prob, set([v_name]), Pa_v_names)
            joint_prob_DAG = joint_prob_DAG * P_v_given_Pa_v

    # Check if P given DAG is joint
    assert np.abs(joint_prob_DAG.sum() - 1) <= 0.0001
    P_x_given_DAG = copy.copy(joint_prob_DAG)
    # Get Likelihood of x in joint prob of DAG
    length = len(joint_prob_DAG.names)
    for i in range(length):
        var_name = joint_prob_DAG.variablesSequence()[length - i - 1].name()
        label_i = int(x[var_name].values[0])
        P_x_given_DAG = P_x_given_DAG[label_i]
    return P_x_given_DAG


# Normalize the posterior and update the prior
def Normalize_posterior(posterior: dict, likelihood: dict, prior: dict, trace=True):
    sum_post = 0
    normed_posterior = copy.copy(posterior)

    # Sum up likelihood*prior
    # for V_names in list(likelihood.keys()):
    #    sum_post += likelihood[V_names]*prior[V_names]
    sum_post = np.sum(np.multiply(list(likelihood.values()), list(prior.values())))
    # Normalize posterior and update prior
    for V_names in list(posterior.keys()):
        if trace:
            normed_posterior[V_names].append(prior[V_names] * likelihood[V_names] / sum_post)
            prior[V_names] = normed_posterior[V_names][-1]
        else:
            normed_posterior[V_names] = prior[V_names] * likelihood[V_names] / sum_post
            prior[V_names] = normed_posterior[V_names]
    return normed_posterior, prior


# Check edge strength
def check_edge(joint_prob, u_name, v_name):
    print('Joint obs ')
    print(joint_prob)
    names = set(joint_prob.names)
    print('Joint {} {} '.format(u_name, v_name))
    print(joint_prob.margSumOut(list(names - set(u_name) - set(v_name))))
    print('P {} given {}'.format(u_name, v_name))
    print(get_conditional_prob(joint_prob, set(u_name), set(v_name)))
    print('P {} given {}'.format(v_name, u_name))
    print(get_conditional_prob(joint_prob, set(v_name), set(u_name)))
    return None


# Plot the posterior vs samples
def plot_posterior_vs_samples(posterior, int_name, n_data_samples=100):
    num_plot = len(posterior.keys())
    cmap = plt.get_cmap('gnuplot')
    colors = [cmap(i) for i in np.linspace(0, 1, num_plot)]
    x_axis = np.arange(n_data_samples) + 1
    var_list = list(posterior.keys())
    plt.figure()
    for i in range(num_plot):
        V_names = var_list[i]
        plt.plot(x_axis, posterior[V_names], c=colors[i])
    plt.title('Posterior V rooted do {}'.format(int_name))
    plt.xlabel('# Samples')
    plt.ylabel('Posterior')
    plt.legend(var_list)
    plt.show()

    return None


# Plot root posterior with different number of DAG samples
def verify_root_consistency(posteriors, root_name, n_data_samples=1000):
    num_plot = len(posteriors.keys())
    cmap = plt.get_cmap('gnuplot')
    colors = [cmap(i) for i in np.linspace(0, 1, num_plot)]
    x_axis = np.arange(n_data_samples) + 1
    n_DAG_sample_list = list(posteriors.keys())
    plt.figure()
    for i in range(num_plot):
        n_DAG_sample = n_DAG_sample_list[i]
        plt.plot(x_axis, posteriors[n_DAG_sample], color=colors[i])
    plt.title('Posterior {} rooted do {}'.format(root_name, root_name))
    plt.xlabel('# Samples')
    plt.ylabel('Posteriors')
    plt.legend(n_DAG_sample_list)
    plt.show()
    return None


# Smooth plots
def smooth_array(a: np.ndarray):
    n = a.shape[0]
    i = 0
    while i < n:
        a[i: i + 100] = np.mean(a[i:i + 100])
        i = i + 100

    return a


# Generate admgs

import networkx as nx
from itertools import permutations, combinations
import pickle

import os
import pickle


def all_labeled_dags_dict(n):
    """
    Generate all labeled DAGs on nodes {0..n-1}, where each DAG is stored as:
       {'directed': [(u,v), ...]}
    with (u,v) meaning u->v.

    No self-loops, no directed cycles. Each pair (u,v), u!=v,
    can be either "no edge" or "u->v" in the final graph.

    Yields one DAG dictionary per valid acyclic configuration.
    """
    possible_edges = [(u, v) for u in range(n) for v in range(n) if u != v]
    # We'll do a backtracking over these edges, deciding "include or not".

    directed_edges = []  # We'll store (u->v) edges as we pick them

    def is_acyclic(directed_edges_list):
        """
        Check if the directed graph 'directed_edges_list' contains a cycle.
        We'll do a simple DFS-based cycle detection.
        """
        adjacency = [[] for _ in range(n)]
        for (src, dst) in directed_edges_list:
            adjacency[src].append(dst)

        visited = [0] * n  # 0=unvisited,1=visiting,2=done

        def dfs(u):
            if visited[u] == 1:
                return True  # cycle found
            if visited[u] == 2:
                return False
            visited[u] = 1
            for w in adjacency[u]:
                if dfs(w):
                    return True
            visited[u] = 2
            return False

        for node in range(n):
            if visited[node] == 0:
                if dfs(node):
                    return False  # cycle detected
        return True  # no cycle

    def backtrack(idx):
        if idx == len(possible_edges):
            # all edges decided, check acyclicity
            if is_acyclic(directed_edges):
                # yield a DAG dict
                yield {'directed': list(directed_edges)}
            return

        (u, v) = possible_edges[idx]

        # Option A: do NOT include (u->v)
        yield from backtrack(idx + 1)

        # Option B: include (u->v)
        directed_edges.append((u, v))
        yield from backtrack(idx + 1)
        directed_edges.pop()

    yield from backtrack(0)


def save_all_dags_in_files(n):
    """
    Enumerate all DAGs with 'n' nodes in dictionary form,
    and save each one as 'dags_{n}_{i}.pkl' inside the folder 'dags_{n}/'.
    """
    folder_name = f"dags_{n}"
    os.makedirs(folder_name, exist_ok=True)

    count = 0
    for i, dag in enumerate(all_labeled_dags_dict(n)):
        filename = os.path.join(folder_name, f"dags_{n}_{i}.pkl")
        with open(filename, 'wb') as f:
            pickle.dump(dag, f)
        count += 1

    print(f"Saved {count} DAGs (each in a separate pickle file) in folder '{folder_name}'.")


def admgs_from_dag_dict(dag, n):
    """
    Given a DAG (dictionary) like {'directed': [(u,v), ...]},
    and the number of nodes n,
    yield all ADMGs by adding any subset of bidirected edges among pairs i<j.

    Each ADMG is:
      {
        'directed':   dag['directed'],        # same directed edges
        'bidirected': [(x,y), ...]            # chosen subset of i<j
      }
    """
    # Collect all unordered pairs i<j
    pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]

    # For each subset of those pairs, yield an ADMG
    for subset_size in range(len(pairs) + 1):
        for combo in itertools.combinations(pairs, subset_size):
            yield {
                'directed': dag['directed'],
                'bidirected': list(combo)
            }


def load_dags_from_folder(n):
    """
    Generator that yields each DAG (dictionary {'directed': [...]})
    from the folder 'dags_{n}' in sorted order of filenames.
    Assumes files are named e.g. 'dags_{n}_0.pkl', 'dags_{n}_1.pkl', etc.
    """
    folder_name = f"dags_{n}"
    if not os.path.isdir(folder_name):
        raise FileNotFoundError(f"Folder '{folder_name}' does not exist.")

    # List all pkl files in that folder
    files = [f for f in os.listdir(folder_name) if f.endswith('.pkl')]
    files.sort()  # so we process in sorted order

    for filename in files:
        path = os.path.join(folder_name, filename)
        with open(path, 'rb') as f:
            dag = pickle.load(f)
            yield dag

# Fix random seed for reproducibility
np.random.seed(42)

# Compare I-MEC size
n_dag = 50 # Number of experiments
n_list = [3]
k = 2 # Number of targets

for n in n_list:
    print('Number of nodes: ', n)
    c_list = []
    c_soft_list = []
    # Get c_total, count how many files are in the admgs_n folder
    c_total = len(list(load_dags_from_folder(n)))*2**(n*(n-1)/2)

    for i_dag in  range(n_dag):
        print('counting dags: ', i_dag, 'out of ', n_dag)
        # Get k different random targets from obs
        target_dict = random_targets(n, k)
        # print('intervention: ', target_dict)

        # Counters
        c = 0
        c_soft = 0

        # Get a random dag
        G = random_admg(n)
        #G = random_complete_admg(n)
        # print('true admg: ', G)

        # Get the augmented MAGs of the true G
        G_aug_adj_dir, G_aug_adj_bi = get_adj_of_aug_mag(G, n, target_dict)
        G_aug_adj_soft_dir, G_aug_adj_soft_bi = get_adj_of_aug_mag_soft(G, n, target_dict)

        all_dags = load_dags_from_folder(n)
        # For each DAG:
        for dag_index, dag in enumerate(all_dags):
            for admg in admgs_from_dag_dict(dag, n):

                #print('counting admgs: ', i, 'out of ', c_total)
                # Get the augmented MAGs
                aug_adj_dir, aug_adj_bi = get_adj_of_aug_mag(admg, n, target_dict)
                aug_adj_soft_dir, aug_adj_soft_bi = get_adj_of_aug_mag_soft(admg, n, target_dict)
                # Verify equivalence
                if verify_MAG_equi_adj(G_aug_adj_dir, G_aug_adj_bi, aug_adj_dir, aug_adj_bi):
                    #print('IMEC admg: ', admg)
                    c += 1
                if verify_MAG_equi_adj(G_aug_adj_soft_dir, G_aug_adj_soft_bi, aug_adj_soft_dir, aug_adj_soft_bi):
                    #print('Soft IMEC admg: ', admg)
                    c_soft += 1

        c_list.append(c)
        c_soft_list.append(c_soft)

    print('{} out of {} ADMGs are valid under hard interventions'.format(np.mean(c_list), c_total))
    print(' std: ', np.std(c_list))
    print('{} out of {} ADMGs are valid under soft interventions'.format(np.mean(c_soft_list), c_total))
    print(' std: ', np.std(c_soft_list))
    print('ratio (hard/soft): ', np.mean(c_list)/np.mean(c_soft_list))
    print('ratio std: ', np.std(np.array(c_list)/np.array(c_soft_list)))

    # Save the results
    results = {'hard': c_list, 'soft': c_soft_list}
    # Save the results to a file
    with open('results_{}_complete.pkl'.format(n), 'wb') as f:
        pickle.dump(results, f)