from __future__ import annotations
import numpy as np
import itertools
from typing import Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, Union
from collections import deque

import pyAgrum as gum  # pyAgrum >=2 uses lowercase package name
# pandas is used only for fast conversion from BNDatabaseGenerator
import pandas as pd

Array = np.ndarray

def random_admg(
    n_obs: int,
    rho_dag: float,
    rho_bi: float,
    seed: Optional[int] = None,
    allow_bow: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate a random ADMG on n_obs observed nodes.

    Representation:
      - Directed adjacency A_di (n_obs x n_obs):
          * if i -> j then A_di[i, j] = 1 and A_di[j, i] = 2
          * otherwise 0
      - Bidirected adjacency A_bi (n_obs x n_obs):
          * if i <-> j then A_bi[i, j] = A_bi[j, i] = 1
          * otherwise 0

    Sampling scheme:
      1) Sample a random topological order pi of {0,...,n_obs-1}.
      2) For each pair (a,b) with a before b in pi, add edge a->b with prob rho_dag.
         (Expected #directed edges = rho_dag * n_obs*(n_obs-1)/2.)
      3) For each unordered pair {i,j}, add i<->j with prob rho_bi.
         If allow_bow=False, skip adding i<->j when i and j already have a directed edge.

    Returns:
      (A_di, A_bi) as int numpy arrays.
    """
    if n_obs <= 0:
        raise ValueError("n_obs must be a positive integer.")
    for name, val in [("rho_dag", rho_dag), ("rho_bi", rho_bi)]:
        if not (0.0 <= val <= 1.0):
            raise ValueError(f"{name} must be in [0, 1], got {val}.")

    rng = np.random.default_rng(seed)

    A_di = np.zeros((n_obs, n_obs), dtype=np.int8)
    A_bi = np.zeros((n_obs, n_obs), dtype=np.int8)

    # 1) Random DAG via random topological order
    order = rng.permutation(n_obs)
    pos = np.empty(n_obs, dtype=int)
    pos[order] = np.arange(n_obs)

    # Add directed edges only forward in the order (acyclic by construction)
    for ii in range(n_obs):
        u = order[ii]
        for jj in range(ii + 1, n_obs):
            v = order[jj]
            if rng.random() < rho_dag:
                A_di[u, v] = 1
                A_di[v, u] = 2

    # 2) Add bidirected edges
    for i in range(n_obs):
        for j in range(i + 1, n_obs):
            if rng.random() < rho_bi:
                if (not allow_bow) and (A_di[i, j] != 0 or A_di[j, i] != 0):
                    continue
                A_bi[i, j] = 1
                A_bi[j, i] = 1

    # Safety: zero diagonals
    np.fill_diagonal(A_di, 0)
    np.fill_diagonal(A_bi, 0)

    return A_di, A_bi

def random_intervention_targets(
    n_obs: int,
    n_targets: int = 2,
    max_target_size: int = 1,
    seed: Optional[int] = None,
    include_observational: bool = True,
) -> List[FrozenSet[int]]:
    """
    Sample a list of unique intervention targets.

    If include_observational=True, target[0] is always the empty set (observational domain).
    """
    if n_obs <= 0:
        raise ValueError("n_obs must be a positive integer.")
    if n_targets <= 0:
        raise ValueError("n_targets must be a positive integer.")
    if max_target_size < 0:
        raise ValueError("max_target_size must be >= 0.")

    rng = np.random.default_rng(seed)

    k = min(max_target_size, n_obs)
    nodes = list(range(n_obs))

    # Candidate targets: all subsets of size 0..k
    candidates: List[FrozenSet[int]] = []
    for s in range(0, k + 1):
        for comb in itertools.combinations(nodes, s):
            candidates.append(frozenset(comb))

    if include_observational:
        obs = frozenset()
        # domain 0 fixed to observational
        remaining_candidates = [t for t in candidates if t != obs]
        if n_targets - 1 > len(remaining_candidates):
            raise ValueError(
                f"Requested n_targets={n_targets} but only {len(remaining_candidates)+1} unique "
                f"targets exist (including observational)."
            )
        idx = rng.choice(len(remaining_candidates), size=n_targets - 1, replace=False)
        return [obs] + [remaining_candidates[i] for i in idx]

    # original behavior (observational may or may not appear)
    if n_targets > len(candidates):
        raise ValueError(
            f"Requested n_targets={n_targets} but only {len(candidates)} unique targets exist."
        )
    idx = rng.choice(len(candidates), size=n_targets, replace=False)
    return [candidates[i] for i in idx]

def ensure_obs_domain_zero(target_dict: Dict[int, FrozenSet[int]]) -> Dict[int, FrozenSet[int]]:
    """
    Ensure domain 0 exists and is observational (empty target).
    If 0 is missing, shift all existing domains by +1 and insert 0:emptyset.
    If 0 exists but is non-empty, raise (to avoid silent mistakes).
    """
    if 0 in target_dict:
        if len(target_dict[0]) != 0:
            raise ValueError("Domain 0 must be observational (empty target).")
        return target_dict

    # shift indices by +1
    shifted = {k + 1: v for k, v in target_dict.items()}
    shifted[0] = frozenset()
    return shifted

def _sigmoid(x: Array) -> Array:
    # stable sigmoid
    out = np.empty_like(x, dtype=np.float64)
    pos = x >= 0
    out[pos] = 1.0 / (1.0 + np.exp(-x[pos]))
    ex = np.exp(x[~pos])
    out[~pos] = ex / (1.0 + ex)
    return out

def _toposort_dag(adj: Array) -> List[int]:
    """
    Kahn's algorithm. adj[i,j]=1 means i -> j.
    """
    n = adj.shape[0]
    indeg = adj.sum(axis=0).astype(int)
    q = [i for i in range(n) if indeg[i] == 0]
    order: List[int] = []
    while q:
        v = q.pop()
        order.append(v)
        for w in np.where(adj[v] == 1)[0]:
            indeg[w] -= 1
            if indeg[w] == 0:
                q.append(w)
    if len(order) != n:
        raise ValueError("Directed part is not a DAG (cycle detected).")
    return order

def _directed_adj_from_encoded(A_di: Array) -> Array:
    """
    From  encoding:
      i -> j  => A_di[i,j]=1 and A_di[j,i]=2
    We take edges where A_di[i,j]==1.
    """
    if A_di.ndim != 2 or A_di.shape[0] != A_di.shape[1]:
        raise ValueError("A_di must be a square matrix.")
    return (A_di == 1).astype(np.int8)

def _latent_realization_parents(
    A_di: Array,
    A_bi: Array
) -> Tuple[List[List[int]], List[int], int]:
    """
    Build parent lists for an expanded DAG that realizes the ADMG:
      - observed nodes: 0..n_obs-1
      - latent nodes:  n_obs..n_obs+n_lat-1 (one per bidirected edge i<->j, i<j)

    Returns:
      parents: list of length n_total, parents[u] = list of parent indices
      obs_topo: topological order over observed nodes in the directed subgraph
      n_obs: number of observed nodes
    """
    n_obs = A_di.shape[0]
    if A_bi.shape != (n_obs, n_obs):
        raise ValueError("A_bi must have the same shape as A_di.")

    D = _directed_adj_from_encoded(A_di)
    obs_topo = _toposort_dag(D)

    # Build parent lists for observed nodes from directed edges
    parents_obs: List[List[int]] = [[] for _ in range(n_obs)]
    for i in range(n_obs):
        for j in np.where(D[i] == 1)[0]:
            parents_obs[j].append(i)

    # Add one latent per bidirected edge i<->j (i<j)
    latent_pairs: List[Tuple[int, int]] = []
    for i in range(n_obs):
        for j in range(i + 1, n_obs):
            if A_bi[i, j] == 1:
                latent_pairs.append((i, j))

    n_lat = len(latent_pairs)
    n_total = n_obs + n_lat
    parents: List[List[int]] = [[] for _ in range(n_total)]

    # Latents are roots
    for k in range(n_lat):
        parents[n_obs + k] = []

    # Observed parents: directed + latent
    for v in range(n_obs):
        parents[v] = list(parents_obs[v])

    for k, (i, j) in enumerate(latent_pairs):
        l = n_obs + k
        parents[i].append(l)
        parents[j].append(l)

    return parents, obs_topo, n_obs

def generate_interventional_data_from_admg(
    A_di: Array,
    A_bi: Array,
    targets: Iterable[Union[set, FrozenSet[int]]],
    n_samples: int = 1000,
    variable_type: str = "binary",  # "binary" or "gaussian"
    seed: Optional[int] = None,
) -> Dict[int, Dict[str, object]]:
    """
    Generate interventional datasets from an ADMG by realizing each bidirected edge
    as a latent root parent (one latent per bidirected edge), then sampling.

    Soft intervention:
      - For each domain k with target T_k, we change ONLY the CPDs of nodes in T_k
        (by resampling their parameters); all other node mechanisms are shared.

    Returns:
      out[k] = {"target": frozenset(...), "data": X_obs} where
        X_obs has shape (n_samples, n_obs) and excludes latent columns.
    """
    if n_samples <= 0:
        raise ValueError("n_samples must be a positive integer.")
    if variable_type not in {"binary", "gaussian"}:
        raise ValueError("variable_type must be 'binary' or 'gaussian'.")

    rng = np.random.default_rng(seed)

    # Expand ADMG into latent-variable DAG
    parents, obs_topo, n_obs = _latent_realization_parents(A_di, A_bi)
    n_total = len(parents)
    latent_nodes = list(range(n_obs, n_total))

    # Normalize targets to frozensets and validate
    targets_list: List[FrozenSet[int]] = []
    for t in targets:
        ft = frozenset(t)
        for v in ft:
            if not (0 <= v < n_obs):
                raise ValueError(f"Target contains node {v}, but valid range is 0..{n_obs-1}.")
        targets_list.append(ft)

    # --- Sample base parameters for every node's mechanism ---
    # For latents: root distributions.
    # For observed: conditional on its parents (observed+latent).
    base_params = {}
    for v in range(n_total):
        pa = parents[v]
        p = len(pa)

        if variable_type == "binary":
            # Logistic CPD: P(X_v=1|pa) = sigmoid(b + sum w_i * X_pa_i)
            # Keep scale modest for stability when p is large.
            w_scale = 1.0 / np.sqrt(max(1, p))
            b = rng.normal(0.0, 0.5)
            w = rng.normal(0.0, w_scale, size=p) if p > 0 else np.zeros(0, dtype=np.float64)
            base_params[v] = ("binary", b, w)
        else:
            # Linear Gaussian: X_v = b + sum w_i * X_pa_i + eps, eps~N(0,sigma^2)
            w_scale = 0.7 / np.sqrt(max(1, p))
            b = rng.normal(0.0, 1.0)
            w = rng.normal(0.0, w_scale, size=p) if p > 0 else np.zeros(0, dtype=np.float64)
            sigma = 1.0
            base_params[v] = ("gaussian", b, w, sigma)

    intervened_nodes = set().union(*targets_list) if len(targets_list) > 0 else set()

    intervened_params = {}
    for v in intervened_nodes:
        pa = parents[v]
        p = len(pa)
        if variable_type == "binary":
            w_scale = 1.0 / np.sqrt(max(1, p))
            b = rng.normal(0.0, 0.5)
            w = rng.normal(0.0, w_scale, size=p) if p > 0 else np.zeros(0, dtype=np.float64)
            intervened_params[v] = ("binary", b, w)
        else:
            w_scale = 0.7 / np.sqrt(max(1, p))
            b = rng.normal(0.0, 1.0)
            w = rng.normal(0.0, w_scale, size=p) if p > 0 else np.zeros(0, dtype=np.float64)
            sigma = 1.0
            intervened_params[v] = ("gaussian", b, w, sigma)

    def sample_domain(domain_target: FrozenSet[int]) -> Array:
        # Copy base params; apply controlled per-node intervention mechanisms
        params = dict(base_params)
        for v in domain_target:
            params[v] = intervened_params[v]

        # Sample all variables (latents first, then observed in topo order)
        X = np.zeros((n_samples, n_total), dtype=np.float64)

        # Latents are roots
        for l in latent_nodes:
            if variable_type == "binary":
                # simple Bernoulli(0.5) root
                X[:, l] = rng.binomial(1, 0.5, size=n_samples)
            else:
                X[:, l] = rng.normal(0.0, 1.0, size=n_samples)

        # Observed nodes in topological order of the directed part
        for v in obs_topo:
            kind = params[v][0]
            pa = parents[v]
            if len(pa) == 0:
                lin = np.full(n_samples, params[v][1], dtype=np.float64)  # bias only
            else:
                if kind == "binary":
                    _, b, w = params[v]
                    lin = b + X[:, pa] @ w
                else:
                    _, b, w, _sigma = params[v]
                    lin = b + X[:, pa] @ w

            if kind == "binary":
                p = _sigmoid(lin)
                X[:, v] = rng.binomial(1, p)
            else:
                sigma = params[v][3]
                X[:, v] = lin + rng.normal(0.0, sigma, size=n_samples)

        # Return ONLY observed variables
        X_obs = X[:, :n_obs]
        if variable_type == "binary":
            X_obs = X_obs.astype(np.int8)
        return X_obs

    out: Dict[int, Dict[str, object]] = {}
    for k, t in enumerate(targets_list):
        out[k] = {"target": t, "data": sample_domain(t)}

    return out

# ============================================================
# Helpers: parse ADMG adjacency encoding
# ============================================================

def directed_adj_from_encoded(A_di: Array) -> Array:
    """
    From encoding: i->j iff A_di[i,j]==1 (and A_di[j,i]==2).
    Returns a 0/1 adjacency matrix D where D[i,j]=1 means i->j.
    """
    if A_di.ndim != 2 or A_di.shape[0] != A_di.shape[1]:
        raise ValueError("A_di must be square.")
    return (A_di == 1).astype(np.int8)


def children_from_directed_adj(D: Array) -> List[List[int]]:
    n = D.shape[0]
    return [list(np.where(D[i] == 1)[0]) for i in range(n)]


def parents_from_directed_adj(D: Array) -> List[List[int]]:
    n = D.shape[0]
    return [list(np.where(D[:, j] == 1)[0]) for j in range(n)]


# ============================================================
# Canonical latent-DAG realization of an ADMG
# (1 latent root per bidirected edge i<->j)
# ============================================================

def build_canonical_latent_dag(A_di: Array, A_bi: Array) -> Tuple[List[List[int]], List[List[int]], int]:
    """
    Build a DAG with observed nodes 0..n_obs-1 and latent nodes n_obs..n_total-1.

    For each bidirected edge i<->j (i<j) in A_bi, create a latent root L_k
    and add edges L_k -> i and L_k -> j.

    Returns:
      parents, children, n_obs
    """
    n_obs = A_di.shape[0]
    if A_bi.shape != (n_obs, n_obs):
        raise ValueError("A_bi must have the same shape as A_di.")

    D = directed_adj_from_encoded(A_di)

    # list bidirected pairs
    bi_pairs: List[Tuple[int, int]] = []
    for i in range(n_obs):
        for j in range(i + 1, n_obs):
            if A_bi[i, j] == 1:
                bi_pairs.append((i, j))

    n_lat = len(bi_pairs)
    n_total = n_obs + n_lat

    parents: List[List[int]] = [[] for _ in range(n_total)]
    children: List[List[int]] = [[] for _ in range(n_total)]

    # directed observed edges
    for i in range(n_obs):
        for j in np.where(D[i] == 1)[0]:
            children[i].append(int(j))
            parents[int(j)].append(i)

    # latent roots for bidirected edges
    for k, (i, j) in enumerate(bi_pairs):
        l = n_obs + k
        # l -> i
        children[l].append(i)
        parents[i].append(l)
        # l -> j
        children[l].append(j)
        parents[j].append(l)

    return parents, children, n_obs


# ============================================================
# Bayes-ball d-separation for DAG
# ============================================================

def d_separated_bayes_ball(
    parents: Sequence[Sequence[int]],
    children: Sequence[Sequence[int]],
    x: int,
    y: int,
    Z: Set[int],
) -> bool:
    """
    Returns True iff x and y are d-separated given Z in the DAG (parents/children lists),
    using Bayes-ball.

    State is (node, came_from) where came_from in {"parent", "child"}:
      - came_from="parent": ball arrived along an edge parent -> node
      - came_from="child" : ball arrived along an edge node -> child (traversed upward)

    Transition rules (classic Bayes-ball):
      If arrived from parent:
        - if node observed: pass to parents
        - else: pass to children
      If arrived from child:
        - if node observed: stop
        - else: pass to parents AND children
    """
    if x in Z or y in Z:
        raise ValueError("For (d-)separation queries, assume x,y not in Z.")

    visited = set()
    q = deque()
    q.append((x, "parent"))
    q.append((x, "child"))

    while q:
        v, came_from = q.popleft()
        if (v, came_from) in visited:
            continue
        visited.add((v, came_from))

        # Reaching y means there's an active trail (not separated)
        if v == y:
            return False

        observed = v in Z

        if came_from == "parent":
            if observed:
                # pass to parents
                for p in parents[v]:
                    q.append((p, "child"))
            else:
                # pass to children
                for c in children[v]:
                    q.append((c, "parent"))

        elif came_from == "child":
            if observed:
                # stop
                continue
            else:
                # pass to parents and children
                for p in parents[v]:
                    q.append((p, "child"))
                for c in children[v]:
                    q.append((c, "parent"))
        else:
            raise ValueError("came_from must be 'parent' or 'child'.")

    return True


# ============================================================
# m-separation in ADMG (via Bayes-ball on canonical latent DAG)
# ============================================================

def m_separated_admg_given_Z(
    A_di: Array,
    A_bi: Array,
    x: int,
    y: int,
    Z: Iterable[int],
    *,
    _cache: Optional[Tuple[List[List[int]], List[List[int]], int]] = None,
) -> bool:
    """
    Check m-separation between x and y given Z in an ADMG,
    by running Bayes-ball d-separation on a canonical latent DAG.

    Z must be a subset of observed nodes.
    """
    Zset = set(Z)
    if x in Zset or y in Zset:
        raise ValueError("Assume x,y not in Z for m-separation queries.")
    if x == y:
        return True

    if _cache is None:
        parents, children, n_obs = build_canonical_latent_dag(A_di, A_bi)
    else:
        parents, children, n_obs = _cache

    if not (0 <= x < n_obs and 0 <= y < n_obs):
        raise ValueError("x and y must be observed node indices.")
    for z in Zset:
        if not (0 <= z < n_obs):
            raise ValueError("Z must contain only observed node indices.")

    # In the canonical DAG, observed nodes keep their indices; latents are > n_obs-1.
    return d_separated_bayes_ball(parents, children, x, y, Zset)


def m_separable_admg(
    A_di: Array,
    A_bi: Array,
    x: int,
    y: int,
    *,
    max_cond_set_size: Optional[int] = None,
    return_sepset: bool = False,
) -> Tuple[bool, Optional[FrozenSet[int]]]:
    """
    Returns (is_m_separable, sepset_if_found).

    is_m_separable=True means: there exists some Z (subset of observed nodes excluding x,y)
    such that x and y are m-separated given Z.

    Brute-force over subsets, increasing by size, up to max_cond_set_size (if provided).
    """
    parents, children, n_obs = build_canonical_latent_dag(A_di, A_bi)
    cache = (parents, children, n_obs)

    if not (0 <= x < n_obs and 0 <= y < n_obs) or x == y:
        raise ValueError("x,y must be distinct observed node indices.")

    candidates = [v for v in range(n_obs) if v not in (x, y)]
    if max_cond_set_size is None:
        max_k = len(candidates)
    else:
        max_k = min(max_cond_set_size, len(candidates))

    # Try small separators first (often enough in practice)
    for k in range(max_k + 1):
        for comb in itertools.combinations(candidates, k):
            Z = set(comb)
            if m_separated_admg_given_Z(A_di, A_bi, x, y, Z, _cache=cache):
                sepset = frozenset(Z)
                return True, (sepset if return_sepset else None)

    return False, None


# ============================================================
# Ancestors in the directed part (for orienting MAG edges)
# ============================================================

def ancestor_matrix_from_directed(A_di: Array) -> Array:
    """
    anc[i,j]=True iff i is an ancestor of j via directed edges (i ->* j),
    computed from the directed part of the ADMG (ignoring bidirected edges).
    """
    D = directed_adj_from_encoded(A_di)
    n = D.shape[0]
    anc = np.zeros((n, n), dtype=bool)

    children = children_from_directed_adj(D)

    for src in range(n):
        stack = list(children[src])
        seen = set()
        while stack:
            v = stack.pop()
            if v in seen:
                continue
            seen.add(v)
            anc[src, v] = True
            stack.extend(children[v])

    return anc


# ============================================================
# ADMG -> MAG
# ============================================================

def admg_to_mag(
    A_di: Array,
    A_bi: Array,
    *,
    max_cond_set_size: Optional[int] = None,
) -> Tuple[Array, Array]:
    """
    Convert an ADMG to a (maximal) ancestral graph (MAG) over the same observed nodes.

    Construction :
      For each pair (i,j):
        - If i and j are m-separable (∃Z such that i ⫫_m j | Z), then no edge.
        - Otherwise add an edge, oriented using ancestry in the directed part:
            if i ancestor of j: i -> j
            elif j ancestor of i: j -> i
            else: i <-> j

    Representation:
      - MAG directed matrix M_di: i->j encoded with (i,j)=1 and (j,i)=2
      - MAG bidirected matrix M_bi: symmetric 1s for i<->j
    """
    n = A_di.shape[0]
    if A_bi.shape != (n, n):
        raise ValueError("A_bi must have same shape as A_di.")

    anc = ancestor_matrix_from_directed(A_di)

    M_di = np.zeros((n, n), dtype=np.int8)
    M_bi = np.zeros((n, n), dtype=np.int8)

    # Prebuild m-sep cache once for speed
    parents, children, n_obs = build_canonical_latent_dag(A_di, A_bi)
    cache = (parents, children, n_obs)

    nodes = list(range(n))
    for i in range(n):
        for j in range(i + 1, n):
            # Check if m-separable (exists a separating set)
            candidates = [v for v in nodes if v not in (i, j)]
            if max_cond_set_size is None:
                max_k = len(candidates)
            else:
                max_k = min(max_cond_set_size, len(candidates))

            separable = False
            for k in range(max_k + 1):
                for comb in itertools.combinations(candidates, k):
                    if m_separated_admg_given_Z(A_di, A_bi, i, j, comb, _cache=cache):
                        separable = True
                        break
                if separable:
                    break

            if separable:
                continue  # no edge in MAG

            # Not separable => adjacent in MAG. Orient by ancestry.
            if anc[i, j]:
                # i -> j
                M_di[i, j] = 1
                M_di[j, i] = 2
            elif anc[j, i]:
                # j -> i
                M_di[j, i] = 1
                M_di[i, j] = 2
            else:
                # i <-> j
                M_bi[i, j] = 1
                M_bi[j, i] = 1

    np.fill_diagonal(M_di, 0)
    np.fill_diagonal(M_bi, 0)
    return M_di, M_bi

# ============================================================
# Conversions / normalization
# ============================================================

def dag_adj01_to_marks(A01: np.ndarray) -> np.ndarray:
    """
    Convert a 0/1 adjacency (A01[i,j]=1 means i->j) into mark matrix:
      i->j => M[i,j]=1 (arrow at j), M[j,i]=2 (tail at i)
    """
    A01 = np.asarray(A01)
    if A01.ndim != 2 or A01.shape[0] != A01.shape[1]:
        raise ValueError("A01 must be square.")
    if not np.isin(A01, [0, 1]).all():
        raise ValueError("dag_adj01_to_marks expects only 0/1 entries.")

    n = A01.shape[0]
    M = np.zeros((n, n), dtype=np.int8)

    for i in range(n):
        js = np.where(A01[i] == 1)[0]
        for j in js:
            if i == j:
                continue
            # If both i->j and j->i appear, treat as bidirected (<->) by default.
            if A01[j, i] == 1:
                M[i, j] = 1
                M[j, i] = 1
            else:
                M[i, j] = 1
                M[j, i] = 2

    np.fill_diagonal(M, 0)
    return M


def mag_two_mats_to_marks(M_di: np.ndarray, M_bi: np.ndarray) -> np.ndarray:
    """
    Convert (directed, bidirected) MAG representation from earlier code into mark matrix.

    Expected:
      - M_di[i,j]=1 and M_di[j,i]=2 encodes i->j
      - M_bi[i,j]=M_bi[j,i]=1 encodes i<->j
    """
    M_di = np.asarray(M_di)
    M_bi = np.asarray(M_bi)
    if M_di.shape != M_bi.shape or M_di.ndim != 2 or M_di.shape[0] != M_di.shape[1]:
        raise ValueError("M_di and M_bi must be same square shape.")

    n = M_di.shape[0]
    M = np.zeros((n, n), dtype=np.int8)

    # directed: copy arrowhead/tail marks into the mark matrix
    # recall: entry encodes mark at column node
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            if M_di[i, j] == 1 and M_di[j, i] == 2:
                # i->j: mark at j is arrowhead, mark at i is tail
                M[i, j] = 1
                M[j, i] = 2

    # bidirected: arrowheads at both ends
    for i in range(n):
        for j in range(i + 1, n):
            if M_bi[i, j] == 1:
                M[i, j] = 1
                M[j, i] = 1

    np.fill_diagonal(M, 0)
    return M


def normalize_to_marks(A: np.ndarray) -> np.ndarray:
    """
    Normalize an input adjacency into a mark matrix with codes {0,1,2,3}.

    Supported inputs:
      1) Mark matrix already (contains any of {2,3} or values subset of {0,1,2,3})
      2) DAG adjacency 0/1 (values subset of {0,1}) interpreted as i->j edges
    """
    A = np.asarray(A)
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError("A must be square.")

    vals = np.unique(A)
    if np.isin(vals, [0, 1]).all():
        return dag_adj01_to_marks(A)
    if np.isin(vals, [0, 1, 2, 3]).all():
        M = A.astype(np.int8, copy=True)
        np.fill_diagonal(M, 0)
        return M
    raise ValueError(f"Unsupported adjacency values {vals}; expected subset of {{0,1}} or {{0,1,2,3}}.")


# ============================================================
# Metrics: ACC and DIS
# ============================================================

def edge_mark_metrics(
    A_pred: np.ndarray,
    A_true_mag: np.ndarray,
    *,
    acc_empty_value: float = np.nan,
    dis_empty_value: float = np.nan,
    cnt_empty_value: float = np.nan,
) -> Tuple[float, float, dict]:
    """
    Compute ACC, DIS, and (in details) CNT between a predicted mixed graph and the true MAG.

    Conventions (mark matrix):
      - 0: no edge-mark
      - 1: arrowhead at the column node
      - 2: tail at the column node
      - 3: circle (unknown / not revealed) at the column node

    Metrics:
      - ACC: (# correct revealed marks) / (# revealed marks), where "revealed" excludes 0 and 3.
      - DIS: (# marks that must change in A_pred to match A_true_mag) / (# edge-marks in A_true_mag),
             where the denominator excludes 0 (i.e., counts the 2m endpoint marks of true edges).
      - CNT: (# revealed marks, excluding 0 and 3) / (# edge-marks in A_true_mag).

    Note: CNT is a *coverage* measure. If A_pred contains extra edges that are not in A_true_mag,
    CNT may exceed 1.0 under this definition.
    """
    A_pred = normalize_to_marks(A_pred)
    A_true_mag = normalize_to_marks(A_true_mag)

    if A_pred.shape != A_true_mag.shape:
        raise ValueError("A_pred and A_true_mag must have the same shape.")

    # Work off-diagonal only
    off = ~np.eye(A_pred.shape[0], dtype=bool)

    P = A_pred[off]
    T = A_true_mag[off]

    # ----------------
    # ACC (precision on revealed endpoint marks)
    # ----------------
    revealed_mask = (P != 0) & (P != 3)
    acc_den = int(revealed_mask.sum())
    acc_num = int(((P == T) & revealed_mask).sum())
    ACC = acc_num / acc_den if acc_den > 0 else acc_empty_value

    # ----------------
    # DIS (normalized edit distance on endpoint marks of true edges)
    # ----------------
    truth_mask = (T != 0)  # counts endpoint marks of true edges (tails/arrowheads only)
    dis_den = int(truth_mask.sum())
    dis_num = int(((P != T) & truth_mask).sum())
    DIS = dis_num / dis_den if dis_den > 0 else dis_empty_value

    # ----------------
    # CNT (coverage relative to true endpoint marks)
    # ----------------
    CNT = acc_den / dis_den if dis_den > 0 else cnt_empty_value

    details = {
        "acc_num_correct_revealed": acc_num,
        "acc_den_revealed": acc_den,
        "dis_num_marks_to_change": dis_num,
        "dis_den_truth_marks": dis_den,
        "CNT": CNT,
    }
    return ACC, DIS, details


def edge_mark_metrics_acc_dis_cnt(
    A_pred: np.ndarray,
    A_true_mag: np.ndarray,
    *,
    acc_empty_value: float = np.nan,
    dis_empty_value: float = np.nan,
    cnt_empty_value: float = np.nan,
) -> Tuple[float, float, float, dict]:
    """Convenience wrapper returning (ACC, DIS, CNT, details)."""
    ACC, DIS, details = edge_mark_metrics(
        A_pred,
        A_true_mag,
        acc_empty_value=acc_empty_value,
        dis_empty_value=dis_empty_value,
        cnt_empty_value=cnt_empty_value,
    )
    return ACC, DIS, float(details.get("CNT", np.nan)), details


def _directed_edges_from_encoded(A_di: np.ndarray) -> List[Tuple[int, int]]:
    """ i->j iff A_di[i, j] == 1."""
    A_di = np.asarray(A_di)
    n = A_di.shape[0]
    edges = []
    for i in range(n):
        js = np.where(A_di[i] == 1)[0]
        for j in js:
            if i != j:
                edges.append((i, int(j)))
    return edges


def _bidirected_pairs(A_bi: np.ndarray) -> List[Tuple[int, int]]:
    """A_bi symmetric with 1 meaning i<->j."""
    A_bi = np.asarray(A_bi)
    n = A_bi.shape[0]
    pairs = []
    for i in range(n):
        for j in range(i + 1, n):
            if A_bi[i, j] == 1:
                pairs.append((i, j))
    return pairs


def _sample_binary_ps_for_parent_configs(
    n_parents: int,
    rng: np.random.Generator,
    *,
    min_prob: float = 0.02,
    beta_a: float = 0.5,
    beta_b: float = 0.5,
    min_variation: float = 0.4,
    max_tries: int = 200,
) -> Dict[Tuple[int, ...], float]:
    """
    Returns dict: parent_config_tuple -> p = P(X=1 | parents=config)

    - Strict positivity enforced by clipping to [min_prob, 1-min_prob]
    - "Strongness" encouraged by requiring max(p)-min(p) >= min_variation when n_parents>0
    """
    if n_parents == 0:
        p = float(rng.beta(beta_a, beta_b))
        p = float(np.clip(p, min_prob, 1.0 - min_prob))
        return {(): p}

    configs = list(itertools.product([0, 1], repeat=n_parents))
    for _ in range(max_tries):
        ps = rng.beta(beta_a, beta_b, size=len(configs)).astype(float)
        ps = np.clip(ps, min_prob, 1.0 - min_prob)
        if float(ps.max() - ps.min()) >= min_variation:
            return {cfg: float(p) for cfg, p in zip(configs, ps)}

    # fallback: accept whatever we got last (still strictly positive)
    return {cfg: float(p) for cfg, p in zip(configs, ps)}


def _set_binary_cpt_from_probs(
    bn: gum.BayesNet,
    var_name: str,
    parent_names: Sequence[str],
    probs: Dict[Tuple[int, ...], float],
) -> None:
    """
    Sets CPT for binary variable var_name with given ordered parent_names.
    Uses dictionary subscripting:
      bn.cpt(var)[{parent1:v1, parent2:v2, ...}] = [P(0|...), P(1|...)]
    """
    if len(parent_names) == 0:
        p1 = probs[()]
        bn.cpt(var_name).fillWith([1.0 - p1, p1])
        return

    for cfg, p1 in probs.items():
        assign = {pname: int(val) for pname, val in zip(parent_names, cfg)}
        bn.cpt(var_name)[assign] = [1.0 - float(p1), float(p1)]


def generate_interventional_data_pyagrum_from_admg(
    A_di: np.ndarray,
    A_bi: np.ndarray,
    targets: Union[Iterable[Set[int]], Iterable[FrozenSet[int]], Dict[int, FrozenSet[int]]],
    *,
    n_samples: int = 50_000,
    seed: Optional[int] = None,
    # CPT sampling controls
    min_prob: float = 0.02,
    beta_a: float = 0.7,
    beta_b: float = 0.7,
    min_variation: float = 0.25,
) -> Dict[int, Dict[str, object]]:
    """
    Build a latent-variable BN (one latent per bidirected edge) using pyAgrum and sample data.

    Controlled soft intervention (per-node):
      - For each observed node v that appears in ANY target, we sample ONE intervened CPT for v.
      - In domain k with target T_k, we replace CPTs of v in T_k with that fixed intervened CPT.

    Returns:
      out[k] = {"target": frozenset(...), "data": X_obs}  # only observed columns

    Notes:
      - pyAgrum BNs are discrete (so this generator is for binary variables). :contentReference[oaicite:2]{index=2}
      - CPT editing via bn.cpt(...) and dict subscripting is standard in pyAgrum tutorials. :contentReference[oaicite:3]{index=3}
      - Sampling uses BNDatabaseGenerator.drawSamples and to_pandas. :contentReference[oaicite:4]{index=4}
    """
    rng = np.random.default_rng(seed)

    A_di = np.asarray(A_di)
    A_bi = np.asarray(A_bi)
    n_obs = A_di.shape[0]
    if A_bi.shape != (n_obs, n_obs):
        raise ValueError("A_bi must have same shape as A_di.")
    if n_samples <= 0:
        raise ValueError("n_samples must be positive.")

    # Normalize targets into an indexed dict
    if isinstance(targets, dict):
        target_dict: Dict[int, FrozenSet[int]] = {int(k): frozenset(v) for k, v in targets.items()}
        domain_ids = sorted(target_dict.keys())
    else:
        target_list = [frozenset(t) for t in targets]
        target_dict = {k: t for k, t in enumerate(target_list)}
        domain_ids = list(range(len(target_list)))

    target_dict = ensure_obs_domain_zero(target_dict)
    domain_ids = sorted(target_dict.keys())
    if len(domain_ids) == 0:
        raise ValueError("Need at least one domain/target.")

    for t in target_dict.values():
        for v in t:
            if not (0 <= v < n_obs):
                raise ValueError(f"Target contains node {v}, valid range is 0..{n_obs-1}.")

    # --- Build latent-realized BN structure ---
    bn = gum.BayesNet("LatentRealizedADMG")

    obs_names = [f"X{i}" for i in range(n_obs)]
    for name in obs_names:
        bn.add(name, 2)  # binary domain 0/1 (tutorial uses bn.add(name,2)) :contentReference[oaicite:5]{index=5}

    bi_pairs = _bidirected_pairs(A_bi)
    lat_names = [f"L{k}" for k in range(len(bi_pairs))]
    for name in lat_names:
        bn.add(name, 2)

    # Track parents in a stable order (we control insertion order)
    parents_of: Dict[str, List[str]] = {name: [] for name in obs_names + lat_names}

    # Add directed arcs among observed
    for i, j in _directed_edges_from_encoded(A_di):
        src = obs_names[i]
        dst = obs_names[j]
        bn.addArc(src, dst)  # :contentReference[oaicite:6]{index=6}
        parents_of[dst].append(src)

    # Add latent arcs for each bidirected edge
    for k, (i, j) in enumerate(bi_pairs):
        l = lat_names[k]
        xi = obs_names[i]
        xj = obs_names[j]
        bn.addArc(l, xi)
        bn.addArc(l, xj)
        parents_of[xi].append(l)
        parents_of[xj].append(l)

    # --- Sample base CPTs (strictly positive, "strong-ish") ---
    base_probs: Dict[str, Dict[Tuple[int, ...], float]] = {}

    # latents: simple roots (still strictly positive)
    for lname in lat_names:
        base_probs[lname] = _sample_binary_ps_for_parent_configs(
            0, rng, min_prob=min_prob, beta_a=0.5, beta_b=0.5, min_variation=0.4
        )
        _set_binary_cpt_from_probs(bn, lname, [], base_probs[lname])

    # observed nodes
    for vname in obs_names:
        pa = parents_of[vname]
        base_probs[vname] = _sample_binary_ps_for_parent_configs(
            len(pa), rng,
            min_prob=min_prob, beta_a=beta_a, beta_b=beta_b,
            min_variation=min_variation
        )
        _set_binary_cpt_from_probs(bn, vname, pa, base_probs[vname])

    # --- Controlled per-node intervention CPTs ---
    intervened_nodes: Set[int] = set().union(*target_dict.values())
    intervened_probs: Dict[str, Dict[Tuple[int, ...], float]] = {}

    for v in intervened_nodes:
        vname = obs_names[v]
        pa = parents_of[vname]
        intervened_probs[vname] = _sample_binary_ps_for_parent_configs(
            len(pa), rng,
            min_prob=min_prob, beta_a=beta_a, beta_b=beta_b,
            min_variation=min_variation
        )

    # --- Sample each domain ---
    out: Dict[int, Dict[str, object]] = {}
    for d in domain_ids:
        t = target_dict[d]

        # Copy BN (BayesNet(source) copies a BN). :contentReference[oaicite:7]{index=7}
        bn_d = gum.BayesNet(bn)

        # Apply interventions: replace CPT of each targeted node with the controlled one
        for v in t:
            vname = obs_names[v]
            pa = parents_of[vname]
            _set_binary_cpt_from_probs(bn_d, vname, pa, intervened_probs[vname])

        # Sample
        gen = gum.BNDatabaseGenerator(bn_d)
        gen.setRandomVarOrder()
        gen.drawSamples(n_samples)  # :contentReference[oaicite:8]{index=8}
        df = gen.to_pandas()        # :contentReference[oaicite:9]{index=9}

        X = df[obs_names].to_numpy()
        # ensure integer 0/1 (robust if dtype is object)
        if X.dtype.kind not in "iu":
            X = X.astype(str).astype(int)
        else:
            X = X.astype(np.int8, copy=False)

        out[d] = {"target": t, "data": X}

    return out


# Marks
ARROW = 1
TAIL = 2
CIRCLE = 3
NOEDGE = 0

def adjacent(A: np.ndarray, i: int, j: int) -> bool:
    return A[i, j] != NOEDGE or A[j, i] != NOEDGE

def neighbors(A: np.ndarray, i: int) -> List[int]:
    # neighbor if any nonzero in either direction
    row = (A[i, :] != NOEDGE)
    col = (A[:, i] != NOEDGE)
    idx = np.where(row | col)[0]
    return [int(k) for k in idx if int(k) != i]

def mark_at(A: np.ndarray, u: int, v: int) -> int:
    """Mark at v on edge u-v is stored at A[u, v]."""
    return int(A[u, v])

def set_mark_at(A: np.ndarray, u: int, v: int, new_mark: int) -> bool:
    """
    Set mark at v on edge u-v to new_mark if not conflicting.
    Returns True iff it changed the matrix.
    Conflicts: trying to overwrite a fixed tail with arrowhead, etc.
    """
    cur = int(A[u, v])
    if cur == NOEDGE:
        return False
    if cur == new_mark:
        return False
    # We allow refining circles -> {arrow, tail}; but do not flip tail<->arrow.
    if cur == CIRCLE and new_mark in (ARROW, TAIL):
        A[u, v] = new_mark
        return True
    # Allow circle -> circle (noop handled above)
    # Allow setting ARROW when already ARROW handled above
    # Disallow ARROW<->TAIL flips:
    return False

def set_edge(A: np.ndarray, u: int, v: int, mark_u: int, mark_v: int) -> bool:
    """Set both endpoint marks; returns True if any change was made."""
    if not adjacent(A, u, v):
        return False
    changed = False
    changed |= set_mark_at(A, v, u, mark_u)  # mark at u
    changed |= set_mark_at(A, u, v, mark_v)  # mark at v
    return changed

def orient_directed(A: np.ndarray, u: int, v: int) -> bool:
    """Orient u -> v (TAIL at u, ARROW at v)."""
    return set_edge(A, u, v, TAIL, ARROW)

def orient_bidirected(A: np.ndarray, u: int, v: int) -> bool:
    """Orient u <-> v (ARROW at both)."""
    return set_edge(A, u, v, ARROW, ARROW)

def has_arrow_at(A: np.ndarray, u: int, v: int) -> bool:
    """Edge u-v has arrowhead at v."""
    return adjacent(A, u, v) and mark_at(A, u, v) == ARROW

def has_tail_at(A: np.ndarray, u: int, v: int) -> bool:
    """Edge u-v has tail at v."""
    return adjacent(A, u, v) and mark_at(A, u, v) == TAIL

def has_circle_at(A: np.ndarray, u: int, v: int) -> bool:
    return adjacent(A, u, v) and mark_at(A, u, v) == CIRCLE

def is_directed(A: np.ndarray, u: int, v: int) -> bool:
    """u -> v."""
    return adjacent(A, u, v) and mark_at(A, v, u) == TAIL and mark_at(A, u, v) == ARROW

def is_bidirected(A: np.ndarray, u: int, v: int) -> bool:
    return adjacent(A, u, v) and mark_at(A, u, v) == ARROW and mark_at(A, v, u) == ARROW

def is_unshielded_triple(A: np.ndarray, a: int, b: int, c: int) -> bool:
    return adjacent(A, a, b) and adjacent(A, b, c) and (not adjacent(A, a, c))

def get_sepset(sepsets: Dict[Tuple[int, int], FrozenSet[int]], i: int, j: int) -> FrozenSet[int]:
    key = (i, j) if i <= j else (j, i)
    return sepsets.get(key, frozenset())

# ---------- Potentially directed & uncovered path helpers (R9/R10) ----------

def step_is_pd(A: np.ndarray, cur: int, nxt: int) -> bool:
    """
    For a path direction cur -> nxt, edge must be:
      - not into cur  (no arrowhead at cur): mark_at(nxt, cur) != ARROW
      - not out of nxt (no tail at nxt):     mark_at(cur, nxt) != TAIL
    Matches Zhang's "potentially directed" definition. :contentReference[oaicite:3]{index=3}
    """
    if not adjacent(A, cur, nxt):
        return False
    return (mark_at(A, nxt, cur) != ARROW) and (mark_at(A, cur, nxt) != TAIL)

def find_uncovered_pd_paths_first_hops(
    A: np.ndarray,
    start: int,
    end: int,
    *,
    max_len: Optional[int] = None,
) -> Dict[int, List[int]]:
    """
    Returns dict: first_hop -> one example path [start, first_hop, ..., end]
    for uncovered potentially directed paths start ~> end.

    Used by R9 (needs specific second vertex) and R10 (needs μ/ω first hops).
    """
    n = A.shape[0]
    if max_len is None:
        max_len = n

    out: Dict[int, List[int]] = {}
    path: List[int] = [start]
    visited: Set[int] = {start}

    def dfs(prev: Optional[int], cur: int, first_hop: Optional[int]) -> None:
        if len(path) > max_len:
            return
        if cur == end:
            if first_hop is not None and first_hop not in out:
                out[first_hop] = path.copy()
            return
        for nxt in neighbors(A, cur):
            if nxt in visited:
                continue
            if not step_is_pd(A, cur, nxt):
                continue
            # uncovered constraint: prev and nxt not adjacent
            if prev is not None and adjacent(A, prev, nxt):
                continue
            visited.add(nxt)
            path.append(nxt)
            dfs(cur, nxt, nxt if first_hop is None else first_hop)
            path.pop()
            visited.remove(nxt)

    dfs(None, start, None)
    return out

# ---------- Discriminating path helper (R4) ----------

def find_discriminating_path(
    A: np.ndarray,
    beta: int,
    gamma: int,
) -> Optional[List[int]]:
    """
    Find one discriminating path p = [theta, ..., alpha, beta, gamma] for beta,
    using Zhang's definition: vertices between theta and beta are colliders and parents of gamma. :contentReference[oaicite:4]{index=4}

    Implementation uses the key structural fact:
      consecutive colliders imply bidirected edges between them (arrowheads at both ends).
    """
    if not adjacent(A, beta, gamma):
        return None

    # Candidate collider/parent-of-gamma vertices: v -> gamma
    parent_of_gamma = {v for v in range(A.shape[0]) if v != gamma and is_directed(A, v, gamma)}

    # alpha must satisfy: beta *-> alpha (arrow at alpha) AND alpha -> gamma
    for alpha in neighbors(A, beta):
        if alpha == gamma:
            continue
        if alpha not in parent_of_gamma:
            continue
        if not has_arrow_at(A, beta, alpha):
            continue  # need arrowhead at alpha on beta-alpha

        # BFS in the subgraph induced by parent_of_gamma with bidirected edges
        # (collider chain): u <-> v means arrowheads at both.
        from collections import deque
        q = deque([alpha])
        parent: Dict[int, Optional[int]] = {alpha: None}

        def is_chain_edge(u: int, v: int) -> bool:
            return is_bidirected(A, u, v) and (u in parent_of_gamma) and (v in parent_of_gamma)

        while q:
            v = q.popleft()

            # theta adjacent to v with arrowhead at v, and theta not adjacent gamma
            for theta in neighbors(A, v):
                if theta == beta or theta == gamma:
                    continue
                if adjacent(A, theta, gamma):
                    continue
                if not has_arrow_at(A, theta, v):
                    continue  # need theta *-> v so v is collider

                # reconstruct path theta - v - ... - alpha - beta - gamma
                chain = [v]
                cur = v
                while cur is not None and cur != alpha:
                    cur = parent[cur]
                    if cur is None:
                        break
                    chain.append(cur)
                if not chain or chain[-1] != alpha:
                    continue
                chain = chain[::-1]  # alpha ... v
                p = [theta] + chain + [beta, gamma]
                # must be at least 3 edges => >=4 vertices
                if len(p) < 4:
                    continue

                # Validate collider + parent-of-gamma condition on vertices between theta and beta
                ok = True
                # indices 1..len(p)-3 inclusive are between theta and beta
                for i in range(1, len(p) - 2):
                    z = p[i]
                    left = p[i - 1]
                    right = p[i + 1]
                    # collider on path: left *-> z and right *-> z (arrowheads at z)
                    if not (has_arrow_at(A, left, z) and has_arrow_at(A, right, z)):
                        ok = False
                        break
                    # parent of gamma
                    if not is_directed(A, z, gamma):
                        ok = False
                        break
                if ok:
                    return p

            for u in neighbors(A, v):
                if u in parent:
                    continue
                if is_chain_edge(v, u):
                    parent[u] = v
                    q.append(u)

    return None
# ============================================================
# MAG/ADMG helpers for MAG-listing / validation
# ============================================================

def marks_to_admg_mats(
    M: np.ndarray,
    *,
    allow_undirected: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Convert a fully-oriented mark matrix (no circles) into (A_di, A_bi) matrices.

    Supported edge types:
      - u -> v : mark at u is TAIL, at v is ARROW
      - u <-> v: marks at both endpoints are ARROW

    Undirected edges (TAIL-TAIL) are not representable in our ADMG encoding and will
    raise unless  explicitly decide to handle selection bias elsewhere.

    Notes:
      - This conversion is used for MAG validation and IMAG realizability checks.
      - The returned (A_di, A_bi) allow *both* a directed and a bidirected edge
        between the same pair (by setting both matrices), which corresponds to a
        DAG with a latent confounder plus a direct causal edge.
    """
    M = normalize_to_marks(M)
    n = M.shape[0]
    if M.shape[1] != n:
        raise ValueError('marks_to_admg_mats expects a square matrix.')

    if np.any(M == CIRCLE):
        raise ValueError('marks_to_admg_mats expects no circle marks (fully oriented graph).')

    A_di = np.zeros((n, n), dtype=np.int8)
    A_bi = np.zeros((n, n), dtype=np.int8)

    for i in range(n):
        for j in range(i + 1, n):
            if not adjacent(M, i, j):
                continue
            mi = int(M[j, i])  # mark at i
            mj = int(M[i, j])  # mark at j

            # Sanity: adjacency must be symmetric in marks
            if mi == NOEDGE or mj == NOEDGE:
                raise ValueError(f'Invalid asymmetric adjacency at ({i},{j}).')

            if mi == TAIL and mj == ARROW:
                # i -> j
                A_di[i, j] = 1
                A_di[j, i] = 2
            elif mi == ARROW and mj == TAIL:
                # j -> i
                A_di[j, i] = 1
                A_di[i, j] = 2
            elif mi == ARROW and mj == ARROW:
                # i <-> j
                A_bi[i, j] = 1
                A_bi[j, i] = 1
            elif mi == TAIL and mj == TAIL:
                if allow_undirected:
                    raise ValueError('Undirected (TAIL-TAIL) edges are not supported by this ADMG encoding.')
                raise ValueError('Encountered an undirected (TAIL-TAIL) edge but allow_undirected=False.')
            else:
                raise ValueError(f'Invalid endpoint mark combination on edge {i}-{j}: (mark_i={mi}, mark_j={mj}).')

    np.fill_diagonal(A_di, 0)
    np.fill_diagonal(A_bi, 0)
    return A_di, A_bi


def is_mag_mark_matrix(
    M: np.ndarray,
    *,
    allow_undirected: bool = False,
    check_maximality: bool = True,
    max_cond_set_size: Optional[int] = None,
    return_reason: bool = False,
) -> Union[bool, Tuple[bool, str]]:
    """
    Heuristic validity check for whether a *fully oriented* mark matrix is a MAG.

    Checks (no selection bias assumed by default):
      1) No circle marks.
      2) Directed part is acyclic.
      3) For every bidirected edge u<->v, neither is an ancestor of the other.
      4) (Optional) Maximality: for every non-adjacent pair, there exists *some* separating set.

    The maximality check uses brute-force m-separability (via canonical latent DAG + Bayes-ball)
    and is exponential in n in the worst case. For n <= ~10-12 this is usually fine.

    If return_reason=True, returns (ok, reason).
    """
    M = normalize_to_marks(M)
    if np.any(M == CIRCLE):
        return (False, 'has_circle') if return_reason else False

    try:
        A_di, A_bi = marks_to_admg_mats(M, allow_undirected=allow_undirected)
    except ValueError as e:
        return (False, f'bad_marks: {e}') if return_reason else False

    # 1) directed acyclicity
    D = directed_adj_from_encoded(A_di)
    try:
        _toposort_dag(D)
    except Exception:
        return (False, 'directed_cycle') if return_reason else False

    # 2) ancestrality w.r.t. bidirected edges
    anc = ancestor_matrix_from_directed(A_di)
    n = M.shape[0]
    for i in range(n):
        for j in range(i + 1, n):
            if A_bi[i, j] == 1:
                if anc[i, j] or anc[j, i]:
                    return (False, 'bidir_ancestor_violation') if return_reason else False

    if not check_maximality:
        return (True, '') if return_reason else True

    # 3) maximality (brute-force)
    parents, children, n_obs = build_canonical_latent_dag(A_di, A_bi)
    cache = (parents, children, n_obs)

    nodes = list(range(n))
    for i in range(n):
        for j in range(i + 1, n):
            if adjacent(M, i, j):
                continue

            candidates = [v for v in nodes if v not in (i, j)]
            if max_cond_set_size is None:
                max_k = len(candidates)
            else:
                max_k = min(max_cond_set_size, len(candidates))

            separable = False
            for k in range(max_k + 1):
                for comb in itertools.combinations(candidates, k):
                    if m_separated_admg_given_Z(A_di, A_bi, i, j, comb, _cache=cache):
                        separable = True
                        break
                if separable:
                    break

            if not separable:
                return (False, 'non_maximal') if return_reason else False

    return (True, '') if return_reason else True
