# ocmb_variants.py
# OCMB: Ordering-Constrained Markov Blanket Algorithm Variants
# Supports different Ordering Backbones: Oracle, Random, SciNO, CaPS, DiffAN
#
# Experimental Objective: Validate the generalization ability of the OCMB framework,
# proving it's not "only effective for specific backbones"

import numpy as np
import pandas as pd
import time
import sys
import os
from typing import Optional, Dict, List, Tuple, Any
import warnings
warnings.filterwarnings('ignore')

# Add path
OCMB_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, OCMB_DIR)
sys.path.insert(0, os.path.join(OCMB_DIR, 'utils'))
sys.path.insert(0, os.path.join(OCMB_DIR, 'caps'))
sys.path.insert(0, os.path.join(OCMB_DIR, 'scino'))

import networkx as nx
from sklearn.preprocessing import StandardScaler

# Import LightGBM (for caps_lgbm backbone)
try:
    from lightgbm import LGBMRegressor
    HAS_LGBM = True
except ImportError:
    HAS_LGBM = False

# Import KnnCMI (using unified loader with cache control)
# Environment variable: KNNCMI_USE_CACHE=0 disables cache (final experiment), =1 enables cache (debug acceleration)
try:
    from KnnCMI_loader import cmi, invalidate_cmi_cache, set_use_cache, get_use_cache
except ImportError:
    # Fallback: direct import
    cmi = None
    invalidate_cmi_cache = None
    try:
        from numba import cuda as numba_cuda
        if numba_cuda.is_available():
            try:
                from KnnCMI_cuda_cached import cmi as _cmi_cuda_cached
                from KnnCMI_cuda_cached import invalidate_cmi_cache as _invalidate_cache
                cmi = _cmi_cuda_cached
                invalidate_cmi_cache = _invalidate_cache
                print("[OCMB] Using CUDA-accelerated KnnCMI with caching")
            except ImportError:
                from KnnCMI_cuda import cmi as _cmi_cuda
                cmi = _cmi_cuda
                print("[OCMB] Using CUDA-accelerated KnnCMI (no cache)")
    except Exception:
        pass

    if cmi is None:
        try:
            from KnnCMI import cmi as _cmi_cpu
            cmi = _cmi_cpu
            print("[OCMB] Using CPU KnnCMI")
        except Exception:
            def cmi(x, y, z, k, data):
                return np.random.random() * 0.1
            print("[OCMB] Using fallback CMI")

    if invalidate_cmi_cache is None:
        def invalidate_cmi_cache():
            pass

    def set_use_cache(use_cache: bool):
        pass

    def get_use_cache() -> bool:
        return False


# =============================================================================
# Utility Functions
# =============================================================================

def get_topological_order_from_dag(B: np.ndarray) -> np.ndarray:
    """Get topological order from ground truth DAG (Oracle ordering)"""
    G = nx.DiGraph(B)
    try:
        order = list(nx.topological_sort(G))
        return np.array(order)
    except nx.NetworkXUnfeasible:
        return np.arange(B.shape[0])


def get_random_order(n_nodes: int, seed: int = 42) -> np.ndarray:
    """Generate random topological order"""
    np.random.seed(seed)
    order = np.arange(n_nodes)
    np.random.shuffle(order)
    return order


def order_divergence(order: np.ndarray, B: np.ndarray) -> float:
    """
    Calculate divergence between predicted topological order and ground truth DAG
    Lower value indicates better ordering quality

    Calculation method: proportion of edges violating topological constraints
    """
    n = len(order)
    order_pos = {node: i for i, node in enumerate(order)}
    n_edges = int(np.sum(B))

    if n_edges == 0:
        return 0.0

    violations = 0
    for i in range(n):
        for j in range(n):
            if B[i, j] == 1:  # i -> j is a true edge
                if order_pos[i] > order_pos[j]:  # but in ordering, j comes before i
                    violations += 1

    return violations / n_edges


def order_kendall_tau(order: np.ndarray, B: np.ndarray) -> float:
    """
    Calculate Kendall tau correlation coefficient between predicted and true topological orders

    Return value ranges from [-1, 1], where 1 means perfect agreement and -1 means complete disagreement

    Parameters
    ----------
    order : np.ndarray
        Predicted topological order (array of node indices)
    B : np.ndarray
        True adjacency matrix (d, d), where B[i,j]=1 means i->j

    Returns
    -------
    float
        Kendall tau correlation coefficient
    """
    from scipy.stats import kendalltau

    # Get a topological order from the true DAG
    true_order = get_topological_order_from_dag(B)

    # Convert both orders to ranks (positions)
    n = len(order)
    pred_rank = np.zeros(n, dtype=int)
    true_rank = np.zeros(n, dtype=int)

    for i, node in enumerate(order):
        pred_rank[node] = i
    for i, node in enumerate(true_order):
        true_rank[node] = i

    # Calculate Kendall tau
    tau, _ = kendalltau(pred_rank, true_rank)

    # Handle NaN (can occur when all values are identical)
    if np.isnan(tau):
        tau = 0.0

    return float(tau)


def calculate_metrics(true_graph: np.ndarray, pred_graph: np.ndarray) -> Dict[str, float]:
    """Calculate evaluation metrics"""
    from sklearn.metrics import precision_score, recall_score, f1_score, average_precision_score

    y_true = true_graph.flatten()
    y_pred = pred_graph.flatten()

    # SHD
    diff = np.abs(true_graph - pred_graph)
    shd = int(np.sum(diff))

    # Precision, Recall, F1
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    # AUPR
    try:
        aupr = average_precision_score(y_true, y_pred)
    except:
        aupr = 0.0

    # TP, FP, FN
    tp = int(np.sum((y_true == 1) & (y_pred == 1)))
    fp = int(np.sum((y_true == 0) & (y_pred == 1)))
    fn = int(np.sum((y_true == 1) & (y_pred == 0)))

    return {
        'SHD': shd,
        'Precision': precision,
        'Recall': recall,
        'F1': f1,
        'AUPR': aupr,
        'TP': tp,
        'FP': fp,
        'FN': fn
    }


def get_true_parents(B: np.ndarray) -> Dict[int, set]:
    """Get true parent set for each node from ground truth DAG"""
    n = B.shape[0]
    parents = {}
    for j in range(n):
        parents[j] = set(np.where(B[:, j] == 1)[0])
    return parents


def get_true_markov_blankets(B: np.ndarray) -> Dict[int, set]:
    """Get true Markov Blanket for each node from ground truth DAG"""
    n = B.shape[0]
    mb = {}
    for i in range(n):
        # Parents
        parents = set(np.where(B[:, i] != 0)[0])
        # Children
        children = set(np.where(B[i, :] != 0)[0])
        # Spouses (other parents of children)
        spouses = set()
        for c in children:
            other_parents = set(np.where(B[:, c] != 0)[0])
            other_parents.discard(i)
            spouses |= other_parents
        mb[i] = parents | children | spouses
    return mb


# =============================================================================
# OCMB Base Class
# =============================================================================

class OCMB_Base:
    """
    OCMB: Ordering-Constrained Markov Blanket

    Core Idea: Use global ordering signal to constrain search space for local Markov Blanket learning
    - Obtain topological order and candidate parent sets from ordering backbone
    - Run constrained IAMB to learn Markov Blanket within candidate set
    - Orient edges using topological constraints

    Subclasses implement different ordering backbones
    """

    def __init__(self,
                 # Constrained-IAMB parameters
                 max_parents: int = 10,
                 k_mb: int = 5,
                 alpha_mb: float = 0.01,
                 symmetry: str = 'delete',
                 use_spouse_closure: bool = True,
                 # Score threshold parameters (critical fix: suppress random baseline)
                 score_threshold: float = 0.0,  # τ: minimum score threshold, 0 means disabled
                 score_threshold_quantile: float = None,  # auto-compute τ by quantile (e.g., 0.8 means 80th percentile)
                 # General parameters
                 verbose: bool = True):
        """
        Initialize OCMB base class

        Args:
            max_parents: Maximum number of candidate parent nodes per node (K)
            k_mb: k-nearest neighbor parameter for Markov Blanket learning
            alpha_mb: CMI threshold for Markov Blanket learning
            symmetry: MB symmetry strategy ('delete' or 'add')
            use_spouse_closure: Whether to perform spouse closure in CandNei
            score_threshold: Minimum score threshold τ for candidate parents, below which they won't enter CandPa
            score_threshold_quantile: Auto-compute τ by score quantile (takes precedence over score_threshold)
            verbose: Whether to output detailed information
        """
        self.max_parents = max_parents
        self.k_mb = k_mb
        self.alpha_mb = alpha_mb
        self.symmetry = symmetry
        self.use_spouse_closure = use_spouse_closure
        self.score_threshold = score_threshold
        self.score_threshold_quantile = score_threshold_quantile
        self.verbose = verbose

        # Result storage
        self.order_ = None
        self.parent_scores_ = None
        self.CandPa_ = None
        self.CandNei_ = None
        self.MB_ = None
        self.graph_ = None
        self.timings_ = {}
        self.ordering_divergence_ = None
        self.ordering_kendall_tau_ = None  # New: Kendall tau correlation coefficient
        # New: CI test statistic as edge score (used for AUPR and other ranking metrics)
        # ci_scores_[i,j] records max CMI(i;j|S) observed during algorithm execution
        self.ci_scores_ = None
        # New: final CMI score (re-evaluate CMI with final MB after IAMB convergence)
        # ci_scores_final_[i,j] records max CMI(i;j|MB_final\{j}) (taking max across targets for skeleton/ordered AUPR)
        self.ci_scores_final_ = None
        # New: candidate coverage statistics (diagnostic metric required by reviewers)
        self.cand_stats_ = {}
        # New: CMI call counter (complexity evidence required by reviewers)
        self.n_cmi_calls_ = 0
        # New: backbone source and error (for auditing SciNO/CaPS fallback)
        self.backbone_used_ = None        # e.g., "scino", "caps", "random", "fallback_random"
        self.backbone_error_ = None       # error string if any

    def _log(self, msg: str):
        if self.verbose:
            print(f"[OCMB] {msg}")

    def _set_backbone_status(self, used: str, err: str = None):
        """Record actual backbone source and error information (for auditability)"""
        self.backbone_used_ = used
        self.backbone_error_ = err

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        """
        Get topological order and parent score matrix

        Subclasses must implement this method

        Args:
            X: Data matrix (n_samples, n_nodes)
            true_adj: True adjacency matrix (used only by Oracle)

        Returns:
            order: Topological order array
            parent_scores: Parent score matrix (n_nodes, n_nodes)
        """
        raise NotImplementedError("Subclass must implement _get_ordering")

    def fit(self, X: np.ndarray, true_adj: np.ndarray = None) -> 'OCMB_Base':
        """
        Fit OCMB model

        Args:
            X: Data matrix (n_samples, n_nodes)
            true_adj: True adjacency matrix (for computing ordering divergence and Oracle methods)

        Returns:
            self
        """
        X = np.asarray(X)
        n_samples, n_nodes = X.shape
        self._log(f"Input data: {n_samples} samples, {n_nodes} nodes")

        # Reset CMI call counter
        self.n_cmi_calls_ = 0
        # Reset CI score matrix (for AUPR)
        self.ci_scores_ = np.zeros((n_nodes, n_nodes), dtype=float)
        self.ci_scores_final_ = np.zeros((n_nodes, n_nodes), dtype=float)

        # Clear CMI cache (at the start of new data)
        invalidate_cmi_cache()

        total_start = time.time()

        # Step 1: Get topological order and parent scores
        self._log("Step 1/4: Getting topological order from backbone...")
        start = time.time()
        self.order_, self.parent_scores_ = self._get_ordering(X, true_adj)
        self.timings_['ordering'] = time.time() - start
        self._log(f"  Ordering completed in {self.timings_['ordering']:.2f}s")

        # Calculate ordering quality metrics (if ground truth available)
        if true_adj is not None:
            self.ordering_divergence_ = order_divergence(self.order_, true_adj)
            self.ordering_kendall_tau_ = order_kendall_tau(self.order_, true_adj)
            self._log(f"  Ordering divergence: {self.ordering_divergence_:.3f}, Kendall tau: {self.ordering_kendall_tau_:.3f}")

        # Step 2: Construct candidate parent and neighbor sets
        self._log("Step 2/4: Building candidate parent/neighbor sets...")
        start = time.time()
        self.CandPa_, self.CandNei_ = self._get_candidate_sets(n_nodes)
        self.timings_['candidates'] = time.time() - start
        avg_cand_pa = np.mean([len(v) for v in self.CandPa_.values()])
        avg_cand_nei = np.mean([len(v) for v in self.CandNei_.values()])
        self._log(f"  Average CandPa size: {avg_cand_pa:.1f}, CandNei size: {avg_cand_nei:.1f}")

        # Calculate candidate coverage (key diagnostic metric required by reviewers)
        if true_adj is not None:
            true_pa = get_true_parents(true_adj)
            true_mb = get_true_markov_blankets(true_adj)
            self._compute_candidate_coverage(true_pa, true_mb)
            self._log(f"  Candidate coverage - Pa: {self.cand_stats_['covPa_mean']:.3f}, MB: {self.cand_stats_['covMB_mean']:.3f}")

        # Step 3: Constrained Markov Blanket learning
        self._log("Step 3/4: Learning constrained Markov Blankets...")
        start = time.time()
        self.MB_ = self._get_MB_constrained(X)
        self.timings_['mb'] = time.time() - start
        avg_mb = np.mean([len(v) for v in self.MB_.values()])
        self._log(f"  MB learning completed in {self.timings_['mb']:.2f}s")
        self._log(f"  Average MB size: {avg_mb:.1f}")

        # Step 4: Orient edges based on topological order
        self._log("Step 4/4: Orienting edges based on topological order...")
        start = time.time()
        self.graph_ = self._orient_edges(n_nodes)
        self.timings_['orientation'] = time.time() - start
        n_edges = int(np.sum(self.graph_))
        self._log(f"  Orientation completed: {n_edges} directed edges")

        self.timings_['total'] = time.time() - total_start
        self._log(f"OCMB completed in {self.timings_['total']:.2f}s")
        self._log(f"  Total CMI calls: {self.n_cmi_calls_} (avg {self.n_cmi_calls_/n_nodes:.1f} per node)")

        # Clear CMI cache (release GPU memory)
        invalidate_cmi_cache()

        return self

    def _update_ci_score(self, i: int, j: int, val: float):
        """
        Update CI-based edge score matrix with a new test statistic value.

        We record max CMI magnitude across all tested conditioning sets:
          score(i,j) = max_{tests} max(0, CMI(i; j | S))
        """
        if self.ci_scores_ is None:
            return
        try:
            v = float(val)
        except Exception:
            return
        if not np.isfinite(v):
            return
        if v < 0:
            v = 0.0
        # CMI is symmetric; keep matrix symmetric to support skeleton evaluation.
        if v > self.ci_scores_[i, j]:
            self.ci_scores_[i, j] = v
        if v > self.ci_scores_[j, i]:
            self.ci_scores_[j, i] = v

    def _update_ci_final_score(self, i: int, j: int, val: float):
        """
        Update CI-based edge score matrix using the *final* IAMB conditioning set.

        This is the recommended AUPR scoring for OCMB:
          score(i,j) = CMI(i; j | MB_final(i) \\ {j})
        Since CMI is symmetric but conditioning sets differ by target, we keep a symmetric
        matrix by taking max over updates from either target (useful for skeleton / ordered AUPR).
        """
        if self.ci_scores_final_ is None:
            return
        try:
            v = float(val)
        except Exception:
            return
        if not np.isfinite(v):
            return
        if v < 0:
            v = 0.0
        if v > self.ci_scores_final_[i, j]:
            self.ci_scores_final_[i, j] = v
        if v > self.ci_scores_final_[j, i]:
            self.ci_scores_final_[j, i] = v

    def _get_candidate_sets(self, n_nodes: int) -> Tuple[Dict[int, set], Dict[int, set]]:
        """
        Construct candidate parent and neighbor sets from topological order and score matrix

        CandPa[j]: Among nodes preceding j in the order, select top-K by parent_scores
                   and score >= τ (score_threshold)
        CandNei[i]:
          - use_spouse_closure=True:  CandPa[i] ∪ CandCh[i] ∪ (∪_{c in CandCh[i]} CandPa[c])
          - use_spouse_closure=False: CandPa[i] ∪ CandCh[i]
        """
        CandPa: Dict[int, set] = {}
        CandNei: Dict[int, set] = {}

        order_pos = {node: i for i, node in enumerate(self.order_)}

        # Use fixed rng to break ties, ensuring reproducibility
        rng = np.random.default_rng(0)

        # Calculate score threshold τ (critical fix: suppress random baseline)
        tau = self.score_threshold
        if self.score_threshold_quantile is not None and self.parent_scores_ is not None:
            # Auto-compute τ by quantile
            non_zero_scores = self.parent_scores_[self.parent_scores_ > 0]
            if len(non_zero_scores) > 0:
                tau = np.quantile(non_zero_scores, self.score_threshold_quantile)
                self._log(f"  Auto τ from {self.score_threshold_quantile:.0%} quantile: {tau:.4f}")

        # Record actually used τ
        self.effective_tau_ = tau

        # 1) build CandPa with score threshold
        for j in range(n_nodes):
            candidates = []
            for i in range(n_nodes):
                if i == j:
                    continue
                if order_pos[i] < order_pos[j]:
                    score = self.parent_scores_[i, j] if self.parent_scores_ is not None else 1.0
                    # Critical fix: only keep candidates with score >= τ
                    if score >= tau:
                        candidates.append((i, float(score)))

            # Critical fix: shuffle first to avoid index bias when scores are equal
            rng.shuffle(candidates)
            candidates.sort(key=lambda x: x[1], reverse=True)

            CandPa[j] = set([c[0] for c in candidates[: self.max_parents]])

        # 2) build candidate children lists CandCh efficiently
        CandCh: Dict[int, set] = {i: set() for i in range(n_nodes)}
        for child in range(n_nodes):
            for p in CandPa[child]:
                CandCh[p].add(child)

        # 3) build CandNei with spouse closure
        for i in range(n_nodes):
            nei = set()
            nei |= CandPa.get(i, set())     # parents
            nei |= CandCh.get(i, set())     # children

            if self.use_spouse_closure:
                # spouse closure: for each candidate child c, add other candidate parents of c
                for c in CandCh.get(i, set()):
                    nei |= CandPa.get(c, set())

            nei.discard(i)
            CandNei[i] = nei

        return CandPa, CandNei

    def _compute_candidate_coverage(self, true_pa: Dict[int, set], true_mb: Dict[int, set]):
        """
        Calculate candidate coverage statistics (key diagnostic metric required by reviewers)

        - covPa: coverage rate of true parents in candidate parent set
        - covMB: coverage rate of true Markov Blanket in candidate neighbor set

        These metrics explain the success/failure mechanisms of OCMB
        """
        n_nodes = len(self.CandPa_)

        # Calculate coverage rate for each node
        covPa_list = []
        covMB_list = []

        for i in range(n_nodes):
            # Parent set coverage rate
            true_pa_i = true_pa.get(i, set())
            cand_pa_i = self.CandPa_.get(i, set())
            if len(true_pa_i) > 0:
                covered_pa = len(true_pa_i & cand_pa_i)
                covPa_list.append(covered_pa / len(true_pa_i))
            else:
                covPa_list.append(1.0)  # 100% coverage when there are no true parents

            # MB coverage rate
            true_mb_i = true_mb.get(i, set())
            cand_nei_i = self.CandNei_.get(i, set())
            if len(true_mb_i) > 0:
                covered_mb = len(true_mb_i & cand_nei_i)
                covMB_list.append(covered_mb / len(true_mb_i))
            else:
                covMB_list.append(1.0)  # 100% coverage when there is no true MB

        # Summary statistics
        self.cand_stats_ = {
            'covPa_mean': float(np.mean(covPa_list)),
            'covPa_std': float(np.std(covPa_list)),
            'covMB_mean': float(np.mean(covMB_list)),
            'covMB_std': float(np.std(covMB_list)),
            'avg_CandPa_size': float(np.mean([len(v) for v in self.CandPa_.values()])),
            'avg_CandNei_size': float(np.mean([len(v) for v in self.CandNei_.values()])),
            'avg_true_Pa_size': float(np.mean([len(v) for v in true_pa.values()])),
            'avg_true_MB_size': float(np.mean([len(v) for v in true_mb.values()])),
            # Detailed per-node coverage rates (for analysis)
            'covPa_per_node': covPa_list,
            'covMB_per_node': covMB_list,
        }

    def _IAMB_constrained(self, T: int, data_df: pd.DataFrame, CandNei_T: set) -> List[int]:
        """Constrained IAMB algorithm: search for Markov Blanket only within candidate neighbors"""
        CMB = []
        CMB_prev = None
        candidates = set(CandNei_T) - {T}

        # Forward phase
        while CMB != CMB_prev:
            CMB_prev = CMB.copy()
            best_x, best_val = None, 0.0

            for x in list(candidates):
                val = cmi([T], [x], CMB, self.k_mb, data=data_df)
                self.n_cmi_calls_ += 1  # Counter: record CMI call count
                self._update_ci_score(T, x, val)
                if val >= best_val:
                    best_val = val
                    best_x = x

            if best_val >= self.alpha_mb and best_x is not None:
                CMB.append(best_x)
                candidates.remove(best_x)

        # Backward phase
        for x in CMB.copy():
            CMB_x = [z for z in CMB if z != x]
            val = cmi([T], [x], CMB_x if CMB_x else [], self.k_mb, data=data_df)
            self.n_cmi_calls_ += 1  # Counter: record CMI call count
            self._update_ci_score(T, x, val)
            if val <= self.alpha_mb:
                CMB.remove(x)

        # Final-score phase (Scheme A): use the converged MB as conditioning set, and compute
        # a continuous CMI score for all candidate neighbors (better ranking than binary graph).
        if self.ci_scores_final_ is not None:
            mb_final = sorted(CMB)
            for x in sorted(set(CandNei_T) - {T}):
                cond = [z for z in mb_final if z != x]
                val_final = cmi([T], [x], cond if cond else [], self.k_mb, data=data_df)
                self.n_cmi_calls_ += 1
                self._update_ci_final_score(T, x, val_final)

        return CMB

    def _get_MB_constrained(self, X: np.ndarray) -> Dict[int, List[int]]:
        """Get constrained Markov Blanket for all nodes"""
        data_df = pd.DataFrame(X)
        d = X.shape[1]
        MB = {}

        for T in range(d):
            if self.verbose and T % max(1, d // 5) == 0:
                print(f"    MB progress: {T}/{d}")
            MB[T] = self._IAMB_constrained(T, data_df, self.CandNei_[T])

        # Symmetry pruning
        if self.symmetry == 'delete':
            for T in range(d):
                for X_node in MB[T].copy():
                    if T not in MB.get(X_node, []):
                        MB[T].remove(X_node)
        elif self.symmetry == 'add':
            for T in range(d):
                for X_node in MB[T]:
                    if T not in MB.get(X_node, []):
                        MB[X_node].append(T)

        return MB

    def _orient_edges(self, n_nodes: int) -> np.ndarray:
        """Orient edges based on topological order"""
        graph = np.zeros((n_nodes, n_nodes), dtype=int)
        order_pos = {node: i for i, node in enumerate(self.order_)}

        for T in range(n_nodes):
            for X_node in self.MB_[T]:
                # If X_node precedes T in topological order, then X_node -> T
                if order_pos[X_node] < order_pos[T]:
                    # Additional check: X_node must be in T's candidate parent set
                    if X_node in self.CandPa_.get(T, set()):
                        graph[X_node, T] = 1

        return graph

    def get_adjacency_matrix(self) -> np.ndarray:
        """Get adjacency matrix"""
        return self.graph_

    def get_timings(self) -> Dict[str, float]:
        """Get timing for each stage"""
        return self.timings_

    def get_ordering_divergence(self) -> Optional[float]:
        """Get ordering divergence"""
        return self.ordering_divergence_

    def get_n_cmi_calls(self) -> int:
        """Get total number of CMI calls (for complexity analysis)"""
        return self.n_cmi_calls_

    def get_effective_tau(self) -> float:
        """Get actually used score threshold τ"""
        return getattr(self, 'effective_tau_', 0.0)


# =============================================================================
# OCMB Variant Implementations
# =============================================================================

class OCMB_Oracle(OCMB_Base):
    """
    OCMB with Oracle Ordering

    Uses true topological order as ordering backbone
    For validating the theoretical upper bound of OCMB algorithm
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._log("Initialized OCMB-Oracle (ground-truth ordering)")

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        if true_adj is None:
            raise ValueError("OCMB-Oracle requires true_adj (ground truth DAG)")

        order = get_topological_order_from_dag(true_adj)

        # Fix: don't transpose. true_adj[i,j]=1 means i->j, which is exactly parent_scores[i,j]
        # parent_scores[i,j] being high means i is likely a parent of j
        parent_scores = true_adj.copy().astype(float)
        np.fill_diagonal(parent_scores, 0.0)

        # Mark backbone status
        self._set_backbone_status("oracle", None)

        return order, parent_scores


class OCMB_Random(OCMB_Base):
    """
    OCMB with Random Ordering

    Uses random topological order as ordering backbone
    For ablation experiments, proving that ordering information is indeed crucial
    """

    def __init__(self, seed: int = 42, **kwargs):
        super().__init__(**kwargs)
        self.seed = seed
        self._log(f"Initialized OCMB-Random (random ordering, seed={seed})")

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        n_nodes = X.shape[1]
        order = get_random_order(n_nodes, self.seed)

        # Critical fix: use truly random scores instead of uniform constants
        # This allows τ threshold to filter candidates (uniform constants make all scores equal, τ can't cut)
        rng = np.random.default_rng(self.seed)
        parent_scores = rng.random((n_nodes, n_nodes))
        np.fill_diagonal(parent_scores, 0)

        # Mark backbone status
        self._set_backbone_status("random", None)

        return order, parent_scores


class OCMB_Precomputed(OCMB_Base):
    """
    OCMB with Precomputed Ordering + Parent Scores

    For "unified one-time pipeline": cache and reuse Stage B (ordering/score) output,
    avoiding repeated training of SciNO/CaPS in K-sweep / τ-sweep / ablation.
    """

    def __init__(self, order: np.ndarray, parent_scores: np.ndarray, **kwargs):
        super().__init__(**kwargs)
        self._provided_order = np.asarray(order, dtype=int)
        self._provided_parent_scores = np.asarray(parent_scores, dtype=float)
        self._log("Initialized OCMB-Precomputed (fixed order+scores)")

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        n_nodes = X.shape[1]

        if self._provided_order.shape[0] != n_nodes:
            raise ValueError(f"Provided order length mismatch: {self._provided_order.shape[0]} vs {n_nodes}")

        if self._provided_parent_scores.shape != (n_nodes, n_nodes):
            raise ValueError(
                f"Provided parent_scores shape mismatch: {self._provided_parent_scores.shape} vs {(n_nodes, n_nodes)}"
            )

        # Mark backbone status (precomputed status is set externally, here just mark the source)
        self._set_backbone_status("precomputed", None)

        return self._provided_order, self._provided_parent_scores


class OCMB_OracleOrderOnly(OCMB_Base):
    """
    OCMB with Oracle Ordering Only

    Uses only true topological order, but candidate parent sets are still constructed by score matrix (using uniform scores)
    For validating: contribution of ordering information alone to OCMB

    Difference from OCMB_Oracle:
    - OCMB_Oracle: uses true ordering + true parent_scores
    - OCMB_OracleOrderOnly: uses true ordering + uniform parent_scores
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._log("Initialized OCMB-OracleOrderOnly (true ordering, uniform scores)")

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        if true_adj is None:
            raise ValueError("OCMB-OracleOrderOnly requires true_adj (ground truth DAG)")

        order = get_topological_order_from_dag(true_adj)
        n_nodes = len(order)

        # Only use true ordering, parent_scores use uniform distribution (don't leverage true parent information)
        parent_scores = np.ones((n_nodes, n_nodes)) / n_nodes
        np.fill_diagonal(parent_scores, 0)

        # Mark backbone status
        self._set_backbone_status("oracle_order_only", None)

        return order, parent_scores


class OCMB_OracleCandPa(OCMB_Base):
    """
    OCMB with Oracle Candidate Parent Sets

    Uses true topological order + true parent sets as candidate parent sets
    For validating: theoretical upper bound of OCMB with perfect candidate parent sets

    CandPa[j] = true_Pa[j] (directly uses true parent set)
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._log("Initialized OCMB-OracleCandPa (true parents as candidates)")
        self._true_adj = None

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        if true_adj is None:
            raise ValueError("OCMB-OracleCandPa requires true_adj (ground truth DAG)")

        self._true_adj = true_adj
        order = get_topological_order_from_dag(true_adj)

        # Fix: don't transpose. Use true values as parent_scores
        parent_scores = true_adj.copy().astype(float)
        np.fill_diagonal(parent_scores, 0.0)

        # Mark backbone status
        self._set_backbone_status("oracle_cand_pa", None)

        return order, parent_scores

    def _get_candidate_sets(self, n_nodes: int) -> Tuple[Dict[int, set], Dict[int, set]]:
        """Directly use true parent sets as candidate parent sets, with spouse closure"""
        CandPa = get_true_parents(self._true_adj)

        # Build candidate children
        CandCh: Dict[int, set] = {i: set() for i in range(n_nodes)}
        for child in range(n_nodes):
            for p in CandPa.get(child, set()):
                CandCh[p].add(child)

        # Neighbor set = parent set + children set + spouse closure
        CandNei = {}
        for i in range(n_nodes):
            nei = set()
            nei |= CandPa.get(i, set())  # parents
            nei |= CandCh.get(i, set())  # children
            if self.use_spouse_closure:
                # spouse closure
                for c in CandCh.get(i, set()):
                    nei |= CandPa.get(c, set())
            nei.discard(i)
            CandNei[i] = nei

        return CandPa, CandNei


class OCMB_OracleCandNei(OCMB_Base):
    """
    OCMB with Oracle Candidate Neighbor Sets (True Markov Blanket)

    Uses true topological order + true Markov Blanket as candidate neighbor sets
    For validating: theoretical upper bound of OCMB with perfect candidate neighbor sets

    CandNei[i] = true_MB[i] (directly uses true Markov Blanket)
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._log("Initialized OCMB-OracleCandNei (true MB as candidates)")
        self._true_adj = None

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        if true_adj is None:
            raise ValueError("OCMB-OracleCandNei requires true_adj (ground truth DAG)")

        self._true_adj = true_adj
        order = get_topological_order_from_dag(true_adj)

        # Fix: don't transpose. Use true values as parent_scores
        parent_scores = true_adj.copy().astype(float)
        np.fill_diagonal(parent_scores, 0.0)

        # Mark backbone status
        self._set_backbone_status("oracle_cand_nei", None)

        return order, parent_scores

    def _get_candidate_sets(self, n_nodes: int) -> Tuple[Dict[int, set], Dict[int, set]]:
        """Directly use true Markov Blanket as candidate neighbor sets"""
        # Candidate parent set: still based on ordering and max_parents constraint
        CandPa = get_true_parents(self._true_adj)

        # Candidate neighbor set: directly use true Markov Blanket
        CandNei = get_true_markov_blankets(self._true_adj)

        return CandPa, CandNei


class OCMB_SciNO(OCMB_Base):
    """
    OCMB with SciNO Ordering

    Uses SciNO neural operator as ordering backbone
    """

    def __init__(self,
                 # SciNO parameters
                 score_hidden: int = 256,
                 score_epochs: int = 100,
                 score_lr: float = 1e-4,
                 score_batch_size: int = 128,
                 op_width: int = 128,
                 op_modes: tuple = (32, 16, 8),
                 op_depth: int = 4,
                 op_epochs: int = 60,
                 op_lr: float = 3e-4,
                 op_batch_size: int = 64,
                 masking: bool = True,
                 residue: bool = False,
                 n_votes: int = 3,
                 # Performance optimization parameters (pass through to SciNO)
                 device: Optional[str] = None,
                 num_workers: int = 0,
                 pin_memory: bool = False,
                 compile_model: bool = False,
                 use_amp: bool = True,
                 grad_clip: float = 1.0,
                 **kwargs):
        super().__init__(**kwargs)
        self.scino_params = {
            'score_hidden': score_hidden,
            'score_epochs': score_epochs,
            'score_lr': score_lr,
            'score_batch_size': score_batch_size,
            'op_width': op_width,
            'op_modes': op_modes,
            'op_depth': op_depth,
            'op_epochs': op_epochs,
            'op_lr': op_lr,
            'op_batch_size': op_batch_size,
            'masking': masking,
            'residue': residue,
            'n_votes': n_votes,
            'device': device,
            'num_workers': num_workers,
            'pin_memory': pin_memory,
            'compile_model': compile_model,
            'use_amp': use_amp,
            'grad_clip': grad_clip,
        }
        self._log("Initialized OCMB-SciNO (SciNO backbone)")

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        n_nodes = X.shape[1]

        try:
            from sciNO.scino import SciNO

            scino = SciNO(
                n_nodes=n_nodes,
                cutoff=0.0,  # 不做剪枝，我们自己处理
                pruning_method='none',
                verbose=self.verbose,
                **self.scino_params
            )

            W, order = scino.fit(X)
            parent_scores = np.abs(W)

            # Mark: actually used SciNO
            self._set_backbone_status("scino", None)

            return np.array(order), parent_scores

        except Exception as e:
            # Mark: degraded to random (and record error)
            err = f"{type(e).__name__}: {e}"
            self._set_backbone_status("fallback_random", err)

            self._log(f"Warning: SciNO failed ({err}), falling back to random ordering")

            # Use seed=2026 to avoid being identical to random baseline (seed=42)
            order_fallback = get_random_order(n_nodes, seed=2026)
            # Critical fix: use random scores instead of uniform constants (otherwise τ can't cut)
            rng = np.random.default_rng(2026)
            scores_fallback = rng.random((n_nodes, n_nodes))
            np.fill_diagonal(scores_fallback, 0.0)

            return order_fallback, scores_fallback


class OCMB_CaPS(OCMB_Base):
    """
    OCMB with CaPS Ordering

    Uses CaPS (Causal order Prediction in Structures) as ordering backbone
    """

    def __init__(self,
                 # CaPS parameters
                 eta_G: float = 0.001,
                 eta_H: float = 0.001,
                 dispersion: str = 'mean',
                 device: str = 'cuda:0',
                 **kwargs):
        super().__init__(**kwargs)
        self.caps_params = {
            'eta_G': eta_G,
            'eta_H': eta_H,
            'dispersion': dispersion,
            'device': device,
        }
        self._log("Initialized OCMB-CaPS (CaPS backbone)")

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        n_nodes = X.shape[1]

        try:
            from caps_backend import run_caps

            # Data standardization
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)

            order, parent_scores = run_caps(
                X_scaled,
                eta_G=self.caps_params['eta_G'],
                eta_H=self.caps_params['eta_H'],
                dispersion=self.caps_params['dispersion'],
                device=self.caps_params['device']
            )

            # Mark: actually used CaPS
            self._set_backbone_status("caps", None)

            return np.array(order), parent_scores

        except Exception as e:
            # Mark: degraded to random (and record error)
            err = f"{type(e).__name__}: {e}"
            self._set_backbone_status("fallback_random", err)

            self._log(f"Warning: CaPS failed ({err}), falling back to random ordering")

            # Use seed=2026 to avoid being identical to random baseline (seed=42)
            order_fallback = get_random_order(n_nodes, seed=2026)
            # Critical fix: use random scores instead of uniform constants (otherwise τ can't cut)
            rng = np.random.default_rng(2026)
            scores_fallback = rng.random((n_nodes, n_nodes))
            np.fill_diagonal(scores_fallback, 0.0)

            return order_fallback, scores_fallback


class OCMB_CaPS_LGBMScore(OCMB_Base):
    """
    OCMB with CaPS ordering + LightGBM-based parent scoring

    - Ordering: CaPS
    - Parent scores: LightGBM feature importance (gain), trained only on predecessors in the order

    Motivation:
      CaPS may require large K in high dimension if its parent_scores are not sharp enough.
      Replacing the scoring with a stronger nonlinear screener (LightGBM) can increase covPa@smallK
      and reduce the need for K=0.4d.

    Notes:
      Optionally uses sample splitting:
        - score split: train LightGBM scores
        - CI split: run constrained IAMB (CMI tests)
    """

    def __init__(self,
                 # CaPS params
                 eta_G: float = 0.001,
                 eta_H: float = 0.001,
                 dispersion: str = 'mean',
                 device: str = 'cuda:0',

                 # LightGBM params (optimized for speed)
                 lgbm_n_estimators: int = 100,  # reduced from 300 for speed
                 lgbm_learning_rate: float = 0.1,  # increased for faster convergence
                 lgbm_num_leaves: int = 31,
                 lgbm_min_child_samples: int = 20,  # reduced for faster training
                 lgbm_subsample: float = 0.8,
                 lgbm_colsample_bytree: float = 0.8,
                 lgbm_reg_lambda: float = 1.0,
                 lgbm_n_jobs: int = -1,
                 lgbm_importance_type: str = 'gain',
                 lgbm_fast_mode: bool = True,  # use faster settings

                 # sample splitting
                 use_sample_splitting: bool = True,
                 score_split_ratio: float = 0.5,
                 split_seed: int = 0,

                 **kwargs):
        super().__init__(**kwargs)

        self.caps_params = dict(
            eta_G=eta_G, eta_H=eta_H, dispersion=dispersion, device=device
        )

        # Apply fast mode settings if enabled
        if lgbm_fast_mode:
            lgbm_n_estimators = min(lgbm_n_estimators, 50)  # cap at 50 trees
            lgbm_num_leaves = 15  # smaller trees
            lgbm_min_child_samples = 30  # prevent overfitting with small trees

        self.lgbm_params = dict(
            n_estimators=lgbm_n_estimators,
            learning_rate=lgbm_learning_rate,
            num_leaves=lgbm_num_leaves,
            min_child_samples=lgbm_min_child_samples,
            subsample=lgbm_subsample,
            colsample_bytree=lgbm_colsample_bytree,
            reg_lambda=lgbm_reg_lambda,
            n_jobs=lgbm_n_jobs,
            random_state=split_seed,
            verbosity=-1,  # suppress LightGBM warnings
        )
        self.lgbm_importance_type = lgbm_importance_type
        self.lgbm_fast_mode = lgbm_fast_mode

        self.use_sample_splitting = use_sample_splitting
        self.score_split_ratio = score_split_ratio
        self.split_seed = split_seed

        self._log(f"Initialized OCMB-CaPS_LGBMScore (CaPS order + LGBM score, fast_mode={lgbm_fast_mode})")

    def _caps_order(self, X: np.ndarray) -> np.ndarray:
        """Run CaPS to get topological order."""
        n_nodes = X.shape[1]
        from caps_backend import run_caps

        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)

        order, _ = run_caps(
            X_scaled,
            eta_G=self.caps_params['eta_G'],
            eta_H=self.caps_params['eta_H'],
            dispersion=self.caps_params['dispersion'],
            device=self.caps_params['device']
        )
        order = np.array(order, dtype=int)
        if order.shape[0] != n_nodes:
            raise ValueError(f"CaPS returned invalid order length: {len(order)} vs {n_nodes}")
        return order

    def _lgbm_scores_given_order(self, X: np.ndarray, order: np.ndarray) -> np.ndarray:
        """Train per-node LightGBM on predecessors to compute parent_scores."""
        if not HAS_LGBM:
            raise ImportError("lightgbm is not installed. Please `pip install lightgbm`.")

        n, d = X.shape
        order_pos = {node: i for i, node in enumerate(order)}
        S = np.zeros((d, d), dtype=float)

        # Standardize X for stability
        X_mean = X.mean(axis=0, keepdims=True)
        X_std = X.std(axis=0, keepdims=True) + 1e-8
        Xz = (X - X_mean) / X_std

        for j in range(d):
            preds = [i for i in range(d) if order_pos[i] < order_pos[j]]
            if len(preds) == 0:
                continue

            Xj = Xz[:, preds]
            y = Xz[:, j]

            model = LGBMRegressor(**self.lgbm_params)
            model.fit(Xj, y)

            booster = model.booster_
            imp = booster.feature_importance(importance_type=self.lgbm_importance_type).astype(float)
            imp = np.maximum(imp, 0.0)
            if imp.sum() > 0:
                imp = imp / imp.sum()

            # map back to global score matrix
            for k, i in enumerate(preds):
                S[i, j] = imp[k]

        np.fill_diagonal(S, 0.0)
        return S

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        """
        For compatibility with OCMB_Base.fit() timing wrapper.
        If you use sample splitting, we override fit(), so this method is only used when splitting is off.
        """
        order = self._caps_order(X)
        parent_scores = self._lgbm_scores_given_order(X, order)
        return order, parent_scores

    def fit(self, X: np.ndarray, true_adj: np.ndarray = None) -> 'OCMB_CaPS_LGBMScore':
        """
        Override fit to support sample splitting:
          - score split -> CaPS order + LGBM scores
          - CI split -> constrained IAMB / orientation
        """
        X = np.asarray(X)
        n_samples, n_nodes = X.shape
        self._log(f"Input data: {n_samples} samples, {n_nodes} nodes")

        # Reset CMI call counter and clear cache
        self.n_cmi_calls_ = 0
        # Initialize CI score matrix (for AUPR and other ranking metrics)
        self.ci_scores_ = np.zeros((n_nodes, n_nodes), dtype=float)
        self.ci_scores_final_ = np.zeros((n_nodes, n_nodes), dtype=float)
        invalidate_cmi_cache()

        total_start = time.time()

        # split indices
        if self.use_sample_splitting:
            rng = np.random.default_rng(self.split_seed)
            idx = np.arange(n_samples)
            rng.shuffle(idx)
            n_score = int(n_samples * self.score_split_ratio)
            idx_score = idx[:n_score]
            idx_ci = idx[n_score:]
            # fallback if too small
            if len(idx_ci) < 10 or len(idx_score) < 10:
                idx_score = idx
                idx_ci = idx
            self._log(f"  Sample splitting: score={len(idx_score)}, CI={len(idx_ci)}")
        else:
            idx_score = np.arange(n_samples)
            idx_ci = np.arange(n_samples)

        X_score = X[idx_score]
        X_ci = X[idx_ci]

        # Step 1: ordering + score
        self._log("Step 1/4: Getting order (CaPS) + scores (LightGBM)...")
        start = time.time()
        caps_failed = False
        lgbm_failed = False
        try:
            self.order_ = self._caps_order(X_score)
        except Exception as e:
            caps_failed = True
            caps_err = f"{type(e).__name__}: {e}"
            self._log(f"Warning: CaPS failed ({caps_err}), falling back to random ordering")
            self.order_ = get_random_order(n_nodes, self.split_seed)

        score_start = time.time()
        try:
            self.parent_scores_ = self._lgbm_scores_given_order(X_score, self.order_)
        except Exception as e:
            lgbm_failed = True
            lgbm_err = f"{type(e).__name__}: {e}"
            self._log(f"Warning: LightGBM scoring failed ({lgbm_err}), falling back to random scores")
            # Critical fix: use random scores instead of uniform constants (otherwise τ can't cut)
            rng = np.random.default_rng(self.split_seed)
            self.parent_scores_ = rng.random((n_nodes, n_nodes))
            np.fill_diagonal(self.parent_scores_, 0.0)

        # Set backbone status
        if caps_failed:
            self._set_backbone_status("fallback_random", caps_err)
        elif lgbm_failed:
            self._set_backbone_status("caps_lgbm_partial", lgbm_err)  # CaPS succeeded but LGBM failed
        else:
            self._set_backbone_status("caps_lgbm", None)

        self.timings_['ordering'] = time.time() - start
        self.timings_['score'] = time.time() - score_start
        self._log(f"  Ordering+Score completed in {self.timings_['ordering']:.2f}s (score part {self.timings_['score']:.2f}s)")

        # ordering quality metrics
        if true_adj is not None:
            self.ordering_divergence_ = order_divergence(self.order_, true_adj)
            self.ordering_kendall_tau_ = order_kendall_tau(self.order_, true_adj)
            self._log(f"  Ordering divergence: {self.ordering_divergence_:.3f}, Kendall tau: {self.ordering_kendall_tau_:.3f}")

        # Step 2: candidate sets
        self._log("Step 2/4: Building candidate parent/neighbor sets...")
        start = time.time()
        self.CandPa_, self.CandNei_ = self._get_candidate_sets(n_nodes)
        self.timings_['candidates'] = time.time() - start
        avg_cand_pa = np.mean([len(v) for v in self.CandPa_.values()])
        avg_cand_nei = np.mean([len(v) for v in self.CandNei_.values()])
        self._log(f"  Average CandPa size: {avg_cand_pa:.1f}, CandNei size: {avg_cand_nei:.1f}")

        # candidate coverage stats
        if true_adj is not None:
            true_pa = get_true_parents(true_adj)
            true_mb = get_true_markov_blankets(true_adj)
            self._compute_candidate_coverage(true_pa, true_mb)
            self._log(f"  Candidate coverage - Pa: {self.cand_stats_['covPa_mean']:.3f}, MB: {self.cand_stats_['covMB_mean']:.3f}")

        # Step 3: MB learning on CI split
        self._log("Step 3/4: Learning constrained Markov Blankets...")
        # Clear cache (because X_ci is a different data subset)
        invalidate_cmi_cache()
        start = time.time()
        self.MB_ = self._get_MB_constrained(X_ci)
        self.timings_['mb'] = time.time() - start
        avg_mb = np.mean([len(v) for v in self.MB_.values()])
        self._log(f"  MB learning completed in {self.timings_['mb']:.2f}s, avg MB size: {avg_mb:.1f}")

        # Step 4: orientation
        self._log("Step 4/4: Orienting edges based on topological order...")
        start = time.time()
        self.graph_ = self._orient_edges(n_nodes)
        self.timings_['orientation'] = time.time() - start
        n_edges = int(np.sum(self.graph_))
        self._log(f"  Orientation completed: {n_edges} directed edges")

        self.timings_['total'] = time.time() - total_start
        self._log(f"OCMB completed in {self.timings_['total']:.2f}s")
        self._log(f"  Total CMI calls: {self.n_cmi_calls_} (avg {self.n_cmi_calls_/n_nodes:.1f} per node)")

        # Clear CMI cache (release GPU memory)
        invalidate_cmi_cache()

        return self


class OCMB_DiffAN(OCMB_Base):
    """
    OCMB with DiffAN Ordering

    Uses DiffAN (Diffusion-based Algorithm for Causal Networks) as ordering backbone
    """

    def __init__(self,
                 # DiffAN parameters
                 epochs: int = 2000,
                 batch_size: int = 1024,
                 learning_rate: float = 0.001,
                 masking: bool = True,
                 residue: bool = True,
                 pruning_method: str = 'auto',
                 **kwargs):
        super().__init__(**kwargs)
        self.diffan_params = {
            'epochs': epochs,
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'masking': masking,
            'residue': residue,
            'pruning_method': pruning_method,
        }
        self._log("Initialized OCMB-DiffAN (DiffAN backbone)")

    def _get_ordering(self, X: np.ndarray, true_adj: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        n_nodes = X.shape[1]

        try:
            from diffan.diffan import DiffAN
            from diffan.utils import full_DAG

            # Data standardization
            X_norm = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)

            diffan = DiffAN(
                n_nodes=n_nodes,
                **self.diffan_params
            )

            adj_matrix, order = diffan.fit(X_norm)

            # Construct parent_scores
            W = full_DAG(order).astype(float)
            parent_scores = W * 0.5 + adj_matrix * 0.5

            # Mark: actually used DiffAN
            self._set_backbone_status("diffan", None)

            return np.array(order), parent_scores

        except Exception as e:
            # Mark: degraded to random (and record error)
            err = f"{type(e).__name__}: {e}"
            self._set_backbone_status("fallback_random", err)

            self._log(f"Warning: DiffAN failed ({err}), falling back to random ordering")

            # Use seed=2026 to avoid being identical to random baseline (seed=42)
            order_fallback = get_random_order(n_nodes, seed=2026)
            # Critical fix: use random scores instead of uniform constants (otherwise τ can't cut)
            rng = np.random.default_rng(2026)
            scores_fallback = rng.random((n_nodes, n_nodes))
            np.fill_diagonal(scores_fallback, 0.0)

            return order_fallback, scores_fallback


# =============================================================================
# Convenience Functions
# =============================================================================

def run_ocmb(X: np.ndarray, backbone: str = 'oracle', true_adj: np.ndarray = None, **kwargs) -> Tuple[np.ndarray, OCMB_Base]:
    """
    Convenience function: run OCMB algorithm with specified backbone

    Args:
        X: Data matrix (n_samples, n_nodes)
        backbone: ordering backbone type:
            - 'oracle': true ordering + true parent_scores (full Oracle)
            - 'oracle_order_only': only true ordering, uniform parent_scores
            - 'oracle_cand_pa': use true parent sets as CandPa
            - 'oracle_cand_nei': use true MB as CandNei
            - 'random': random ordering (baseline)
            - 'scino': SciNO neural operator
            - 'caps': CaPS algorithm
            - 'caps_lgbm': CaPS order + LightGBM score (sharper candidate parent ranking)
            - 'diffan': DiffAN algorithm
            - 'precomputed': reuse externally cached order + parent_scores (for pipeline acceleration)
        true_adj: True adjacency matrix (required for Oracle classes, optional for computing divergence in others)
        **kwargs: OCMB parameters

    Returns:
        graph: Adjacency matrix
        ocmb: OCMB object (containing detailed results)
    """
    backbone_classes = {
        'oracle': OCMB_Oracle,
        'oracle_order_only': OCMB_OracleOrderOnly,
        'oracle_cand_pa': OCMB_OracleCandPa,
        'oracle_cand_nei': OCMB_OracleCandNei,
        'random': OCMB_Random,
        'scino': OCMB_SciNO,
        'caps': OCMB_CaPS,
        'caps_lgbm': OCMB_CaPS_LGBMScore,  # CaPS order + LightGBM score
        'diffan': OCMB_DiffAN,
        'precomputed': OCMB_Precomputed,
    }

    if backbone not in backbone_classes:
        raise ValueError(f"Unknown backbone: {backbone}. Choose from {list(backbone_classes.keys())}")

    ocmb = backbone_classes[backbone](**kwargs)
    ocmb.fit(X, true_adj=true_adj)

    return ocmb.graph_, ocmb


# =============================================================================
# Testing
# =============================================================================

if __name__ == '__main__':
    print("=" * 80)
    print("OCMB Variants Module Self-Test")
    print("=" * 80)

    # Generate simple test data
    np.random.seed(42)
    n_samples = 500
    n_nodes = 5

    # Simple causal structure: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3, 3 -> 4
    X = np.zeros((n_samples, n_nodes))
    e = np.random.randn(n_samples, n_nodes)
    X[:, 0] = e[:, 0]
    X[:, 1] = 0.8 * X[:, 0] + e[:, 1]
    X[:, 2] = 0.6 * X[:, 0] + e[:, 2]
    X[:, 3] = 0.5 * X[:, 1] + 0.5 * X[:, 2] + e[:, 3]
    X[:, 4] = 0.7 * X[:, 3] + e[:, 4]

    true_graph = np.array([
        [0, 1, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0]
    ])

    print(f"Test data: {n_samples} samples, {n_nodes} nodes")
    print(f"True edges: {int(np.sum(true_graph))}")
    print()

    # Test each backbone (including new Oracle variants)
    test_backbones = [
        'oracle',
        'oracle_order_only',
        'oracle_cand_pa',
        'oracle_cand_nei',
        'random'
    ]

    for backbone in test_backbones:
        print(f"\n{'='*40}")
        print(f"Testing OCMB-{backbone.upper()}")
        print(f"{'='*40}")

        try:
            graph, ocmb = run_ocmb(X, backbone=backbone, true_adj=true_graph, verbose=True)
            metrics = calculate_metrics(true_graph, graph)

            print(f"\nResults:")
            print(f"  SHD: {metrics['SHD']}")
            print(f"  F1: {metrics['F1']:.3f}")
            print(f"  Precision: {metrics['Precision']:.3f}")
            print(f"  Recall: {metrics['Recall']:.3f}")
            if ocmb.ordering_divergence_ is not None:
                print(f"  Ordering Divergence: {ocmb.ordering_divergence_:.3f}")

            # Output candidate coverage statistics (key diagnostic metric)
            if ocmb.cand_stats_:
                print(f"  Candidate Stats:")
                print(f"    covPa: {ocmb.cand_stats_['covPa_mean']:.3f} ± {ocmb.cand_stats_['covPa_std']:.3f}")
                print(f"    covMB: {ocmb.cand_stats_['covMB_mean']:.3f} ± {ocmb.cand_stats_['covMB_std']:.3f}")
                print(f"    avg|CandPa|: {ocmb.cand_stats_['avg_CandPa_size']:.1f}")
                print(f"    avg|CandNei|: {ocmb.cand_stats_['avg_CandNei_size']:.1f}")

            print(f"  Total time: {ocmb.timings_['total']:.2f}s")
        except Exception as e:
            print(f"  Error: {e}")

    print("\n" + "=" * 80)
    print("Self-test completed!")
    print("=" * 80)
