from __future__ import annotations

from Utils import *
import itertools
from dataclasses import dataclass, field
from typing import Callable, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, Union

import numpy as np
from scipy.stats import chi2, chi2_contingency, fisher_exact


# ============================================================
# Graph / node bookkeeping
# ============================================================

FNode = Tuple[str, int, int]  # ("F", i, j) with i<j

def make_fnode(i: int, j: int) -> FNode:
    if i == j:
        raise ValueError("F node requires i!=j")
    a, b = (i, j) if i < j else (j, i)
    return ("F", a, b)

def is_fnode(node) -> bool:
    return isinstance(node, tuple) and len(node) == 3 and node[0] == "F"


@dataclass
class IFCIBaseState:
    A: np.ndarray                              # adjacency (0 none, 1 arrow, 2 tail, 3 circle)
    idx2node: Dict[int, Union[int, FNode]]     # expanded index -> (obs int) or FNode
    node2idx: Dict[Union[int, FNode], int]     # reverse map
    fnode_of_pair: Dict[Tuple[int, int], int]  # (domain_i,domain_j) sorted -> idx in A
    sepsets: Dict[Tuple[int, int], FrozenSet[int]]  # (u,v) sorted indices -> sepset indices
    targets_dict: Dict[int, FrozenSet[int]]    # domain -> target set over observables
    n_obs: int                                 # number of observed variables

    # --- used by fast_soft_fci and completion ---
    transit_pairs: Set[Tuple[int, int]] = field(default_factory=set)
    extras: Dict[str, object] = field(default_factory=dict)

    def is_F_idx(self, idx: int) -> bool:
        return is_fnode(self.idx2node[idx])

# ============================================================
# Discrete CI tests (binary default)
# ============================================================

def _chisq_or_g2_conditional(
    data: np.ndarray,
    x: int,
    y: int,
    cond: Sequence[int],
    *,
    test: str = "chisq",  # "chisq" or "g2"
) -> float:
    """
    Conditional independence test for discrete variables via stratified chi-square or G^2.
    Returns a p-value for H0: X ⟂ Y | cond.

    Works best for binary (0/1), but supports any small discrete alphabets.
    """
    data = np.asarray(data)
    n, p = data.shape
    if not (0 <= x < p and 0 <= y < p):
        raise ValueError("x, y out of range")
    for c in cond:
        if not (0 <= c < p):
            raise ValueError("cond index out of range")
    if x in cond or y in cond:
        raise ValueError("cond cannot include x or y")

    x_vals = np.unique(data[:, x])
    y_vals = np.unique(data[:, y])
    r, c = len(x_vals), len(y_vals)
    x_map = {v: i for i, v in enumerate(x_vals)}
    y_map = {v: j for j, v in enumerate(y_vals)}

    # strata assignment
    if len(cond) == 0:
        inv = np.zeros(n, dtype=int)
        n_strata = 1
    else:
        cond_mat = data[:, cond]
        _, inv = np.unique(cond_mat, axis=0, return_inverse=True)
        n_strata = inv.max() + 1

    chi_stat = 0.0
    df_total = 0

    for s in range(n_strata):
        rows = np.where(inv == s)[0]
        if rows.size == 0:
            continue

        # build contingency table
        tab = np.zeros((r, c), dtype=np.int64)
        xv = data[rows, x]
        yv = data[rows, y]
        xi = np.array([x_map[v] for v in xv], dtype=int)
        yi = np.array([y_map[v] for v in yv], dtype=int)
        np.add.at(tab, (xi, yi), 1)

        # effective df (drop all-zero rows/cols)
        row_sums = tab.sum(axis=1)
        col_sums = tab.sum(axis=0)
        r_eff = int((row_sums > 0).sum())
        c_eff = int((col_sums > 0).sum())
        if r_eff <= 1 or c_eff <= 1:
            continue
        df_s = (r_eff - 1) * (c_eff - 1)

        total = tab.sum()
        expected = (row_sums[:, None] * col_sums[None, :]) / max(1, total)

        mask = expected > 0
        if test == "chisq":
            chi_stat += float(((tab[mask] - expected[mask]) ** 2 / expected[mask]).sum())
        elif test == "g2":
            obs = tab[mask].astype(float)
            exp = expected[mask]
            # 2 * sum O log(O/E), with convention O=0 contributes 0
            nz = obs > 0
            chi_stat += float((2.0 * obs[nz] * np.log(obs[nz] / exp[nz])).sum())
        else:
            raise ValueError("test must be 'chisq' or 'g2'")

        df_total += df_s

    # If df_total==0, we have no information; return p=1 (treat as independent)
    if df_total == 0:
        return 1.0
    return float(chi2.sf(chi_stat, df_total))


def discrete_ci_pvalue(
    data: np.ndarray,
    x: int,
    y: int,
    cond: Sequence[int],
    *,
    method: str = "chisq",   # "chisq", "g2", or "fisher"
) -> float:
    """
    Wrapper for CI tests. Default for binary is 'chisq'.

    Note: Fisher exact is only supported for *unconditional* 2x2 tests here.
    """
    if method in ("chisq", "g2"):
        return _chisq_or_g2_conditional(data, x, y, cond, test=method)

    if method == "fisher":
        if len(cond) != 0:
            #extend to CMH or stratified Fisher later; for now fall back.
            return _chisq_or_g2_conditional(data, x, y, cond, test="chisq")
        # unconditional 2x2 only
        x_vals = np.unique(data[:, x])
        y_vals = np.unique(data[:, y])
        if len(x_vals) != 2 or len(y_vals) != 2:
            return _chisq_or_g2_conditional(data, x, y, cond, test="chisq")
        # map to 0/1 in fixed order
        x0, x1 = x_vals[0], x_vals[1]
        y0, y1 = y_vals[0], y_vals[1]
        tab = np.zeros((2, 2), dtype=int)
        for xv, yv in zip(data[:, x], data[:, y]):
            tab[0 if xv == x0 else 1, 0 if yv == y0 else 1] += 1
        _, p = fisher_exact(tab, alternative="two-sided")
        return float(p)

    raise ValueError("method must be one of {'chisq','g2','fisher'}")


def find_sepset_discrete(
    data: np.ndarray,
    x: int,
    y: int,
    candidates: Sequence[int],
    *,
    max_k: int,
    alpha: float,
    ci_method: str,
) -> Optional[Tuple[int, ...]]:
    """
    Try to find a separating set S (subset of candidates) with |S|<=max_k
    such that pvalue(X ⟂ Y | S) > alpha.
    Returns S as a tuple if found, else None.
    """
    for k in range(0, min(max_k, len(candidates)) + 1):
        for S in itertools.combinations(candidates, k):
            p = discrete_ci_pvalue(data, x, y, S, method=ci_method)
            if p > alpha:
                return S
    return None






# ============================================================
# IFCI-base Phase 1+2
# ============================================================

def ifci_base_phase12(
    datasets: Dict[int, Dict[str, object]],
    targets: Optional[Dict[int, FrozenSet[int]]] = None,
    *,
    max_cond_set_size: int = 3,
    alpha: float = 0.05,
    ci_method: str = "chisq",   # default for binary
    seed: Optional[int] = None,
) -> IFCIBaseState:
    """
    Phase 1: attach F nodes and initialize complete circle graph.
    Phase 2: skeleton learning via (random-order) CI tests as specified.

    datasets[domain_idx]["data"] must be (n_samples, n_obs) numpy array of observed variables.
    datasets[domain_idx]["target"] optional; if targets is None we read it from datasets.

    Returns IFCIBaseState including:
      - A: adjacency over observed+F nodes with 0/3 marks (phase 2 only removes edges)
      - idx2node/node2idx mappings
      - fnode_of_pair mapping
      - sepsets dict keyed by sorted (u,v) indices in expanded graph
    """
    rng = np.random.default_rng(seed)

    domain_ids = sorted(datasets.keys())
    if len(domain_ids) < 2:
        raise ValueError("Need at least 2 domains to form F nodes.")

    # infer targets if not provided
    if targets is None:
        targets = {}
        for d in domain_ids:
            if "target" not in datasets[d]:
                raise ValueError("targets not provided and datasets[d] has no 'target'.")
            targets[d] = frozenset(datasets[d]["target"])  # type: ignore


    # observational domain: force to 0 if present; otherwise fall back to any empty-target domain
    if 0 in targets and len(targets[0]) == 0:
        obs_domain_id = 0
    else:
        obs_domain_id = next((d for d in domain_ids if len(targets[d]) == 0), None)

    # infer n_obs
    X0 = np.asarray(datasets[domain_ids[0]]["data"])
    if X0.ndim != 2:
        raise ValueError("datasets[d]['data'] must be a 2D array.")
    n_obs = X0.shape[1]
    for d in domain_ids:
        Xd = np.asarray(datasets[d]["data"])
        if Xd.shape[1] != n_obs:
            raise ValueError("All domains must have the same number of observed variables.")

    # --- Phase 1: create F nodes for each unordered pair of domains ---
    f_pairs: List[Tuple[int, int]] = []
    for i, di in enumerate(domain_ids):
        for dj in domain_ids[i + 1:]:
            f_pairs.append((di, dj))

    n_F = len(f_pairs)
    n_total = n_obs + n_F

    idx2node: Dict[int, Union[int, FNode]] = {}
    node2idx: Dict[Union[int, FNode], int] = {}

    # observed nodes 0..n_obs-1
    for v in range(n_obs):
        idx2node[v] = v
        node2idx[v] = v

    fnode_of_pair: Dict[Tuple[int, int], int] = {}
    for k, (di, dj) in enumerate(f_pairs):
        idx = n_obs + k
        fn = make_fnode(di, dj)
        idx2node[idx] = fn
        node2idx[fn] = idx
        fnode_of_pair[(min(di, dj), max(di, dj))] = idx

    # initialize complete circle graph (off-diagonal 3)
    A = np.full((n_total, n_total), 3, dtype=np.int8)
    np.fill_diagonal(A, 0)

    # sepsets recorded on expanded indices
    sepsets: Dict[Tuple[int, int], FrozenSet[int]] = {}

    all_F_idx = list(range(n_obs, n_total))

    # prepare random order of all unordered pairs in expanded graph
    pairs = [(u, v) for u in range(n_total) for v in range(u + 1, n_total)]
    rng.shuffle(pairs)

    # --- Phase 2: skeleton learning ---
    for u, v in pairs:
        if A[u, v] == 0:
            continue  # already removed

        nu = idx2node[u]
        nv = idx2node[v]

        u_is_F = is_fnode(nu)
        v_is_F = is_fnode(nv)

        key = (u, v)

        # Case 1: F - F
        if u_is_F and v_is_F:
            sepsets[key] = frozenset()  # empty
            A[u, v] = 0
            A[v, u] = 0
            continue

        # Case 2: obs - obs
        if (not u_is_F) and (not v_is_F):
            x = int(nu)  # observed index
            y = int(nv)

            if obs_domain_id is None:
                # should not happen if  enforced observational domain
                d = int(rng.choice(domain_ids))
            else:
                d = int(obs_domain_id)
            data = np.asarray(datasets[d]["data"])

            candidates = [w for w in range(n_obs) if w not in (x, y)]
            S = find_sepset_discrete(
                data, x, y, candidates,
                max_k=max_cond_set_size, alpha=alpha, ci_method=ci_method
            )
            if S is not None:
                # record sepset as S ∪ {all F nodes}
                sep = set(S) | set(all_F_idx)
                sepsets[key] = frozenset(sep)
                A[u, v] = 0
                A[v, u] = 0
            continue

        # Case 3: obs - F
        # ensure obs is (y_idx) and f is (f_idx)
        if u_is_F and (not v_is_F):
            f_idx, y_idx = u, v
            fnode = nu
            y = int(nv)
        else:
            f_idx, y_idx = v, u
            fnode = nv
            y = int(nu)

        assert is_fnode(fnode)
        _, di, dj = fnode  # type: ignore

        data_i = np.asarray(datasets[di]["data"])
        data_j = np.asarray(datasets[dj]["data"])

        # Build pooled data with an environment indicator E as last column
        # Test: Y ⟂ E | W
        Ei = np.zeros((data_i.shape[0], 1), dtype=data_i.dtype)
        Ej = np.ones((data_j.shape[0], 1), dtype=data_j.dtype)
        pooled = np.vstack([np.hstack([data_i, Ei]), np.hstack([data_j, Ej])])

        E_col = n_obs  # last column index in pooled
        candidates_W = [w for w in range(n_obs) if w != y]

        W = find_sepset_discrete(
            pooled, y, E_col, candidates_W,
            max_k=max_cond_set_size, alpha=alpha, ci_method=ci_method
        )
        if W is not None:
            # record sepset as W ∪ {all F nodes except this F}
            sep = set(W) | (set(all_F_idx) - {f_idx})
            sepsets[(min(u, v), max(u, v))] = frozenset(sep)
            A[u, v] = 0
            A[v, u] = 0

    return IFCIBaseState(
        A=A,
        idx2node=idx2node,
        node2idx=node2idx,
        fnode_of_pair=fnode_of_pair,
        sepsets=sepsets,
        targets_dict=targets, 
        n_obs=n_obs
    )

def ifci_base_phase3(
    A: np.ndarray,
    sepsets: Dict[Tuple[int, int], FrozenSet[int]],
    is_F: Callable[[int], bool],
    *,
    max_path_len: int | None = None,
) -> np.ndarray:
    """
    Phase 3:
      1) Orient F -> X for all adjacent X
      2) R0: unshielded colliders via Sepset
      3) Apply R1-R4 and R8-R10 until fixed point (skip selection-bias rules R5-R7). :contentReference[oaicite:5]{index=5}
    """
    A = A.copy()
    n = A.shape[0]

    # --- Step 1: orient out of F nodes ---
    for f in range(n):
        if not is_F(f):
            continue
        for x in neighbors(A, f):
            # force f -> x
            orient_directed(A, f, x)

    # --- Step 2: R0 (unshielded colliders) ---
    _apply_R0(A, sepsets)

    # --- Step 3: close under orientation rules ---
    changed = True
    while changed:
        changed = False
        changed |= _apply_R1(A)
        changed |= _apply_R2(A)
        changed |= _apply_R3(A)
        changed |= _apply_R4(A, sepsets)
        changed |= _apply_R8(A)
        changed |= _apply_R9(A, max_path_len=max_path_len)
        changed |= _apply_R10(A, max_path_len=max_path_len)

    return A


def _apply_R0(A: np.ndarray, sepsets: Dict[Tuple[int, int], FrozenSet[int]]) -> bool:
    changed = False
    n = A.shape[0]
    for b in range(n):
        nb = neighbors(A, b)
        for i in range(len(nb)):
            a = nb[i]
            for j in range(i + 1, len(nb)):
                c = nb[j]
                if not is_unshielded_triple(A, a, b, c):
                    continue
                S = get_sepset(sepsets, a, c)
                if b not in S:
                    # orient a *-> b <-* c (only force arrowheads at b)
                    changed |= set_mark_at(A, a, b, ARROW)
                    changed |= set_mark_at(A, c, b, ARROW)
    return changed


def _apply_R1(A: np.ndarray) -> bool:
    """
    R1: If α *-> β o-* γ and α not adjacent γ, then orient β -> γ. :contentReference[oaicite:6]{index=6}
    """
    changed = False
    n = A.shape[0]
    for beta in range(n):
        for alpha in neighbors(A, beta):
            if not has_arrow_at(A, alpha, beta):
                continue
            for gamma in neighbors(A, beta):
                if gamma == alpha:
                    continue
                if adjacent(A, alpha, gamma):
                    continue
                # beta o-* gamma means circle at beta on beta-gamma
                if not has_circle_at(A, gamma, beta):
                    continue
                changed |= orient_directed(A, beta, gamma)
    return changed


def _apply_R2(A: np.ndarray) -> bool:
    """
    R2: If α -> β *-> γ OR α *-> β -> γ, and α *-o γ, orient α *-> γ. :contentReference[oaicite:7]{index=7}
    """
    changed = False
    n = A.shape[0]
    for alpha in range(n):
        for beta in neighbors(A, alpha):
            for gamma in neighbors(A, beta):
                if gamma == alpha:
                    continue

                cond1 = is_directed(A, alpha, beta) and has_arrow_at(A, beta, gamma)
                cond2 = has_arrow_at(A, alpha, beta) and is_directed(A, beta, gamma)
                if not (cond1 or cond2):
                    continue

                # α *-o γ means circle at γ on α-γ
                if not has_circle_at(A, alpha, gamma):
                    continue

                # orient α *-> γ: add arrowhead at γ
                changed |= set_mark_at(A, alpha, gamma, ARROW)
    return changed


def _apply_R3(A: np.ndarray) -> bool:
    """
    R3 (the "triangle" rule): :contentReference[oaicite:8]{index=8}
    If α *-> β <-* γ, α *-o θ o-* γ, α not adjacent γ, and θ *-o β,
    then orient θ *-> β.
    """
    changed = False
    n = A.shape[0]
    for beta in range(n):
        # alpha, gamma must both point into beta
        into_beta = [v for v in neighbors(A, beta) if has_arrow_at(A, v, beta)]
        for i in range(len(into_beta)):
            alpha = into_beta[i]
            for j in range(i + 1, len(into_beta)):
                gamma = into_beta[j]
                if adjacent(A, alpha, gamma):
                    continue

                # pick theta such that alpha *-o theta and theta o-* gamma
                for theta in neighbors(A, alpha):
                    if theta == beta or theta == gamma or theta == alpha:
                        continue
                    if not has_circle_at(A, alpha, theta):
                        continue  # circle at theta on alpha-theta
                    if not adjacent(A, theta, gamma):
                        continue
                    if not has_circle_at(A, gamma, theta):
                        continue  # circle at theta on theta-gamma

                    # theta *-o beta: circle at beta on theta-beta
                    if not adjacent(A, theta, beta):
                        continue
                    if not has_circle_at(A, theta, beta):
                        continue

                    changed |= set_mark_at(A, theta, beta, ARROW)  # θ *-> β
    return changed


def _apply_R4(A: np.ndarray, sepsets: Dict[Tuple[int, int], FrozenSet[int]]) -> bool:
    """
    R4 (discriminating path): :contentReference[oaicite:9]{index=9}
    If u = <θ,...,α,β,γ> discriminates β and β o-* γ then:
      - if β in Sepset(θ,γ): orient β -> γ
      - else: orient α <-> β <-> γ
    """
    changed = False
    n = A.shape[0]
    for beta in range(n):
        for gamma in neighbors(A, beta):
            # need β o-* γ (circle at beta)
            if not has_circle_at(A, gamma, beta):
                continue

            p = find_discriminating_path(A, beta, gamma)
            if p is None:
                continue
            theta = p[0]
            alpha = p[-3]

            S = get_sepset(sepsets, theta, gamma)
            if beta in S:
                changed |= orient_directed(A, beta, gamma)
            else:
                changed |= orient_bidirected(A, alpha, beta)
                changed |= orient_bidirected(A, beta, gamma)
    return changed


def _apply_R8(A: np.ndarray) -> bool:
    """
    R8: If α -> β -> γ OR α --o β -> γ, and α o-> γ, orient α -> γ. :contentReference[oaicite:10]{index=10}
    (We don't rely on undirected edges, but α --o β can occur as tail-circle.)
    """
    changed = False
    n = A.shape[0]
    for alpha in range(n):
        for gamma in neighbors(A, alpha):
            # α o-> γ means circle at α and arrow at γ
            if not (has_circle_at(A, gamma, alpha) and has_arrow_at(A, alpha, gamma)):
                continue

            for beta in neighbors(A, alpha):
                if beta == gamma:
                    continue
                if not adjacent(A, beta, gamma):
                    continue

                cond_chain1 = is_directed(A, alpha, beta) and is_directed(A, beta, gamma)

                # α --o β means tail at α, circle at β
                cond_chain2 = (adjacent(A, alpha, beta) and
                               (A[beta, alpha] == TAIL) and (A[alpha, beta] == CIRCLE) and
                               is_directed(A, beta, gamma))

                if cond_chain1 or cond_chain2:
                    # orient α -> γ means tail at α on α-γ
                    changed |= set_mark_at(A, gamma, alpha, TAIL)
    return changed


def _apply_R9(A: np.ndarray, *, max_path_len: int | None = None) -> bool:
    """
    R9: If α o-> γ and there exists an uncovered p.d. path
        p = <α,β,θ,...,γ> with γ not adjacent β, orient α -> γ. :contentReference[oaicite:11]{index=11}
    """
    changed = False
    n = A.shape[0]
    for alpha in range(n):
        for gamma in neighbors(A, alpha):
            if not (has_circle_at(A, gamma, alpha) and has_arrow_at(A, alpha, gamma)):
                continue  # α o-> γ

            paths_by_first = find_uncovered_pd_paths_first_hops(A, alpha, gamma, max_len=max_path_len)
            for beta, path in paths_by_first.items():
                if beta == gamma:
                    continue
                if adjacent(A, gamma, beta):
                    continue  # need γ and β not adjacent
                # found such path => orient α -> γ
                changed |= set_mark_at(A, gamma, alpha, TAIL)
                if changed:
                    break
    return changed


def _apply_R10(A: np.ndarray, *, max_path_len: int | None = None) -> bool:
    """
    R10: If α o-> γ, β -> γ <- θ, and there exist uncovered p.d. paths
         p1: α ~> β (first hop μ), p2: α ~> θ (first hop ω),
         with μ != ω and μ not adjacent ω, orient α -> γ. :contentReference[oaicite:12]{index=12}
    """
    changed = False
    n = A.shape[0]
    for alpha in range(n):
        for gamma in neighbors(A, alpha):
            if not (has_circle_at(A, gamma, alpha) and has_arrow_at(A, alpha, gamma)):
                continue  # α o-> γ

            # parents of gamma: v -> gamma
            parents = [v for v in neighbors(A, gamma) if is_directed(A, v, gamma)]
            if len(parents) < 2:
                continue

            # Precompute possible first hops for alpha ~> each parent
            first_hops: Dict[int, Set[int]] = {}
            for v in parents:
                paths = find_uncovered_pd_paths_first_hops(A, alpha, v, max_len=max_path_len)
                first_hops[v] = set(paths.keys())

            for i in range(len(parents)):
                beta = parents[i]
                for j in range(i + 1, len(parents)):
                    theta = parents[j]
                    for mu in first_hops.get(beta, set()):
                        for omega in first_hops.get(theta, set()):
                            if mu == omega:
                                continue
                            if adjacent(A, mu, omega):
                                continue
                            changed |= set_mark_at(A, gamma, alpha, TAIL)
                            if changed:
                                break
                        if changed:
                            break
                    if changed:
                        break
                if changed:
                    break
    return changed


# -----------------------------------------
# Orientation-rule plugin interface (use IFCIBaseState)
# -----------------------------------------
OrientationRule = Callable[[np.ndarray, IFCIBaseState], bool]


def _is_F_idx(state: IFCIBaseState, idx: int) -> bool:
    return is_fnode(state.idx2node[idx])


def orient_ifci_base_closure(
    A: np.ndarray,
    state: IFCIBaseState,
    *,
    extra_rules: Optional[List[OrientationRule]] = None,
    max_path_len: Optional[int] = None,
) -> np.ndarray:
    """
    IFCI-base Phase III closure + optional extra rules.
    Applies:
      - F -> X for F-adjacent X
      - R0
      - iterate R1–R4 and R8–R10 and extra_rules until fixed point
    """
    if extra_rules is None:
        extra_rules = []

    A = A.copy()

    # (1) Orient edges out of F nodes
    for f in range(A.shape[0]):
        if not _is_F_idx(state, f):
            continue
        for x in neighbors(A, f):
            orient_directed(A, f, x)

    # (2) R0 once
    _apply_R0(A, state.sepsets)

    # (3) Fixpoint closure including extra rules
    changed = True
    while changed:
        changed = False

        changed |= _apply_R1(A)
        changed |= _apply_R2(A)
        changed |= _apply_R3(A)
        changed |= _apply_R4(A, state.sepsets)
        changed |= _apply_R8(A)
        changed |= _apply_R9(A, max_path_len=max_path_len)
        changed |= _apply_R10(A, max_path_len=max_path_len)

        for rule in extra_rules:
            changed |= bool(rule(A, state))

    return A


def ifci_base(
    datasets: Dict[int, Dict[str, object]],
    targets: Dict[int, FrozenSet[int]],
    *,
    max_cond_set_size: int = 3,
    ci_method: str = "chisq",
    alpha: float = 0.01,
    seed: Optional[int] = None,
    max_path_len: Optional[int] = None,
    run_phase3: bool = True,
    extra_orientation_rules: Optional[List[OrientationRule]] = None,
) -> IFCIBaseState:
    """
    IFCI-base = Phase12 + (optional) Phase3 closure.
    """
    state = ifci_base_phase12(
        datasets=datasets,
        targets=targets,
        max_cond_set_size=max_cond_set_size,
        alpha=alpha,
        ci_method=ci_method,
        seed=seed,
    )

    if run_phase3:
        state.A = orient_ifci_base_closure(
            state.A, state,
            extra_rules=extra_orientation_rules,
            max_path_len=max_path_len,
        )

    return state



# ------------------------------------------------------------
# Kocaoglu et al. 2019 IFCI = IFCI-base + one extra rule (9)
# ------------------------------------------------------------

def rule_kocaoglu2019_inducing_paths(A: np.ndarray, state: IFCIBaseState) -> bool:
    changed = False
    n = A.shape[0]

    for f in range(n):
        if not state.is_F_idx(f):
            continue
        node = state.idx2node[f]
        if not (isinstance(node, tuple) and len(node) == 3 and node[0] == "F"):
            continue
        i, j = int(node[1]), int(node[2])
        Ti = state.targets_dict.get(i, frozenset())
        Tj = state.targets_dict.get(j, frozenset())

        diff = Ti.symmetric_difference(Tj)
        if len(diff) != 1:
            continue
        (X,) = tuple(diff)

        for Y in neighbors(A, f):
            if Y == X:
                continue
            if X >= state.n_obs or Y >= state.n_obs:
                continue
            if adjacent(A, X, Y):
                changed |= orient_directed(A, X, Y)

    return changed


def ifci(
    datasets: Dict[int, Dict[str, object]],
    targets: Dict[int, FrozenSet[int]],
    *,
    max_cond_set_size: int = 3,
    ci_method: str = "chisq",
    alpha: float = 0.01,
    seed: Optional[int] = None,
    max_path_len: Optional[int] = None,
) -> IFCIBaseState:
    state = ifci_base(
        datasets=datasets,
        targets=targets,
        max_cond_set_size=max_cond_set_size,
        ci_method=ci_method,
        alpha=alpha,
        seed=seed,
        max_path_len=max_path_len,
        run_phase3=True,
        extra_orientation_rules=None,
    )

    if rule_kocaoglu2019_inducing_paths(state.A, state):
        state.A = orient_ifci_base_closure(state.A, state, extra_rules=None, max_path_len=max_path_len)

    return state

def _global_targets(state: IFCIBaseState) -> Set[int]:
    K: Set[int] = set()
    for t in state.targets_dict.values():
        K |= set(t)
    return K


def _tar_of_F_idx(state: IFCIBaseState, f_idx: int) -> Set[int]:
    node = state.idx2node[f_idx]
    if not is_fnode(node):
        return set()
    _, di, dj = node
    Ti = state.targets_dict.get(di, frozenset())
    Tj = state.targets_dict.get(dj, frozenset())
    return set(Ti.symmetric_difference(Tj))


def rule9_local_transit(A: np.ndarray, state: IFCIBaseState) -> bool:
    """
    Rule 9 :
      For (F,Y) with Y globally non-target and adjacent to F,
      let TF(Y) = {X in tar(F) ∩ Adj(Y) : mark at X on X*-*Y is not an arrowhead}.
      If |TF(Y)|=1, orient X->Y and record (X,Y) as a transit pair.
    """
    changed = False
    K = _global_targets(state)
    n_total = A.shape[0]

    for f in range(n_total):
        if not _is_F_idx(state, f):
            continue
        tarF = _tar_of_F_idx(state, f)
        if not tarF:
            continue

        for Y in neighbors(A, f):
            if Y >= state.n_obs:
                continue
            if Y in tarF:
                continue  # non-target w.r.t this F

            # compute TF(Y)
            TF = []
            for X in neighbors(A, Y):
                if X >= state.n_obs:
                    continue
                if X not in tarF:
                    continue
                # "possible parent": mark at X is not an arrowhead
                if mark_at(A, Y, X) != ARROW:
                    TF.append(X)

            if len(TF) == 1:
                X = TF[0]
                if adjacent(A, X, Y):
                    changed |= orient_directed(A, X, Y)
                state.transit_pairs.add((X, Y))

    return changed


def rule10_propagate_transit(A: np.ndarray, state: IFCIBaseState) -> bool:
    """
    Rule 10 (complexity interpretation):
      If (X,Y) is a transit pair, then for any Z adjacent to X but not adjacent to Y,
      orient X -> Z (equivalently forbid Z *-> X).

    We additionally restrict to observed Z (as intended).
    """
    changed = False
    if not state.transit_pairs:
        return False

    # cache adjacency sets (observed only)
    adj_obs = [set([v for v in neighbors(A, i) if v < state.n_obs]) for i in range(state.n_obs)]

    for (X, Y) in list(state.transit_pairs):
        if X >= state.n_obs or Y >= state.n_obs:
            continue
        adjX = adj_obs[X]
        adjY = adj_obs[Y] if Y < state.n_obs else set()

        for Z in (adjX - adjY):
            if Z == Y:
                continue
            if not adjacent(A, X, Z):
                continue
            # orient Z <- X (try full; if conflicts, at least force tail at X)
            if not orient_directed(A, X, Z):
                changed |= set_mark_at(A, Z, X, TAIL)
            else:
                changed = True

    return changed


def rule11_unique_source_targets(A: np.ndarray, state: IFCIBaseState) -> bool:
    """
    Rule 11 :
      For (F,Y) with Y globally non-target and adjacent to F,
      let T = Adj(Y) ∩ tar(F). If the directed subgraph induced by T has a unique source X,
      orient X -> Y.
    """
    changed = False
    K = _global_targets(state)
    n_total = A.shape[0]

    for f in range(n_total):
        if not _is_F_idx(state, f):
            continue
        tarF = _tar_of_F_idx(state, f)
        if not tarF:
            continue

        for Y in neighbors(A, f):
            if Y >= state.n_obs:
                continue
            if Y in tarF:
                continue

            T = [X for X in neighbors(A, Y) if (X < state.n_obs and X in tarF)]
            if len(T) < 2:
                continue

            indeg = {x: 0 for x in T}
            for u in T:
                for v in T:
                    if u == v:
                        continue
                    if is_directed(A, u, v):  # u -> v
                        indeg[v] += 1

            sources = [x for x, d in indeg.items() if d == 0]
            if len(sources) == 1:
                X = sources[0]
                if adjacent(A, X, Y):
                    changed |= orient_directed(A, X, Y)

    return changed


def fast_soft_fci(
    datasets: Dict[int, Dict[str, object]],
    targets: Dict[int, FrozenSet[int]],
    *,
    max_cond_set_size: int = 3,
    ci_method: str = "chisq",
    alpha: float = 0.01,
    seed: Optional[int] = None,
    max_path_len: Optional[int] = None,
) -> IFCIBaseState:
    """
     fast local algorithm:
      IFCI-base (skeleton + FCI closure) + Rules 9–11 to convergence (interleaved with FCI rules).
    """
    return ifci_base(
        datasets=datasets,
        targets=targets,
        max_cond_set_size=max_cond_set_size,
        ci_method=ci_method,
        alpha=alpha,
        seed=seed,
        max_path_len=max_path_len,
        run_phase3=True,
        extra_orientation_rules=[rule9_local_transit, rule10_propagate_transit, rule11_unique_source_targets],
    )



# ============================================================
#  local rules (Rules 9-11) and FAST-SOFT-FCI
# ============================================================

def _tar_of_F_idx(state: IFCIState, f_idx: int) -> Set[int]:
    """
    tar(F_{i,j}) := targets[i] Δ targets[j] (symmetric difference) as a set of observed indices.
    """
    node = state.idx2node[f_idx]
    if not is_fnode(node):
        return set()
    _, di, dj = node  # domain ids
    Ti = state.targets_dict.get(di, frozenset())
    Tj = state.targets_dict.get(dj, frozenset())
    return set(Ti.symmetric_difference(Tj))


def rule9_local_transit(A: np.ndarray, state: IFCIState) -> bool:
    """
    Rule 9 (corrected):
      For each edge F - Y where Y is adjacent to F and Y ∉ tar(F),
      let T = Adj(Y) ∩ tar(F) (over observed nodes).
      Let TF(Y) = {X in T : mark at X on edge X-Y is NOT an arrowhead}.
      If |TF(Y)| = 1, orient X -> Y and record (X,Y) as a transit pair.

    Note: Y may be targeted in other domains; we only require Y ∉ tar(F).
    """
    changed = False
    n_total = A.shape[0]

    for f in range(n_total):
        if not state.is_F_idx(f):
            continue
        tarF = _tar_of_F_idx(state, f)

        for Y in neighbors(A, f):
            if Y >= state.n_obs:
                continue
            if Y in tarF:
                continue  # only exclude members of tar(F)

            TF: List[int] = []
            for X in neighbors(A, Y):
                if X >= state.n_obs:
                    continue
                if X not in tarF:
                    continue
                # mark at X is stored at A[Y, X]
                if mark_at(A, Y, X) != ARROW:
                    TF.append(X)

            if len(TF) == 1:
                X = TF[0]
                # record transit pair even if edge already oriented
                if (X, Y) not in state.transit_pairs:
                    state.transit_pairs.add((X, Y))
                    changed = True  # new information (enables Rule 10)
                changed |= orient_directed(A, X, Y)

    return changed


def rule10_propagate_transit(A: np.ndarray, state: IFCIState) -> bool:
    """
    Rule 10:
      If (X,Y) is a transit pair, then for any Z adjacent to X but not adjacent to Y,
      orient X -> Z.
    """
    changed = False
    if not state.transit_pairs:
        return False

    for (X, Y) in list(state.transit_pairs):
        if X >= state.n_obs or Y >= state.n_obs:
            continue

        for Z in neighbors(A, X):
            if Z >= state.n_obs:
                continue
            if Z == Y:
                continue
            if adjacent(A, Y, Z):
                continue
            changed |= orient_directed(A, X, Z)

    return changed


def rule11_unique_source_targets(A: np.ndarray, state: IFCIState) -> bool:
    """
    Rule 11 (corrected):
      For each edge F - Y where Y is adjacent to F and Y ∉ tar(F),
      let T = Adj(Y) ∩ tar(F).
      Consider the *currently oriented* directed edges among nodes in T.
      If there is a unique source node X in this induced graph,
      orient X -> Y.

    Note:
      - T need not be fully oriented.
      - Directions may come from FCI closure OR Rules 9-10 (anything already oriented).
      - Y may be targeted elsewhere; only require Y ∉ tar(F).
    """
    changed = False
    n_total = A.shape[0]

    for f in range(n_total):
        if not state.is_F_idx(f):
            continue
        tarF = _tar_of_F_idx(state, f)

        for Y in neighbors(A, f):
            if Y >= state.n_obs:
                continue
            if Y in tarF:
                continue

            T = [X for X in neighbors(A, Y) if (X < state.n_obs and X in tarF)]
            if len(T) == 0:
                continue

            indeg = {x: 0 for x in T}
            for u in T:
                for v in T:
                    if u == v:
                        continue
                    if is_directed(A, u, v):  # u -> v
                        indeg[v] += 1

            sources = [x for x, d in indeg.items() if d == 0]
            if len(sources) == 1:
                X = sources[0]
                changed |= orient_directed(A, X, Y)

    return changed


def fast_soft_fci(
    datasets: Dict[int, Dict[str, object]],
    targets: Dict[int, FrozenSet[int]],
    *,
    max_cond_set_size: int = 3,
    ci_method: str = "chisq",
    alpha: float = 0.01,
    seed: Optional[int] = None,
    max_path_len: Optional[int] = None,
) -> IFCIState:
    """
    FAST-SOFT-FCI = IFCI-base + Rules 9–11, closed to fixed point together with FCI rules.
    """
    return ifci_base(
        datasets=datasets,
        targets=targets,
        max_cond_set_size=max_cond_set_size,
        ci_method=ci_method,
        alpha=alpha,
        seed=seed,
        max_path_len=max_path_len,
        run_phase3=True,
        extra_orientation_rules=[rule9_local_transit, rule10_propagate_transit, rule11_unique_source_targets],
    )


# ============================================================
# Completion via MAG listing + IMAG realizability filtering
# ============================================================

def _marks_to_admg_mats(M: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Convert a fully oriented mark matrix (0/1/2) into ADMG matrices (A_di, A_bi).

    Assumes:
      - No circles (3)
      - No undirected edges (2,2)
      - Allowed: directed (1,2)/(2,1) and bidirected (1,1)
    """
    M = np.asarray(M, dtype=np.int8)
    n = M.shape[0]
    A_di = np.zeros((n, n), dtype=np.int8)
    A_bi = np.zeros((n, n), dtype=np.int8)

    for i in range(n):
        for j in range(i + 1, n):
            mi_j = int(M[i, j])   # mark at j
            mj_i = int(M[j, i])   # mark at i

            if mi_j == 0 and mj_i == 0:
                continue
            if mi_j == CIRCLE or mj_i == CIRCLE:
                raise ValueError("marks_to_admg_mats requires no circles (3).")
            if mi_j == TAIL and mj_i == TAIL:
                raise ValueError("Undirected edges (tail-tail) are not allowed in this setting.")

            # i -> j
            if mi_j == ARROW and mj_i == TAIL:
                A_di[i, j] = 1
                A_di[j, i] = 2
                continue

            # j -> i
            if mi_j == TAIL and mj_i == ARROW:
                A_di[j, i] = 1
                A_di[i, j] = 2
                continue

            # i <-> j
            if mi_j == ARROW and mj_i == ARROW:
                A_bi[i, j] = 1
                A_bi[j, i] = 1
                continue

            raise ValueError(f"Invalid endpoint marks on edge {i}-{j}: ({mi_j},{mj_i})")

    np.fill_diagonal(A_di, 0)
    np.fill_diagonal(A_bi, 0)
    return A_di, A_bi


def _is_mag_via_admg_to_mag(M: np.ndarray, *, max_cond_set_size: Optional[int] = None) -> bool:
    """
    Check if a fully oriented mark matrix M is a MAG by verifying:
      MAG(ADMG(M)) == M
    using admg_to_mag + mag_two_mats_to_marks.
    """
    A_di, A_bi = _marks_to_admg_mats(M)
    M_di, M_bi = admg_to_mag(A_di, A_bi, max_cond_set_size=max_cond_set_size)
    M_chk = mag_two_mats_to_marks(M_di, M_bi)
    return np.array_equal(M_chk, M)


def list_mags_from_pag_bruteforce(
    P: np.ndarray,
    state: IFCIState,
    *,
    max_mags: Optional[int] = None,
    check_maximality: bool = True,
    max_cond_set_size: Optional[int] = None,
    seed: Optional[int] = None,
) -> List[np.ndarray]:
    """
    Brute-force MAG listing from a partially oriented mixed graph P (0/1/2/3 codes).

    Enumerate assignments of circle endpoints (3) to {ARROW, TAIL},
    disallow undirected edges (TAIL,TAIL) (no selection bias),
    enforce F->* on edges adjacent to F nodes, then keep only MAG candidates.

    Returns a list of fully oriented MAG mark matrices (0/1/2).
    """
    rng = np.random.default_rng(seed)
    P = np.asarray(P, dtype=np.int8).copy()
    n = P.shape[0]
    np.fill_diagonal(P, 0)

    C = P.copy()

    dir_adj: List[Set[int]] = [set() for _ in range(n)]

    def add_dir_edge(a: int, b: int) -> bool:
        stack = [b]
        seen = set()
        while stack:
            v = stack.pop()
            if v == a:
                return False
            for w in dir_adj[v]:
                if w not in seen:
                    seen.add(w)
                    stack.append(w)
        dir_adj[a].add(b)
        return True

    var_edges: List[Tuple[int, int, List[Tuple[int, int]]]] = []

    for u in range(n):
        for v in range(u + 1, n):
            if not adjacent(P, u, v):
                continue

            uF = state.is_F_idx(u)
            vF = state.is_F_idx(v)

            if uF and vF:
                return []

            mu_v = int(P[u, v])
            mv_u = int(P[v, u])
            if mu_v == 0 or mv_u == 0:
                return []

            allowed_uv = [mu_v] if mu_v in (ARROW, TAIL) else ([ARROW, TAIL] if mu_v == CIRCLE else [])
            allowed_vu = [mv_u] if mv_u in (ARROW, TAIL) else ([ARROW, TAIL] if mv_u == CIRCLE else [])

            if uF and (not vF):
                allowed_uv = [ARROW]
                allowed_vu = [TAIL]
            elif vF and (not uF):
                allowed_uv = [TAIL]
                allowed_vu = [ARROW]

            options: List[Tuple[int, int]] = []
            for a in allowed_uv:
                for b in allowed_vu:
                    if a == CIRCLE or b == CIRCLE or a == 0 or b == 0:
                        continue
                    if a == TAIL and b == TAIL:
                        continue
                    options.append((a, b))

            if len(options) == 0:
                return []

            if len(options) == 1 and (mu_v != CIRCLE and mv_u != CIRCLE):
                a, b = options[0]
                if a == ARROW and b == TAIL:
                    if not add_dir_edge(u, v):
                        return []
                elif a == TAIL and b == ARROW:
                    if not add_dir_edge(v, u):
                        return []
                continue

            if len(options) == 1:
                a, b = options[0]
                C[u, v] = a
                C[v, u] = b
                if a == ARROW and b == TAIL:
                    if not add_dir_edge(u, v):
                        return []
                elif a == TAIL and b == ARROW:
                    if not add_dir_edge(v, u):
                        return []
            else:
                var_edges.append((u, v, options))

    rng.shuffle(var_edges)

    results: List[np.ndarray] = []

    def finalize_and_store(C_full: np.ndarray) -> None:
        anc = np.zeros((n, n), dtype=bool)
        for src in range(n):
            stack = list(dir_adj[src])
            seen = set()
            while stack:
                w = stack.pop()
                if w in seen:
                    continue
                seen.add(w)
                anc[src, w] = True
                stack.extend(dir_adj[w])

        for i in range(n):
            for j in range(i + 1, n):
                if int(C_full[i, j]) == ARROW and int(C_full[j, i]) == ARROW:
                    if anc[i, j] or anc[j, i]:
                        return

        if check_maximality:
            if not _is_mag_via_admg_to_mag(C_full, max_cond_set_size=max_cond_set_size):
                return

        results.append(C_full.copy())

    def backtrack(k: int) -> None:
        if max_mags is not None and len(results) >= max_mags:
            return
        if k == len(var_edges):
            if (C == CIRCLE).any():
                return
            finalize_and_store(C)
            return

        u, v, options = var_edges[k]
        old_uv, old_vu = int(C[u, v]), int(C[v, u])

        for (a, b) in options:
            C[u, v] = a
            C[v, u] = b

            added: Optional[Tuple[int, int]] = None
            ok = True
            if a == ARROW and b == TAIL:
                ok = add_dir_edge(u, v)
                if ok:
                    added = (u, v)
            elif a == TAIL and b == ARROW:
                ok = add_dir_edge(v, u)
                if ok:
                    added = (v, u)

            if ok:
                backtrack(k + 1)

            if added is not None:
                dir_adj[added[0]].remove(added[1])

            C[u, v] = old_uv
            C[v, u] = old_vu

            if max_mags is not None and len(results) >= max_mags:
                return

    backtrack(0)
    return results


def IMAG_realize(
    M: np.ndarray,
    state: IFCIState,
    *,
    max_cond_set_size: Optional[int] = None,
    max_backtracking: int = 50000,
) -> Tuple[bool, dict]:
    """
    IMAG realizability:
      For each (F,Y) adjacency where Y ∉ tar(F),
        choose transit X ∈ tar(F) with X -> Y in M,
        add bidirected X <-> Y to ADMG(M),
      accept iff MAG(augmented ADMG) == M.

    Returns (ok, info) with a transit_map in info if ok.
    """
    M = np.asarray(M, dtype=np.int8)
    if (M == CIRCLE).any():
        return False, {"reason": "M contains circle marks; needs full orientation."}

    n = M.shape[0]

    req: List[Tuple[int, int]] = []
    cand: Dict[Tuple[int, int], List[int]] = {}

    for f in range(n):
        if not state.is_F_idx(f):
            continue
        tarF = _tar_of_F_idx(state, f)

        for y in neighbors(M, f):
            if y >= state.n_obs:
                continue
            if y in tarF:
                continue
            Xs = [x for x in tarF if (x < state.n_obs and is_directed(M, x, y))]
            if len(Xs) == 0:
                return False, {"reason": "no transit", "f": f, "y": y, "tarF": sorted(tarF)}
            req.append((f, y))
            cand[(f, y)] = Xs

    req.sort(key=lambda fy: len(cand[fy]))

    try:
        base_di, base_bi = _marks_to_admg_mats(M)
    except Exception as e:
        return False, {"reason": f"invalid marks for ADMG conversion: {e}"}

    explored = 0

    def check_with_added(transit_pairs: Set[Tuple[int, int]]) -> bool:
        A_di = base_di.copy()
        A_bi = base_bi.copy()
        for (x, y) in transit_pairs:
            if x == y:
                continue
            A_bi[x, y] = 1
            A_bi[y, x] = 1
        M_di, M_bi = admg_to_mag(A_di, A_bi, max_cond_set_size=max_cond_set_size)
        M_chk = mag_two_mats_to_marks(M_di, M_bi)
        return np.array_equal(M_chk, M)

    best_map: Dict[Tuple[int, int], int] = {}

    def dfs(i: int, chosen: Set[Tuple[int, int]], mapping: Dict[Tuple[int, int], int]) -> bool:
        nonlocal explored, best_map
        explored += 1
        if explored > max_backtracking:
            return False
        if i == len(req):
            if check_with_added(chosen):
                best_map = dict(mapping)
                return True
            return False

        f, y = req[i]
        for x in cand[(f, y)]:
            mapping[(f, y)] = x
            chosen.add((x, y))
            if dfs(i + 1, chosen, mapping):
                return True
            chosen.remove((x, y))
            del mapping[(f, y)]
        return False

    ok = dfs(0, set(), {})
    info = {"explored": explored, "transit_map": best_map}
    return ok, info


def complete_soft_fci_from_state(
    state: IFCIState,
    *,
    listing_max_mags: Optional[int] = None,
    listing_check_maximality: bool = True,
    listing_max_cond_set_size: Optional[int] = None,
    imag_max_cond_set_size: Optional[int] = None,
    imag_max_backtracking: int = 50000,
    seed: Optional[int] = None,
) -> IFCIState:
    """
    Complete algorithm:
      - list MAG completions of current state.A
      - filter by IMAG_realize
      - insert invariant endpoint marks back into state.A (only replacing circles)
    """
    P = np.asarray(state.A, dtype=np.int8)

    mags = list_mags_from_pag_bruteforce(
        P, state,
        max_mags=listing_max_mags,
        check_maximality=listing_check_maximality,
        max_cond_set_size=listing_max_cond_set_size,
        seed=seed,
    )
    state.extras["complete_listed_mags"] = len(mags)

    realizable: List[np.ndarray] = []
    infos: List[dict] = []
    for M in mags:
        ok, info = IMAG_realize(
            M, state,
            max_cond_set_size=imag_max_cond_set_size,
            max_backtracking=imag_max_backtracking,
        )
        if ok:
            realizable.append(M)
            infos.append(info)

    state.extras["complete_imag_realizable"] = len(realizable)
    state.extras["complete_imag_infos_first5"] = infos[:5]

    if len(realizable) == 0:
        return state

    A_new = P.copy()
    n = A_new.shape[0]
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            if int(A_new[i, j]) == CIRCLE:
                vals = {int(M[i, j]) for M in realizable}
                if len(vals) == 1:
                    A_new[i, j] = int(next(iter(vals)))

    state.A = A_new
    return state


def complete_soft_fci(
    datasets: Dict[int, Dict[str, object]],
    targets: Dict[int, FrozenSet[int]],
    *,
    max_cond_set_size: int = 3,
    ci_method: str = "chisq",
    alpha: float = 0.01,
    seed: Optional[int] = None,
    max_path_len: Optional[int] = None,
    listing_max_mags: Optional[int] = None,
    listing_check_maximality: bool = True,
    listing_max_cond_set_size: Optional[int] = None,
    imag_max_cond_set_size: Optional[int] = None,
    imag_max_backtracking: int = 50000,
) -> IFCIState:
    """
    IFCI-base -> completion via brute-force MAG listing + IMAG realizability.
    """
    state = ifci_base(
        datasets=datasets,
        targets=targets,
        max_cond_set_size=max_cond_set_size,
        ci_method=ci_method,
        alpha=alpha,
        seed=seed,
        max_path_len=max_path_len,
        run_phase3=True,
        extra_orientation_rules=None,
    )
    return complete_soft_fci_from_state(
        state,
        listing_max_mags=listing_max_mags,
        listing_check_maximality=listing_check_maximality,
        listing_max_cond_set_size=listing_max_cond_set_size,
        imag_max_cond_set_size=imag_max_cond_set_size,
        imag_max_backtracking=imag_max_backtracking,
        seed=seed,
    )
