import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
import itertools
from copy import deepcopy
from utils import gm_to_nx_Digraph
import cliquepicking as cp
import networkx as nx
from collections import defaultdict
import pandas as pd
import numpy as np
from sklearn.preprocessing import KBinsDiscretizer
from scipy.special import logsumexp
import time 
import pyAgrum as gum
from multiprocessing import Pool
from functools import partial
from scipy.stats import chi2_contingency, chi2
from scipy.stats import entropy
from causallearn.search.PermutationBased.GRaSP import grasp
import numpy as np
import pandas as pd
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
from dataclasses import dataclass
from typing import Dict, Optional

############################
def _infer_cardinality(col: pd.Series) -> int:
    """Infer #states; assumes discretized ints 0..K-1 if possible, else nunique()."""
    vals = col.to_numpy()
    if np.issubdtype(vals.dtype, np.integer):
        vmin, vmax = int(vals.min()), int(vals.max())
        if vmin >= 0:
            return vmax + 1
    return int(col.nunique())

def dirichlet_postpred_rowwise(
    df: pd.DataFrame,
    node: str,
    parents: list[str],
    alpha_star: float = 5.0,
    cardinalities: Optional[Dict[str, int]] = None,
    # order: 'as_is' uses df's row order; if you want, you can pass a fixed permutation for reproducibility
    order: str = 'as_is',
) -> np.ndarray:
    """
    Return per-row posterior-predictive probabilities under a symmetric Dirichlet(alpha_star/K)
    for the family p(node | parents). The product over rows equals the integrated marginal likelihood.
    """
    if cardinalities is None:
        cardinalities = {}

    # ensure RangeIndex so we can index by position efficiently
    if not isinstance(df.index, pd.RangeIndex):
        df = df.reset_index(drop=True)

    K = cardinalities.get(node, _infer_cardinality(df[node]))
    alpha0 = alpha_star / K

    probs = np.empty(len(df), dtype=float)

    if len(parents) == 0:
        # single stream over all rows
        counts = np.zeros(K, dtype=float)
        total = 0.0
        x_vals = df[node].astype(int).to_numpy()

        if order != 'as_is':
            # if you want to permute for prequential scoring, do it here
            idxs = np.arange(len(df))
        else:
            idxs = np.arange(len(df))

        for i in idxs:
            x = x_vals[i]
            # posterior predictive for this row given prior+past rows
            probs[i] = (counts[x] + alpha0) / (total + alpha_star)
            counts[x] += 1.0
            total += 1.0
        return probs

    # group by parent configurations; handle each stream independently
    groups = df.groupby(parents, sort=False).groups  # dict: key -> index labels
    node_vals = df[node].astype(int).to_numpy()

    for _, idx_labels in groups.items():
        # convert to positional indices (RangeIndex → same ints)
        idxs = np.fromiter(idx_labels, dtype=int)
        # (optional) you could permute idxs reproducibly here if desired
        counts = np.zeros(K, dtype=float)
        total = 0.0
        for i in idxs:
            x = node_vals[i]
            probs[i] = (counts[x] + alpha0) / (total + alpha_star)
            counts[x] += 1.0
            total += 1.0

    return probs

def discrete_likelihood_fn_dirichlet(
    df,
    node,
    parents,
    alpha_star: float = 5.0,
    cardinalities: Optional[Dict[str, int]] = None,
):
    return dirichlet_postpred_rowwise(
        df=df,
        node=node,
        parents=parents,
        alpha_star=alpha_star,
        cardinalities=cardinalities,
        order='as_is',
    )
# ---- Precompute global moments once ----
class GaussianMoments:
    def __init__(self, cols, mean, cov):
        self.cols = cols
        self.col_to_idx = {c:i for i,c in enumerate(cols)}
        self.mean = mean  # shape (d,)
        self.cov = cov    # shape (d,d)

def precompute_gaussian_moments(joint_df: pd.DataFrame, cols=None) -> GaussianMoments:
    if cols is None:
        cols = list(joint_df.columns)
    X = joint_df[cols].to_numpy(dtype=float)
    mu = X.mean(axis=0)
    # ddof=0 → MLE covariance
    Sigma = np.cov(X, rowvar=False, ddof=0)
    return GaussianMoments(cols, mu, Sigma)

# ---- Fast conditional Gaussian logpdf per row ----
def gaussian_conditional_logpdf_rows(joint_df: pd.DataFrame, node: str, parents: list, gm: GaussianMoments,
                                     ridge=1e-8) -> pd.Series:
    idx_x = gm.col_to_idx[node]
    idx_z = [gm.col_to_idx[p] for p in parents]

    Xvals = joint_df[[node]].to_numpy(dtype=float)  # (n,1)

    mu_x = gm.mean[idx_x]
    var_xx = gm.cov[idx_x, idx_x]

    if len(idx_z) == 0:
        # p(x) ~ N(mu_x, var_xx)
        resid = Xvals[:, 0] - mu_x
        # clamp variance to be positive
        var = max(var_xx, ridge)
        log_norm_const = -0.5 * (np.log(2*np.pi*var))
        ll = log_norm_const - 0.5 * (resid**2) / var
        return pd.Series(ll, index=joint_df.index, name=f'logpdf_{node}')

    Zvals = joint_df[parents].to_numpy(dtype=float)         # (n,k)
    mu_z = gm.mean[idx_z]                                   # (k,)
    Sigma_xz = gm.cov[idx_x, idx_z].reshape(1, -1)          # (1,k)
    Sigma_zx = Sigma_xz.T                                   # (k,1)
    Sigma_zz = gm.cov[np.ix_(idx_z, idx_z)].copy()          # (k,k)

    # Numerical stabilization
    trace = np.trace(Sigma_zz)
    lam = ridge if trace == 0 else ridge * trace / len(idx_z)
    Sigma_zz.flat[::len(idx_z)+1] += lam  # add lam to diag

    # Solve Sigma_zz^{-1} via Cholesky
    try:
        L = np.linalg.cholesky(Sigma_zz)
        def chol_solve(b):
            # solve Sigma_zz * x = b for x
            y = np.linalg.solve(L, b)
            return np.linalg.solve(L.T, y)
        Sigma_zz_inv_Sigma_zx = chol_solve(Sigma_zx)                # (k,1)
        A = Sigma_xz @ chol_solve(np.eye(len(idx_z)))               # (1,k) @ (k,k) = (1,k)
    except np.linalg.LinAlgError:
        # fallback
        Sigma_zz_inv = np.linalg.pinv(Sigma_zz)
        Sigma_zz_inv_Sigma_zx = Sigma_zz_inv @ Sigma_zx
        A = Sigma_xz @ Sigma_zz_inv

    # Conditional mean/variance
    mu_cond = mu_x + (A @ (Zvals - mu_z).T).ravel()                 # (n,)
    var_cond = var_xx - (Sigma_xz @ Sigma_zz_inv_Sigma_zx)[0, 0]
    var_cond = float(max(var_cond, ridge))

    # Log-density per row
    resid = Xvals[:, 0] - mu_cond
    log_norm_const = -0.5 * (np.log(2*np.pi*var_cond))
    ll = log_norm_const - 0.5 * (resid**2) / var_cond
    return pd.Series(ll, index=joint_df.index, name=f'logpdf_{node}|{",".join(parents)}')

# ---- Convenience wrapper with caching across families ----
_moments_cache = {}
def rowwise_loglik_gaussian(joint_df, node, parents, cols=None, cache_key="__all__"):
    gm = _moments_cache.get(cache_key)
    if gm is None:
        gm = precompute_gaussian_moments(joint_df if cols is None else joint_df[cols])
        _moments_cache[cache_key] = gm
    return gaussian_conditional_logpdf_rows(joint_df, node, parents, gm)

@dataclass
class ConditionalKDE:
    x_name: str
    z_names: list
    scaler_joint: StandardScaler | None
    scaler_z: StandardScaler | None
    kde_joint: KernelDensity        # KDE on [x, z] if z exists, else KDE on x
    kde_z: KernelDensity | None     # KDE on z (None if no parents)

    def logpdf(self, x_val, z_val=None):
        """
        Evaluate log p(x | z) at a single point.
        If no parents, returns log p(x).
        """
        if len(self.z_names) == 0:
            x_val = np.asarray([x_val], dtype=float).reshape(1, 1)
            X_std = self.scaler_joint.transform(x_val)
            return self.kde_joint.score_samples(X_std)[0]

        z_val = np.asarray(z_val, dtype=float).reshape(1, -1)
        x_val = np.asarray([x_val], dtype=float).reshape(1, 1)
        X_joint = np.hstack([x_val, z_val])

        X_joint_std = self.scaler_joint.transform(X_joint)
        Z_std = self.scaler_z.transform(z_val)

        log_p_joint = self.kde_joint.score_samples(X_joint_std)[0]
        log_p_z = self.kde_z.score_samples(Z_std)[0]
        return log_p_joint - log_p_z

    def pdf(self, x_val, z_val=None):
        return np.exp(self.logpdf(x_val, z_val))

    def logpdf_rows(self, df: pd.DataFrame):
        """
        Vectorized per-row log p(x | z). If no parents, returns log p(x).
        """
        if len(self.z_names) == 0:
            X = df[[self.x_name]].to_numpy(dtype=float)
            X_std = self.scaler_joint.transform(X)
            return pd.Series(self.kde_joint.score_samples(X_std), index=df.index,
                             name=f'logpdf_{self.x_name}')

        X = df[[self.x_name] + self.z_names].to_numpy(dtype=float)
        Z = df[self.z_names].to_numpy(dtype=float)

        X_joint_std = self.scaler_joint.transform(X)
        Z_std = self.scaler_z.transform(Z)

        log_p_joint = self.kde_joint.score_samples(X_joint_std)
        log_p_z = self.kde_z.score_samples(Z_std)
        return pd.Series(log_p_joint - log_p_z, index=df.index,
                         name=f'logpdf_{self.x_name}|{",".join(self.z_names)}')

    def pdf_rows(self, df: pd.DataFrame):
        return self.logpdf_rows(df).apply(np.exp)

def _fit_kde(X_std, bandwidth_grid=(0.1, 0.2, 0.5, 1.0, 2.0)):
    grid = {'bandwidth': bandwidth_grid}
    search = GridSearchCV(KernelDensity(kernel='gaussian'), grid, cv=5, n_jobs=-1)
    search.fit(X_std)
    return search.best_estimator_

def estimate_conditional_pdf(joint_df: pd.DataFrame, x: str, ls_of_parents: list,
                             bandwidth_grid_joint=(0.1, 0.2, 0.5, 1.0, 2.0),
                             bandwidth_grid_z=(0.1, 0.2, 0.5, 1.0, 2.0)) -> ConditionalKDE:
    """
    Build a conditional KDE estimator for continuous data.
    Handles the no-parents case by fitting a marginal KDE p(x).
    """
    # safety checks
    if x not in joint_df.columns:
        raise ValueError(f"Column '{x}' not found in joint_df.")
    for c in ls_of_parents:
        if c not in joint_df.columns:
            raise ValueError(f"Column '{c}' not found in joint_df.")

    cols_joint = [x] + ls_of_parents
    data_joint = joint_df[cols_joint].dropna().to_numpy(dtype=float)

    if len(ls_of_parents) == 0:
        # Fit marginal p(x)
        X = data_joint[:, [0]]  # shape (n,1)
        scaler_x = StandardScaler().fit(X)
        X_std = scaler_x.transform(X)
        kde_x = _fit_kde(X_std, bandwidth_grid_joint)
        return ConditionalKDE(
            x_name=x,
            z_names=[],
            scaler_joint=scaler_x,  # scaler over x only
            scaler_z=None,
            kde_joint=kde_x,        # stores marginal KDE
            kde_z=None
        )

    # parents present: fit p(x,z) and p(z)
    Z = data_joint[:, 1:]  # parents only
    scaler_joint = StandardScaler().fit(data_joint)
    scaler_z = StandardScaler().fit(Z)

    data_joint_std = scaler_joint.transform(data_joint)
    data_z_std = scaler_z.transform(Z)

    kde_joint = _fit_kde(data_joint_std, bandwidth_grid_joint)
    kde_z = _fit_kde(data_z_std, bandwidth_grid_z)

    return ConditionalKDE(
        x_name=x,
        z_names=ls_of_parents,
        scaler_joint=scaler_joint,
        scaler_z=scaler_z,
        kde_joint=kde_joint,
        kde_z=kde_z
    )

def mutual_info(df, x_col, y_col):
    """
    Compute mutual information between two columns in a dataframe.
    
    Args:
        df (pd.DataFrame): DataFrame containing the columns
        x_col (str): Name of first column
        y_col (str): Name of second column
        
    Returns:
        float: Mutual information between x_col and y_col
    """
    # Create contingency table
    contingency = pd.crosstab(df[x_col], df[y_col])
    
    # Convert to probabilities
    p_xy = contingency / contingency.sum().sum()
    
    # Compute marginal probabilities
    p_x = p_xy.sum(axis=1)
    p_y = p_xy.sum(axis=0)
    
    # Compute mutual information
    mi = 0
    for i in range(len(p_x)):
        for j in range(len(p_y)):
            if p_xy.iloc[i,j] > 0:  # Avoid log(0)
                mi += p_xy.iloc[i,j] * np.log(p_xy.iloc[i,j] / (p_x[i] * p_y[j]))
    
    return mi


def conditional_chi2_test(df, x_col, y_col, z_cols=[]):
    """
    Conducts a conditional chi-square test for X and Y given multiple Zs.

    If z_cols is empty, does a regular chi-square test.

    Args:
        df (pd.DataFrame): DataFrame containing X, Y, and (optionally) Zs.
        x_col (str): Name of the X column.
        y_col (str): Name of the Y column.
        z_cols (list[str]): List of names of Z columns (can be empty).

    Returns:
        chi2_stat_total, dof_total, p_value_total
    """
    chi2_stat_total = 0
    dof_total = 0

    if len(z_cols) == 0:
        # No conditioning: regular chi-square test
        contingency = pd.crosstab(df[x_col], df[y_col])

        if contingency.shape[0] < 2 or contingency.shape[1] < 2:
            return None, None, None

        chi2_stat, p, dof, expected = chi2_contingency(contingency, correction=False)
        return chi2_stat, dof, p
    else:
        # Conditioning on Z
        grouped = df.groupby(z_cols)

        for _, group in grouped:
            contingency = pd.crosstab(group[x_col], group[y_col])

            if contingency.shape[0] < 2 or contingency.shape[1] < 2:
                continue

            chi2_stat, p, dof, expected = chi2_contingency(contingency, correction=False)
            chi2_stat_total += chi2_stat
            dof_total += dof

        if dof_total == 0:
            return None, None, None

        from scipy.stats import chi2
        p_value_total = 1 - chi2.cdf(chi2_stat_total, dof_total)

        return chi2_stat_total, dof_total, p_value_total
    
def build_cpt(df, node, parents, alpha=1e-6):
    """
    Returns a dict mapping
       (parents_vals_tuple, node_val) -> P(node=node_val | parents=parents_vals_tuple)
    computed by one groupby+merge.
    """
    if not parents:
        # just a prior P(node)
        probs = df[node].value_counts(normalize=True).to_dict()
        return { ((), val): p for val,p in probs.items() }

    joint = (
        df
        .groupby(parents + [node])
        .size()
        .rename("count")
        .reset_index()
    )
    totals = (
        df
        .groupby(parents)
        .size()
        .rename("total")
        .reset_index()
    )
    merged = joint.merge(totals, on=parents)
    return {
        (tuple(row[p] for p in parents), row[node]): row["count"]/row["total"]
        for _, row in merged.iterrows()
    }

def _is_adjacent_pdag(pdg, a, b) -> bool:
    """
    True if a and b are adjacent in the PDAG in ANY way:
    - undirected edge a—b, or
    - directed arc a->b or b->a.
    Some PDAG libs differentiate pdg.has_edge (undirected) vs pdg.has_arc (directed).
    We guard for both directions just in case.
    """
    return (
        pdg.has_edge(a, b) or pdg.has_edge(b, a) or
        pdg.has_arc(a, b)  or pdg.has_arc(b, a)
    )

def has_new_unshielded_collider_at(pdg_completed, node, original_parents_at_node):
    """
    Return True if, at `node`, we created an unshielded collider a->node<-b
    that wasn't already present in the baseline completion *at node*.
    """
    current_parents = list(pdg_completed.parents_of(node))
    # any pair of parents that are NOT adjacent create a collider at node
    for i in range(len(current_parents)):
        for j in range(i+1, len(current_parents)):
            a, b = current_parents[i], current_parents[j]
            if not _is_adjacent_pdag(pdg_completed, a, b):
                # It’s a collider at node. Decide whether it is “new”.
                if not (a in original_parents_at_node and b in original_parents_at_node):
                    return True
    return False

def isValidConfiguration(pdg, original_parents_at_node, node) -> bool:
    """
    Validate a candidate where undirected edges at `node` were oriented:
      1) complete PDAG with Meek,
      2) ensure arcs are acyclic,
      3) forbid NEW unshielded colliders *at node* (local check).
    """
    # 1) Complete
    pdg_completed = pdg.copy()
    pdg_completed.to_complete_pdag()

    # 2) Acyclicity on directed arcs
    G = nx.DiGraph()
    G.add_edges_from(pdg_completed.arcs)
    if not nx.is_directed_acyclic_graph(G):
        return False

    # 3) No new unshielded colliders anywhere
    if has_new_unshielded_collider_at(pdg_completed, node, original_parents_at_node):
        return False

    return True

def undirected_neighbors_of(cpdag, node):
    """Return list of nodes that share an UNdirected edge with `node`."""
    nbrs = []
    # cpdag.edges is the set/list of undirected edges (u,v)
    for (u, v) in cpdag.edges:
        if u == node: nbrs.append(v)
        elif v == node: nbrs.append(u)
    # dedupe in case edges has duplicates
    return list(dict.fromkeys(nbrs))
    
def debug_config_rejections(cpdag, node):
    undirected_neighbors = undirected_neighbors_of(cpdag, node)
    original_parents = list(cpdag.parents_of(node))
    reasons = {"acyclicity":0, "new_collider_at_node":0, "not_undirected":0, "ok":0}

    if len(undirected_neighbors) == 0:
        print(f"[DBG] Node {node}: no undirected neighbors.")
        return

    for combo in itertools.product([0,1], repeat=len(undirected_neighbors)):
        cand = cpdag.copy()
        valid_undirected = True
        for bit, nbr in zip(combo, undirected_neighbors):
            if not (cand.has_edge(node, nbr) or cand.has_edge(nbr, node)):
                reasons["not_undirected"] += 1
                valid_undirected = False
                break
            cand.replace_edge_with_arc((node, nbr) if bit == 0 else (nbr, node))
        if not valid_undirected:
            continue

        pdg_completed = cand.copy()
        pdg_completed.to_complete_pdag()
        G = nx.DiGraph()
        G.add_edges_from(pdg_completed.arcs)
        if not nx.is_directed_acyclic_graph(G):
            reasons["acyclicity"] += 1
            continue

        if has_new_unshielded_collider_at(pdg_completed, node, original_parents):
            reasons["new_collider_at_node"] += 1
            continue

        reasons["ok"] += 1

    print(f"[DBG] Node {node} combos: {2**len(undirected_neighbors)} → {reasons}")
    
def getConfigurations(cpdag, node):
    # get all undirected edges with endpoint that is node
    undirected_neighbors = list(cpdag._undirected_neighbors[node])

    # get all existing parents
    parents = list(cpdag.parents_of(node))
    configurations = []

    if len(undirected_neighbors) == 0:
        cpdag_copy1 = cpdag.copy()
        configurations.append(cpdag_copy1)
        return configurations

    # Special case: if there's only one undirected neighbor, we need to consider both directions
    if len(undirected_neighbors) == 1:
        nbr = undirected_neighbors[0]
        
        # Direction 1: node -> neighbor
        # cpdag_copy1 = deepcopy(cpdag)
        cpdag_copy1 = cpdag.copy()
        cpdag_copy1.replace_edge_with_arc((node, nbr))
        if isValidConfiguration(cpdag_copy1, parents, node):
            configurations.append(cpdag_copy1)
        
        # Direction 2: neighbor -> node
        #cpdag_copy2 = deepcopy(cpdag)
        cpdag_copy2 = cpdag.copy()
        cpdag_copy2.replace_edge_with_arc((nbr, node))
        if isValidConfiguration(cpdag_copy2, parents, node):
            configurations.append(cpdag_copy2)
            
        return configurations
    
    # For multiple undirected neighbors, use the original approach
    for combo in itertools.product([0, 1], repeat=len(undirected_neighbors)):
        # cpdag_copy = deepcopy(cpdag)
        cpdag_copy = cpdag.copy()
        # Process each undirected neighbor one at a time
        for i, (counter, nbr) in enumerate(zip(combo, undirected_neighbors)):
            if counter == 0:
                # Direction: node -> neighbor
                cpdag_copy.replace_edge_with_arc((node, nbr))
            else:
                # Direction: neighbor -> node
                cpdag_copy.replace_edge_with_arc((nbr, node))
        
        # check if this configuration is valid
        if isValidConfiguration(cpdag_copy, parents, node):
            configurations.append(cpdag_copy)

    return configurations

def convert_graphical_model_object_to_nx_Digraph(cpdag, name_to_id_dict):
    nx_graph = nx.DiGraph()
    # add nodes to the graph
    cpdag_nodeids = [name_to_id_dict[nodename] for nodename in cpdag.nodes]
    nx_graph.add_nodes_from(cpdag_nodeids)

    for u, v in cpdag.edges:
        uid = name_to_id_dict[u]
        vid = name_to_id_dict[v]
        nx_graph.add_edge(uid, vid)
        nx_graph.add_edge(vid, uid)

    for u, v in cpdag.arcs:
        uid = name_to_id_dict[u]
        vid = name_to_id_dict[v]
        nx_graph.add_edge(uid, vid)

    return nx_graph

def sampleAugmentedGraphs(cpdag, nodenames, potential_root_causes, k=np.inf):
    # given a cpdag, get a set of I-Markov in-equavialent augmented DAGs
    count = 0 
    arugmented_dags_dict = defaultdict(list)
    mec_sizes = defaultdict(list)
    name_to_id = {nodename:i for i, nodename in enumerate(nodenames)}
    name_to_id['FNODE'] =  max(name_to_id.values()) + 1
    #print('name_to_id:{}'.format(name_to_id))
    id_to_name = {id : name for name, id in name_to_id.items()}
    #print('id_to_names:{}'.format(id_to_name))

    # for each node x , this follows numeric ascending order for the names
    # parallelize configuration‐generation across roots
    ########## NEW  ############
    partial_get = partial(getConfigurations, cpdag)

    #partial_get = partial(getConfigurations, cpdag)
    num_cores_to_use = min(4, os.cpu_count())
    with Pool(num_cores_to_use) as pool:
        # results is a list of lists of cpdag copies
        configs_per_root = pool.map(partial_get, potential_root_causes)

    # now proceed as before, but using the parallel results
    for root, all_configs_of_root in zip(potential_root_causes, configs_per_root):
        rootid = name_to_id[root]
        for config_of_root in all_configs_of_root:
            # apply meek rules
            config_of_root.to_complete_pdag()
            ########## NEW  ############
            # convert indices from strings to integer
            nx_graph = convert_graphical_model_object_to_nx_Digraph(config_of_root, name_to_id)
            #print('nx_graph edges:{}'.format(nx_graph.edges))
            nx_graph.add_node(name_to_id['FNODE'])
            nx_graph.add_edge(name_to_id['FNODE'], rootid)
            size = cp.mec_size(list(nx_graph.edges))
            sampler = cp.MecSampler(list(nx_graph.edges))
            sampled_augmented_dag = sampler.sample_dag()
            new_sample_dag_edges_ls = [(id_to_name[e1], id_to_name[e2]) for e1, e2 in sampled_augmented_dag]
            arugmented_dags_dict[root].append(nx.DiGraph(new_sample_dag_edges_ls))
            mec_sizes[root].append(size)
            count += 1
            if count == k:
                return arugmented_dags_dict, mec_sizes
            #  sampled_dags = cp.mec_list_dags(list(nx_graph.edges))
            # for sampled_augmented_dag in sampled_dags:   
            #     size = cp.mec_size(sampled_augmented_dag)
            #     # sampler = cp.MecSampler(list(nx_graph.edges))
            #     # sampled_augmented_dag = sampler.sample_dag()
            #     # add the DAG to the list
                
            #     # map the integers back to the strings
            #     new_sample_dag_edges_ls = [(id_to_name[e1], id_to_name[e2]) for e1, e2 in sampled_augmented_dag]
            #     arugmented_dags_dict[root].append(nx.DiGraph(new_sample_dag_edges_ls))
            #     # record the size of the DAG
            #     mec_sizes[root].append(size)
            #     count += 1
            #     if count == k:
            #         return arugmented_dags_dict, mec_sizes
    return arugmented_dags_dict, mec_sizes


def compute_conditional_prob(row, ie, target_var, conditioning_vars):
    # Clear any previous evidence
    ie.setEvidence({})  # reset inference engine
    # If conditioning variables exist, build and set evidence
    if conditioning_vars:
        evidence = {var: int(row[var]) for var in conditioning_vars}
        ie.setEvidence(evidence)

    # Compute posterior P(target_var | evidence)
    posterior = ie.posterior(target_var)

    # Return the probability of the value observed in the row
    return posterior[int(row[target_var])]




def compute_conditional_probs_cached(df, ie, target_var,
                                     conditioning_vars=None,
                                     alpha=1e-6):
    """
    Computes P(target_var = y | conditioning_vars = x) for each row in df,
    by caching unique evidence combinations in a dict, then doing one
    pandas merge to broadcast back to all rows.

    Parameters:
    - df: pandas DataFrame containing all relevant variables
    - ie: pyAgrum inference engine (e.g., gum.LazyPropagation(bn))
    - target_var: name of the target variable (str)
    - conditioning_vars: list of variable names to condition on (can be empty or None)
    - alpha: fallback probability for unseen combinations (default: 1e-6)

    Returns:
    - A NumPy array of conditional probabilities, aligned with df rows
    """
    if conditioning_vars is None:
        conditioning_vars = []

    lookup = {}

    # Case 1: No conditioning variables → compute once (prior)
    if not conditioning_vars:
        ie.setEvidence({})
        posterior = ie.posterior(target_var)
        # map each row's value to its prior
        return df[target_var].astype(int).map(lambda val: posterior[val]).values

    # Case 2: With conditioning → build lookup for all unique combinations
    unique_rows = df[conditioning_vars + [target_var]].drop_duplicates()
    for _, row in unique_rows.iterrows():
        evidence = {var: int(row[var]) for var in conditioning_vars}
        ie.setEvidence(evidence)
        posterior = ie.posterior(target_var)
        key = tuple(int(row[var]) for var in conditioning_vars) \
              + (int(row[target_var]),)
        lookup[key] = posterior[int(row[target_var])]

    # ──────────────── vectorized broadcast ────────────────

    if not lookup:
        # no combos? fallback to alpha
        return np.full(len(df), alpha)

    # Build a tiny lookup‐table DataFrame
    # each key is (c1, c2, ..., target), value is prob
    rows = []
    cols = conditioning_vars + [target_var]
    for key, prob in lookup.items():
        entry = dict(zip(cols, key))
        entry["prob"] = prob
        rows.append(entry)
    lookup_df = pd.DataFrame(rows)

    # Merge it onto the full df at once
    merged = df.merge(lookup_df, on=cols, how="left")

    # Fill any missing combos with alpha
    return merged["prob"].fillna(alpha).values



def discrete_likelihood_fn(df, node, parents, obs_ie, int_ie, obs_ground_truth, int_ground_truth, alpha=1e-6):
    """
    Compute the likelihood for a discrete node with string entries.
    
    Parameters:
        df (pd.DataFrame): The dataset.
        node (str): The name of the current node.
        parents (list): A list of parent node names.
        alpha (float): A small number to assign to unseen parent-child combinations.
        
    Returns:
        np.ndarray: A vector of likelihoods for the node over all rows in df.
    """
    if not parents:
        # Compute the empirical probability of each value in the column
        if obs_ground_truth:
            if node == 'FNODE':
                # directly compute the distribution for F-node from data 
                probs = df[node].value_counts(normalize=True).to_dict()
                return df[node].map(probs).values
            else:
                # REPLACE WITH PRE-COMPUTE
                return compute_conditional_probs_cached(df, obs_ie, node, conditioning_vars=None, alpha=1e-6)
        else:
            # REPLACE WITH PRE-COMPUTE
            probs = df[node].value_counts(normalize=True).to_dict()
            # Map each observation to its probability
            return df[node].map(probs).values
    else:
        if obs_ground_truth:
            if 'FNODE' in parents:
                n_df = df[df['FNODE'] == 0]
                a_df = df[df['FNODE'] == 1]
                pa_without_F = [pa for pa in parents if pa != 'FNODE']
                if not pa_without_F:
                    # if there is no other parent besides F-NODE
                    probs = a_df[node].value_counts(normalize=True).to_dict()
                    # Map each observation to its probability
                    int_dist = a_df[node].map(probs).values
                else:
                    if not int_ground_truth:
                        # get the int distribution from data
                        joint_counts = a_df.groupby(pa_without_F + [node]).size().reset_index(name='count')
                        # Then, compute the totals for each parent's configuration
                        parent_totals = a_df.groupby(pa_without_F).size().reset_index(name='total')
                        # Merge to get the conditional probability for each joint combination
                        merged = pd.merge(joint_counts, parent_totals, on=pa_without_F)
                        merged['prob'] = merged['count'] / merged['total']
                        # Merge the computed probabilities back to the original dataframe
                        df_merged = pd.merge(a_df, merged[pa_without_F + [node, 'prob']], on=pa_without_F + [node], how='left')
                        # For any unseen parent-child combination, fill in a small probability
                        df_merged['prob'].fillna(alpha, inplace=True)
                        int_dist = df_merged['prob'].values
                    else:
                        int_dist = compute_conditional_probs_cached(a_df, int_ie, node, conditioning_vars=pa_without_F, alpha=1e-6)
                obs_dist = compute_conditional_probs_cached(n_df, obs_ie, node, conditioning_vars=pa_without_F, alpha=1e-6)
                return np.append(obs_dist, int_dist)
            else:
                # REPLACE WITH PRE-COMPUTE
                return compute_conditional_probs_cached(df, obs_ie, node, conditioning_vars=parents, alpha=1e-6)
            
        else:
            if 'FNODE' in parents:
                pa_without_F = [pa for pa in parents if pa != 'FNODE']
                if int_ground_truth:
                    n_df = df[df['FNODE'] == 0]
                    a_df = df[df['FNODE'] == 1]
                    int_dist = compute_conditional_probs_cached(a_df, int_ie, node, conditioning_vars=pa_without_F, alpha=1e-6)
                    obs_dist = compute_conditional_probs_cached(n_df, int_ie, node, conditioning_vars=pa_without_F, alpha=1e-6)
                    return np.append(obs_dist, int_dist)
                else:
                    joint_counts = df.groupby(parents + [node]).size().reset_index(name='count')
                    # Then, compute the totals for each parent's configuration
                    parent_totals = df.groupby(parents).size().reset_index(name='total')
                    # Merge to get the conditional probability for each joint combination
                    merged = pd.merge(joint_counts, parent_totals, on=parents)
                    merged['prob'] = merged['count'] / merged['total']
                    # Merge the computed probabilities back to the original dataframe
                    df_merged = pd.merge(df, merged[parents + [node, 'prob']], on=parents + [node], how='left')
                    # For any unseen parent-child combination, fill in a small probability
                    df_merged['prob'].fillna(alpha, inplace=True)
                    return df_merged['prob'].values
            else:
                # TO DO: If F is in the parents, then we compute
                # IF NOT, we then just get the pre-computed likelihood

                # Compute conditional probabilities for each combination of parent's values and node
                # First, count the joint occurrences
                joint_counts = df.groupby(parents + [node]).size().reset_index(name='count')
                # Then, compute the totals for each parent's configuration
                parent_totals = df.groupby(parents).size().reset_index(name='total')
                # Merge to get the conditional probability for each joint combination
                merged = pd.merge(joint_counts, parent_totals, on=parents)
                merged['prob'] = merged['count'] / merged['total']
                # Merge the computed probabilities back to the original dataframe
                df_merged = pd.merge(df, merged[parents + [node, 'prob']], on=parents + [node], how='left')
                # For any unseen parent-child combination, fill in a small probability
                df_merged['prob'].fillna(alpha, inplace=True)
                return df_merged['prob'].values


# cache: dict to avoid refitting the same family
_cond_cache = {}

def rowwise_likelihoods(joint_df, node, parents, return_log=False, cache=_cond_cache):
    key = (node, tuple(parents))
    if key not in cache:
        cache[key] = estimate_conditional_pdf(joint_df, x=node, ls_of_parents=parents)
    cond = cache[key]
    out = cond.logpdf_rows(joint_df) if return_log else cond.pdf_rows(joint_df)
    return out  # pandas Series indexed like joint_df


_local_factor_cache = {}
def compute_local_likelihoods(dag, df, ie,  obs_ground_truth, int_ground_truth):
    """
    Compute local likelihoods for each node in the DAG over the dataframe using a common discrete likelihood function.
    Works in log-space to prevent numerical underflow.
    
    Parameters:
        dag (networkx.DiGraph): The DAG with nodes representing variables.
        df (pd.DataFrame): The discrete dataset where each column is a variable.
        
    Returns:
        np.ndarray: A 1D array of shape (n_samples, 1) where each row corresponds to joint distribution based on a sample.
    """
    factors = []
    
    # Process nodes in topological order to respect dependency order
    for node in nx.topological_sort(dag):
        # Use the discrete likelihood function for every node
        parents = list(dag.predecessors(node))
        parents = [str(pa) for pa in parents]

        cache_key = (node, tuple(parents), obs_ground_truth)
        if cache_key not in _local_factor_cache:
            _local_factor_cache[cache_key] = discrete_likelihood_fn(
                df,
                str(node),
                parents,
                ie,
                obs_ground_truth,
                int_ground_truth
            )
        
        local_factor = _local_factor_cache[cache_key]
        
        if len(local_factor) != len(df):
            raise ValueError(f"Likelihood for node '{node}' returned an array of incorrect length.")
        
        # Convert to log space to prevent underflow
        log_local_factor = np.log(local_factor + 1e-300)  # Add small constant to avoid log(0)
        
        # p(v|pa(v)) over all rows in the dataframe
        factors.append(log_local_factor)
    
    # Stack the log factors
    log_joint = np.column_stack(factors)
    
    # Sum the log factors (equivalent to multiplying the original factors)
    log_joint = np.sum(log_joint, axis=1)
    
    # Convert back to probability space if needed
    # joint = np.exp(log_joint)
    
    # Return the log joint probability
    return log_joint



def brcd(normal_df,
         anomalous_df,
         cpdag,
         obs_bn=None,
         int_bn=None,
         obs_ground_truth=False,
         k=np.inf,
         discretize=False,
         int_ground_truth=False,
         version='brcd_c'):
    # ───────────────────────────────────────────────────────────────
    # 1) Prepare data + inference engine
    _local_factor_cache.clear()
    obs_ie = gum.LazyPropagation(obs_bn) if obs_ground_truth else None
    int_ie = gum.LazyPropagation(int_bn) if int_ground_truth else None

    df_obs = normal_df.copy()
    df_int = anomalous_df.copy()
    joint_df = pd.concat([df_obs, df_int], ignore_index=True)

    if discretize:
        # kbd = KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='kmeans')
        # Xb = kbd.fit_transform(joint_df.to_numpy()).astype(int)
        # joint_df = pd.DataFrame(Xb, columns=normal_df.columns)
        #########################################################
        kbd = KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='quantile')
        kbd.fit(normal_df.to_numpy())

        X_obs_binned = kbd.transform(normal_df.to_numpy())
        X_int_binned = kbd.transform(anomalous_df.to_numpy())


        df_obs = pd.DataFrame(X_obs_binned, columns=normal_df.columns).astype(int)
        df_int = pd.DataFrame(X_int_binned, columns=normal_df.columns).astype(int)

        joint_df = pd.concat([df_obs, df_int], ignore_index=True)

        #########################################################


    # else:
    #     moments = precompute_gaussian_moments(joint_df)

    
    joint_df['FNODE'] = np.r_[np.zeros(len(df_obs), dtype=int), np.ones(len(df_int), dtype=int)]

    if version == 'brcd_c' and discretize:
        chisq_weights = []
        for i in normal_df.columns:
            ch, _, _ = conditional_chi2_test(joint_df, 'FNODE', i, [])
            if ch is None:
                ch = 0
            chisq_weights.append(ch)
        prior = np.array(chisq_weights) / np.sum(chisq_weights)
    else:
        prior = np.ones(len(normal_df.columns)) / len(normal_df.columns)
    

    potential_root_causes = list(normal_df.columns)
    

    # ───────────────────────────────────────────────────────────────
    # 2) Sample augmented DAGs

    sampled_augmented, mec_sizes = sampleAugmentedGraphs(cpdag, list(normal_df.columns),
                                                         potential_root_causes,
                                                         k)


    # ───────────────────────────────────────────────────────────────
    # 3) Build a cache of every unique (node, parents) family *once*
    unique_families = {}
    for dags in sampled_augmented.values():
        for dag in dags:
            for node in dag.nodes():
                parents = tuple(sorted(dag.predecessors(node)))
                unique_families.setdefault((node, parents), None)
    start_time = time.time()
    # Compute & store their log‑likelihood vectors
    # cardinalities = {col: int(joint_df[col].max()) + 1 for col in joint_df.columns if col != 'FNODE'}


    for (node, parents) in unique_families:
        #if discretize:
        # probs = discrete_likelihood_fn(
        #     joint_df,
        #     str(node),
        #     [str(p) for p in parents],
        #     obs_ie,
        #     int_ie,
        #     obs_ground_truth,
        #     int_ground_truth
        # )
        probs = discrete_likelihood_fn_dirichlet(
                    joint_df,
                    str(node),
                    [str(p) for p in parents],
                    alpha_star=5.0,                  # tune: 1..10 common; larger = stronger smoothing
                    cardinalities=None               # or pass a precomputed {var: K} dict
                )
        
        unique_families[(node, parents)] = np.log(probs + 1e-300)
        
        # else:
        #     #lik_series = rowwise_likelihoods(joint_df, str(node), [str(p) for p in parents])
        #     ll_series = gaussian_conditional_logpdf_rows(joint_df, str(node), [str(p) for p in parents], moments)
        #     log_lik_series = np.log(ll_series + 1e-300)
        #     # log_lik_series = np.log(lik_series + 1e-300)
        #     unique_families[(node, parents)] = log_lik_series


   
    # 4) For each root r, assemble per‑DAG joint logs, add log‑prior, then log‑sum‑exp
    log_p_data_given_R = []
    root_causes = []
    for r, dags in sampled_augmented.items():
        sizes = np.array(mec_sizes[r], dtype=float)
        log_p_G = np.log(sizes / sizes.sum() + 1e-300)

        # build an (n_samples × num_dags) array where each column = log P(data|G) + log P(G)
        cols = []
        for i, dag in enumerate(dags):
            # sum cached log‑factors over that DAG's families
            log_joint = sum(
                unique_families[(node, tuple(sorted(dag.predecessors(node))))]
                for node in dag.nodes()
            )
            cols.append(log_joint + log_p_G[i])

        matrix = np.column_stack(cols)
        # log P(data | root=r) = logsumexp over DAGs
        log_p_data_given_R.append(logsumexp(matrix, axis=1))
        root_causes.append(r)

    # stack over roots → shape = (num_roots, num_samples)

    log_p_data_given_R = np.stack(log_p_data_given_R, axis=0)



    # ───────────────────────────────────────────────────────────────
    # 5) Sum over samples to get log-likelihood per root, add uniform prior, normalize
    log_likelihood = log_p_data_given_R.sum(axis=1)            # shape=(num_roots,)
    log_posterior = log_likelihood + np.log(prior)
   
    posterior = np.exp(log_posterior - log_posterior.max())
    # posterior /= posterior.sum() # normalize the posterior
    end_time = time.time()
    elasped = end_time - start_time
    # Get indices that would sort posterior in descending order
    sorted_indices = np.argsort(-posterior)
    # Sort root_causes and posterior accordingly
    sorted_root_causes = [root_causes[i] for i in sorted_indices]
    sorted_posterior = [posterior[i] for i in sorted_indices]
    return sorted_root_causes, sorted_posterior, elasped


    
   
   
   