from itertools import product
from typing import Dict, List, Optional

import networkx as nx
import numpy as np
from scipy.stats import chi2_contingency
from statsmodels.stats.multitest import multipletests

if __name__ != "__main__":
    from causal_profiler.query import Query


def build_modified_graph(
    scm,
    intervention_variables: List[str] = None,
    observation_variables: List[str] = None,
) -> nx.DiGraph:
    """
    Build the modified graph G_bar_X,underline_Z needed for do-calculus rules.

    Parameters
    ----------
    scm : SCM
        The structural causal model
    intervention_variables : List[str]
        Variables that have been intervened on (X in the notation)
    observation_variables : List[str]
        Variables whose outgoing edges should be removed (Z in the notation)

    Returns
    -------
    nx.DiGraph
        The modified DAG with appropriate edges removed
    """
    # Start with a clean DAG
    G = nx.DiGraph()

    # Add all variables as nodes
    for var_id in scm.variables.keys():
        G.add_node(var_id)

    # Add all edges
    for from_var, to_vars in scm.children.items():
        for to_var in to_vars:
            G.add_edge(from_var, to_var)

    # Remove incoming edges to intervention variables (G_bar_X)
    if intervention_variables:
        edges_to_remove = []
        for intervention_var in intervention_variables:
            for parent in scm.parents.get(intervention_var, []):
                edges_to_remove.append((parent, intervention_var))
        G.remove_edges_from(edges_to_remove)

    # Remove outgoing edges from observation variables (G_underline_Z)
    if observation_variables:
        edges_to_remove = []
        for obs_var in observation_variables:
            for child in scm.children.get(obs_var, []):
                edges_to_remove.append((obs_var, child))
        G.remove_edges_from(edges_to_remove)

    return G


def extract_domain(
    scm, var_name: str, data_dict: Optional[Dict[str, np.ndarray]] = None
):
    """
    Extract the domain (possible values) for a variable.

    Parameters
    ----------
    scm : SCM
        The structural causal model
    var_name : str
        The variable name
    data_dict : Dict[str, np.ndarray], optional
        If provided, extract unique values from this dataset

    Returns
    -------
    List
        The domain of possible values for the variable
    """
    if data_dict is not None:
        # Use unique values from data_dict
        return sorted(np.unique(data_dict[var_name].flatten()))
    else:
        # Use SCM info, domain = [0, 1, ..., num_vals-1]
        num_vals = scm.variables[var_name].num_discrete_values
        return list(range(num_vals))


def construct_contingency_table(left_data, right_data, Y, Y_domain):
    """
    Construct a contingency table using the original data points from the query estimator.

    Parameters
    ----------
    left_data : dict
        Data of first distribution (P)
    right_data : dict
        Data of second distribution (Q)
    Y: str
        Target variable
    Y_domain : List
        Domain of the target variable Y

    Returns
    -------
    np.ndarray
        2 x m contingency table where m is the number of categories
    """
    # Initialize contingency table
    num_categories = len(Y_domain)
    contingency = np.zeros((2, num_categories))

    # Process each Y value to build the contingency table
    for i, y_val in enumerate(Y_domain):
        # Get the raw data for both queries for this Y value
        Y_value_key = f"{Y}_{y_val}"
        left_Y_value_data = left_data.get(Y_value_key, {}).get("Y_values")
        right_Y_value_data = right_data.get(Y_value_key, {}).get("Y_values")

        # If data exists in at least one, count occurrences, skip 0 occurences for chi2
        if left_Y_value_data is not None or right_Y_value_data is not None:
            if left_Y_value_data is None:
                left_Y_value_data = np.array([])
            if right_Y_value_data is None:
                right_Y_value_data = np.array([])
            # Count occurrences of this Y value in both datasets
            contingency[0, i] = np.sum(left_Y_value_data == y_val)
            contingency[1, i] = np.sum(right_Y_value_data == y_val)

    return contingency


def check_koehler_criterion(contingency_table):
    """
    Check if the Koehler criterion is satisfied for the given contingency table.

    The Koehler criterion states that no more than 20% of expected values
    should be less than 1.

    Parameters
    ----------
    contingency_table : np.ndarray
        Contingency table of observed counts

    Returns
    -------
    bool
        True if Koehler criterion is satisfied, False otherwise
    """
    # Check if any row or column sums are zero
    if np.any(contingency_table.sum(axis=0) == 0) or np.any(
        contingency_table.sum(axis=1) == 0
    ):
        return False

    try:
        # Compute the chi-square test to get expected frequencies
        _, _, _, expected = chi2_contingency(contingency_table)

        # Check Koehler criterion
        num_low_expected = (expected < 1).sum()
        total_bins = expected.size
        koehler_fraction = num_low_expected / total_bins

        # Koehler criterion: no more than 20% of expected values < 1
        return koehler_fraction <= 0.20
    except:
        # If chi2_contingency fails, return False
        return False


def perform_chi2_test(left_data, right_data, Y, Y_domain):
    """
    Perform a chi-square test to compare two distributions using original data points.

    Parameters
    ----------
    left_data : dict
        Data of first distribution (P)
    right_data : dict
        Data of second distribution (Q)
    Y: str
        Target variable
    Y_domain : List
        Domain of the target variable Y

    Returns
    -------
    Tuple[float, bool]
        p-value from the chi-square test and a boolean indicating whether
        the test is valid according to the Koehler criterion
    """
    # Construct contingency table using original data
    contingency = construct_contingency_table(left_data, right_data, Y, Y_domain)

    # Remove columns where both distributions have 0 counts
    nonzero_cols = ~(np.all(contingency == 0, axis=0))
    filtered_contingency = contingency[:, nonzero_cols]

    # Check the Koehler criterion
    koehler_satisfied = check_koehler_criterion(contingency)

    if koehler_satisfied:
        try:
            chi2_stat, p_val, _, _ = chi2_contingency(filtered_contingency)
            return p_val, True
        except ValueError:
            return None, False
    else:
        return None, False


def verify_rule1_insertion_deletion_observation(
    query_estimator,
    scm,
    Y: str,
    Z: List[str],
    X: List[str],
    W: List[str],
    data_dict: Optional[Dict[str, np.ndarray]] = None,
    alpha: float = 0.05,
    correction: str = "BH",
    n_samples: int = 10000,
):
    """
    Verify Rule 1 of do-calculus: Insertion/deletion of observation.

    If Y and Z are d-separated by X ∪ W in G_bar_X, then:
    P(Y | do(X), W, Z) = P(Y | do(X), W)

    For each combination of X, W, Z values, a chi-square test is performed to compare
    the distributions P(Y | do(X=x), W=w, Z=z) and P(Y | do(X=x), W=w).
    Multiple testing correction is applied across all combinations to control the error rate.
    The rule holds if fewer than 10% of the valid distributions show a significant difference.

    Parameters
    ----------
    query_estimator : object
        An estimator with methods:
            - evaluate_query(scm, query) -> float
    scm : SCM
        The structural causal model
    Y : str
        Target variable name
    Z : List[str]
        Observation variables to insert/delete
    X : List[str]
        Intervention variables
    W : List[str]
        Additional conditioning variables
    data_dict : Dict[str, np.ndarray], optional
        If provided, extract unique values from this dataset
    alpha : float, default 0.05
        Significance level for the chi-square tests
    correction : str, default "BH"
        Multiple-comparison correction method:
        - 'BH'         : Benjamini-Hochberg (controls false discovery rate)
        - 'Bonferroni' : Bonferroni correction (very conservative)
        - 'none'       : No correction is applied.
    n_samples : int, default 10000
        Number of samples for the query_estimator

    Returns
    -------
    Dict
        Dictionary with test results including:
        - d_separation_holds: whether Y and Z are d-separated by X ∪ W in G_bar_X
        - distribution_results: list of results for each distribution (X,W,Z combination)
        - alpha: significance level
        - correction: correction method
        - rule_holds: whether the rule holds (fewer than 10% of distributions rejected)
        - rejection_rate: fraction of distributions that were rejected
        - skipped_distributions: list of distributions that were skipped due to Koehler criterion
    """
    query_estimator.n_samples = n_samples

    # Enable storing raw data for this verification
    original_store_raw_data = query_estimator.store_raw_data
    query_estimator.store_raw_data = True

    # Check d-separation in G_bar_X
    G_bar_X = build_modified_graph(scm, intervention_variables=X)
    d_separated = nx.d_separated(G_bar_X, {Y}, set(Z), set(X + W))

    if not d_separated:
        # Restore original setting
        query_estimator.store_raw_data = original_store_raw_data

        # If d-separation doesn't hold, rule doesn't apply
        return {
            "d_separation_holds": False,
            "distribution_results": [],
            "alpha": alpha,
            "correction": correction,
            "rule_holds": None,
            "rejection_rate": None,
            "skipped_distributions": [],
        }

    # Extract domains
    Y_domain = extract_domain(scm, Y, data_dict)
    X_domains = [extract_domain(scm, x_var, data_dict) for x_var in X]
    W_domains = [extract_domain(scm, w_var, data_dict) for w_var in W]
    Z_domains = [extract_domain(scm, z_var, data_dict) for z_var in Z]

    # Get all value combinations
    X_combinations = list(product(*X_domains))
    W_combinations = list(product(*W_domains)) if W else [()]
    Z_combinations = list(product(*Z_domains))

    # For each combination of X, W, Z values, perform chi-square test
    distribution_results = []
    skipped_distributions = []

    for x_vals in X_combinations:
        for w_vals in W_combinations:
            for z_vals in Z_combinations:
                # Build and evaluate P(Y | do(X=x), W=w, Z=z)
                for y_val in Y_domain:
                    # Create query for P(Y=y | do(X=x), W=w, Z=z)
                    left_query = Query.createInterventionalProb(
                        Y=[scm.variables[Y]],
                        T=[scm.variables[x] for x in X],
                        X=[scm.variables[w] for w in W] + [scm.variables[z] for z in Z],
                        Y_value=y_val,
                        T_value=list(x_vals),
                        X_value=list(w_vals) + list(z_vals),
                    )
                    query_estimator.evaluate_query(scm, left_query)
                left_data = query_estimator.raw_data
                query_estimator.clear_raw_data()

                # Build and evaluate P(Y | do(X=x), W=w)
                for y_val in Y_domain:
                    # Create query for P(Y=y | do(X=x), W=w)
                    right_query = Query.createInterventionalProb(
                        Y=[scm.variables[Y]],
                        T=[scm.variables[x] for x in X],
                        X=[scm.variables[w] for w in W],
                        Y_value=y_val,
                        T_value=list(x_vals),
                        X_value=list(w_vals),
                    )
                    query_estimator.evaluate_query(scm, right_query)
                right_data = query_estimator.raw_data
                query_estimator.clear_raw_data()

                # Perform chi-square test
                p_val, test_valid = perform_chi2_test(
                    left_data=left_data,
                    right_data=right_data,
                    Y=Y,
                    Y_domain=Y_domain,
                )

                if test_valid:
                    distribution_results.append(
                        {
                            "x_vals": x_vals,
                            "w_vals": w_vals,
                            "z_vals": z_vals,
                            "p_value": p_val,
                        }
                    )
                else:
                    skipped_distributions.append((x_vals, w_vals, z_vals))

    # Restore original setting
    query_estimator.store_raw_data = original_store_raw_data

    # If no valid distributions, return early
    if len(distribution_results) == 0:
        print(f"Warning: All tests were skipped for Rule 1. No valid tests.")
        return {
            "d_separation_holds": True,
            "distribution_results": [],
            "skipped_distributions": skipped_distributions,
            "alpha": alpha,
            "correction": correction,
            "rule_holds": None,
            "rejection_rate": None,
        }

    # Apply Multiple-Comparison Correction across all distributions
    # Extract p-values for correction
    p_values = [result["p_value"] for result in distribution_results]

    if correction == "BH":
        rejected, p_values_corrected, _, _ = multipletests(
            p_values, alpha=alpha, method="fdr_bh"
        )
    elif correction == "Bonferroni":
        rejected, p_values_corrected, _, _ = multipletests(
            p_values, alpha=alpha, method="bonferroni"
        )
    elif correction == "none":
        p_values_corrected = p_values
        rejected = np.array(p_values) < alpha
    else:
        raise ValueError(f"Unknown correction method: {correction}")

    # Update distribution results with corrected p-values and rejection status
    for i, result in enumerate(distribution_results):
        result["p_value_corrected"] = p_values_corrected[i]
        result["rejected"] = rejected[i]

    # Calculate rejection rate (fraction of distributions that were rejected)
    rejection_rate = np.mean(rejected)

    # Rule holds if rejection rate is low (e.g., < 10%)
    rule_holds = rejection_rate < 0.1

    # Warn if many tests were skipped
    if (
        len(skipped_distributions) > len(distribution_results)
        and len(distribution_results) > 0
    ):
        print(
            f"Warning: {len(skipped_distributions)} tests were skipped for Rule 1 due to Koehler criterion."
        )

    # Warn if rejection rate is high
    if rejection_rate > 0.2:
        print(f"Warning: Rule 1 has a high rejection rate of {rejection_rate:.2f}.")

    return {
        "d_separation_holds": True,
        "distribution_results": distribution_results,
        "skipped_distributions": skipped_distributions,
        "alpha": alpha,
        "correction": correction,
        "rule_holds": rule_holds,
        "rejection_rate": rejection_rate,
    }


def verify_rule2_action_observation_exchange(
    query_estimator,
    scm,
    Y: str,
    Z: List[str],
    X: List[str],
    W: List[str],
    data_dict: Optional[Dict[str, np.ndarray]] = None,
    alpha: float = 0.05,
    correction: str = "BH",
    n_samples: int = 10000,
):
    """
    Verify Rule 2 of do-calculus: Action/observation exchange.

    If Y and Z are d-separated by X ∪ W in G_bar_X,underline_Z, then:
    P(Y | do(X), do(Z), W) = P(Y | do(X), Z, W)

    For each combination of X, W, Z values, a chi-square test is performed to compare
    the distributions P(Y | do(X=x), do(Z=z), W=w) and P(Y | do(X=x), Z=z, W=w).
    Multiple testing correction is applied across all combinations to control the error rate.
    The rule holds if fewer than 10% of the valid distributions show a significant difference.

    Parameters
    ----------
    query_estimator : object
        An estimator with methods:
            - evaluate_query(scm, query) -> float
    scm : SCM
        The structural causal model
    Y : str
        Target variable name
    Z : List[str]
        Variables to be changed between intervention and observation
    X : List[str]
        Intervention variables
    W : List[str]
        Additional conditioning variables
    data_dict : Dict[str, np.ndarray], optional
        If provided, extract unique values from this dataset
    alpha : float, default 0.05
        Significance level for the chi-square tests
    correction : str, default "BH"
        Multiple-comparison correction method:
        - 'BH'         : Benjamini-Hochberg (controls false discovery rate)
        - 'Bonferroni' : Bonferroni correction (very conservative)
        - 'none'       : No correction is applied.
    n_samples : int, default 10000
        Number of samples for the query_estimator

    Returns
    -------
    Dict
        Dictionary with test results including:
        - d_separation_holds: whether Y and Z are d-separated by X ∪ W in G_bar_X,underline_Z
        - distribution_results: list of results for each distribution (X,W,Z combination)
        - alpha: significance level
        - correction: correction method
        - rule_holds: whether the rule holds (fewer than 10% of distributions rejected)
        - rejection_rate: fraction of distributions that were rejected
        - skipped_distributions: list of distributions that were skipped due to Koehler criterion
    """
    query_estimator.n_samples = n_samples

    # Enable storing raw data for this verification
    original_store_raw_data = query_estimator.store_raw_data
    query_estimator.store_raw_data = True

    # Check d-separation in G_bar_X,underline_Z
    G_bar_X_underline_Z = build_modified_graph(
        scm, intervention_variables=X, observation_variables=Z
    )
    d_separated = nx.d_separated(G_bar_X_underline_Z, {Y}, set(Z), set(X + W))

    if not d_separated:
        # Restore original setting
        query_estimator.store_raw_data = original_store_raw_data

        # If d-separation doesn't hold, rule doesn't apply
        return {
            "d_separation_holds": False,
            "distribution_results": [],
            "alpha": alpha,
            "correction": correction,
            "rule_holds": None,
            "rejection_rate": None,
            "skipped_distributions": [],
        }

    # Extract domains
    Y_domain = extract_domain(scm, Y, data_dict)
    X_domains = [extract_domain(scm, x_var, data_dict) for x_var in X]
    W_domains = [extract_domain(scm, w_var, data_dict) for w_var in W]
    Z_domains = [extract_domain(scm, z_var, data_dict) for z_var in Z]

    # Get all value combinations
    X_combinations = list(product(*X_domains))
    W_combinations = list(product(*W_domains)) if W else [()]
    Z_combinations = list(product(*Z_domains))

    # For each combination of X, W, Z values, perform chi-square test
    distribution_results = []
    skipped_distributions = []

    for x_vals in X_combinations:
        for w_vals in W_combinations:
            for z_vals in Z_combinations:
                # Build and evaluate P(Y | do(X=x), do(Z=z), W=w)
                for y_val in Y_domain:
                    # Combine X and Z into a single intervention set
                    combined_intervention_vars = [scm.variables[x] for x in X] + [
                        scm.variables[z] for z in Z
                    ]
                    combined_intervention_vals = list(x_vals) + list(z_vals)

                    # Create query for P(Y=y | do(X=x), do(Z=z), W=w)
                    left_query = Query.createInterventionalProb(
                        Y=[scm.variables[Y]],
                        T=combined_intervention_vars,
                        X=[scm.variables[w] for w in W],
                        Y_value=y_val,
                        T_value=combined_intervention_vals,
                        X_value=list(w_vals),
                    )
                    query_estimator.evaluate_query(scm, left_query)
                left_data = query_estimator.raw_data
                query_estimator.clear_raw_data()

                # Build and evaluate P(Y | do(X=x), Z=z, W=w)
                for y_val in Y_domain:
                    # Create query for P(Y=y | do(X=x), Z=z, W=w)
                    right_query = Query.createInterventionalProb(
                        Y=[scm.variables[Y]],
                        T=[scm.variables[x] for x in X],
                        X=[scm.variables[z] for z in Z] + [scm.variables[w] for w in W],
                        Y_value=y_val,
                        T_value=list(x_vals),
                        X_value=list(z_vals) + list(w_vals),
                    )
                    query_estimator.evaluate_query(scm, right_query)
                right_data = query_estimator.raw_data
                query_estimator.clear_raw_data()

                # Perform chi-square test
                p_val, test_valid = perform_chi2_test(
                    left_data=left_data,
                    right_data=right_data,
                    Y=Y,
                    Y_domain=Y_domain,
                )

                if test_valid:
                    distribution_results.append(
                        {
                            "x_vals": x_vals,
                            "w_vals": w_vals,
                            "z_vals": z_vals,
                            "p_value": p_val,
                        }
                    )
                else:
                    skipped_distributions.append((x_vals, w_vals, z_vals))

    # Restore original setting
    query_estimator.store_raw_data = original_store_raw_data

    # If no valid distributions, return early
    if len(distribution_results) == 0:
        print(f"Warning: All tests were skipped for Rule 2. No valid tests.")
        return {
            "d_separation_holds": True,
            "distribution_results": [],
            "skipped_distributions": skipped_distributions,
            "alpha": alpha,
            "correction": correction,
            "rule_holds": None,
            "rejection_rate": None,
        }

    # Apply Multiple-Comparison Correction across all distributions
    # Extract p-values for correction
    p_values = [result["p_value"] for result in distribution_results]

    if correction == "BH":
        rejected, p_values_corrected, _, _ = multipletests(
            p_values, alpha=alpha, method="fdr_bh"
        )
    elif correction == "Bonferroni":
        rejected, p_values_corrected, _, _ = multipletests(
            p_values, alpha=alpha, method="bonferroni"
        )
    elif correction == "none":
        p_values_corrected = p_values
        rejected = np.array(p_values) < alpha
    else:
        raise ValueError(f"Unknown correction method: {correction}")

    # Update distribution results with corrected p-values and rejection status
    for i, result in enumerate(distribution_results):
        result["p_value_corrected"] = p_values_corrected[i]
        result["rejected"] = rejected[i]

    # Calculate rejection rate (fraction of distributions that were rejected)
    rejection_rate = np.mean(rejected)

    # Rule holds if rejection rate is low (e.g., < 10%)
    rule_holds = rejection_rate < 0.1

    # Warn if many tests were skipped
    if (
        len(skipped_distributions) > len(distribution_results)
        and len(distribution_results) > 0
    ):
        print(
            f"Warning: {len(skipped_distributions)} tests were skipped for Rule 2 due to Koehler criterion."
        )

    # Warn if rejection rate is high
    if rejection_rate > 0.2:
        print(f"Warning: Rule 2 has a high rejection rate of {rejection_rate:.2f}.")

    return {
        "d_separation_holds": True,
        "distribution_results": distribution_results,
        "skipped_distributions": skipped_distributions,
        "alpha": alpha,
        "correction": correction,
        "rule_holds": rule_holds,
        "rejection_rate": rejection_rate,
    }


def verify_rule3_insertion_deletion_action(
    query_estimator,
    scm,
    Y: str,
    Z: List[str],
    X: List[str],
    W: List[str],
    data_dict: Optional[Dict[str, np.ndarray]] = None,
    alpha: float = 0.05,
    correction: str = "BH",
    n_samples: int = 10000,
):
    """
    Verify Rule 3 of do-calculus: Insertion/deletion of action.

    If Y and Z are d-separated by X ∪ W in G_bar_X,bar_Z(W), then:
    P(Y | do(X), do(Z), W) = P(Y | do(X), W)

    where Z(W) is the set of Z variables that are not ancestors of any W in G_bar_X.

    For each combination of X, W, Z values, a chi-square test is performed to compare
    the distributions P(Y | do(X=x), do(Z=z), W=w) and P(Y | do(X=x), W=w).
    Multiple testing correction is applied across all combinations to control the error rate.
    The rule holds if fewer than 10% of the valid distributions show a significant difference.

    Parameters
    ----------
    query_estimator : object
        An estimator with methods:
            - evaluate_query(scm, query) -> float
    scm : SCM
        The structural causal model
    Y : str
        Target variable name
    Z : List[str]
        Variables to be inserted/deleted as interventions
    X : List[str]
        Intervention variables
    W : List[str]
        Additional conditioning variables
    data_dict : Dict[str, np.ndarray], optional
        If provided, extract unique values from this dataset
    alpha : float, default 0.05
        Significance level for the chi-square tests
    correction : str, default "BH"
        Multiple-comparison correction method:
        - 'BH'         : Benjamini-Hochberg (controls false discovery rate)
        - 'Bonferroni' : Bonferroni correction (very conservative)
        - 'none'       : No correction is applied.
    n_samples : int, default 10000
        Number of samples for the query_estimator

    Returns
    -------
    Dict
        Dictionary with test results including:
        - d_separation_holds: whether Y and Z are d-separated by X ∪ W in G_bar_X,bar_Z(W)
        - distribution_results: list of results for each distribution (X,W,Z combination)
        - alpha: significance level
        - correction: correction method
        - rule_holds: whether the rule holds (fewer than 10% of distributions rejected)
        - rejection_rate: fraction of distributions that were rejected
        - skipped_distributions: list of distributions that were skipped due to Koehler criterion
        - z_w: the subset of Z that are not ancestors of W in G_bar_X
    """
    query_estimator.n_samples = n_samples

    # Enable storing raw data for this verification
    original_store_raw_data = query_estimator.store_raw_data
    query_estimator.store_raw_data = True

    # First, build G_bar_X
    G_bar_X = build_modified_graph(scm, intervention_variables=X)

    # Find Z(W): subset of Z not ancestors of any W in G_bar_X
    Z_set = set(Z)

    # If W is empty, Z(W) = Z
    if not W:
        Z_W = Z.copy()
    else:
        # For each W, find ancestors in G_bar_X
        ancestors_of_W = set()
        for w in W:
            if w in G_bar_X.nodes():
                ancestors_of_W.update(nx.ancestors(G_bar_X, w))
                ancestors_of_W.add(w)  # Include W itself

        # Z(W) = Z - (Z ∩ ancestors_of_W)
        Z_W = list(Z_set - (Z_set & ancestors_of_W))

    # Now build G_bar_X,bar_Z(W)
    G_bar_X_bar_Z_W = build_modified_graph(scm, intervention_variables=X + Z_W)

    # Check d-separation
    d_separated = nx.d_separated(G_bar_X_bar_Z_W, {Y}, set(Z), set(X + W))

    if not d_separated:
        # Restore original setting
        query_estimator.store_raw_data = original_store_raw_data

        # If d-separation doesn't hold, rule doesn't apply
        return {
            "d_separation_holds": False,
            "distribution_results": [],
            "alpha": alpha,
            "correction": correction,
            "rule_holds": None,
            "rejection_rate": None,
            "skipped_distributions": [],
            "z_w": Z_W,
        }

    # Extract domains
    Y_domain = extract_domain(scm, Y, data_dict)
    X_domains = [extract_domain(scm, x_var, data_dict) for x_var in X]
    W_domains = [extract_domain(scm, w_var, data_dict) for w_var in W]
    Z_domains = [extract_domain(scm, z_var, data_dict) for z_var in Z]

    # Get all value combinations
    X_combinations = list(product(*X_domains))
    W_combinations = list(product(*W_domains)) if W else [()]
    Z_combinations = list(product(*Z_domains))

    # For each combination of X, W, Z values, perform chi-square test
    distribution_results = []
    skipped_distributions = []

    for x_vals in X_combinations:
        for w_vals in W_combinations:
            for z_vals in Z_combinations:
                # Build and evaluate P(Y | do(X=x), do(Z=z), W=w)
                for y_val in Y_domain:
                    # Combine X and Z into a single intervention set
                    combined_intervention_vars = [scm.variables[x] for x in X] + [
                        scm.variables[z] for z in Z
                    ]
                    combined_intervention_vals = list(x_vals) + list(z_vals)

                    # Create query for P(Y=y | do(X=x), do(Z=z), W=w)
                    left_query = Query.createInterventionalProb(
                        Y=[scm.variables[Y]],
                        T=combined_intervention_vars,
                        X=[scm.variables[w] for w in W],
                        Y_value=y_val,
                        T_value=combined_intervention_vals,
                        X_value=list(w_vals),
                    )
                    query_estimator.evaluate_query(scm, left_query)
                left_data = query_estimator.raw_data
                query_estimator.clear_raw_data()

                # Build and evaluate P(Y | do(X=x), W=w)
                for y_val in Y_domain:
                    # Create query for P(Y=y | do(X=x), W=w)
                    right_query = Query.createInterventionalProb(
                        Y=[scm.variables[Y]],
                        T=[scm.variables[x] for x in X],
                        X=[scm.variables[w] for w in W],
                        Y_value=y_val,
                        T_value=list(x_vals),
                        X_value=list(w_vals),
                    )
                    query_estimator.evaluate_query(scm, right_query)
                right_data = query_estimator.raw_data
                query_estimator.clear_raw_data()

                # Perform chi-square test
                p_val, test_valid = perform_chi2_test(
                    left_data=left_data,
                    right_data=right_data,
                    Y=Y,
                    Y_domain=Y_domain,
                )

                if test_valid:
                    distribution_results.append(
                        {
                            "x_vals": x_vals,
                            "w_vals": w_vals,
                            "z_vals": z_vals,
                            "p_value": p_val,
                        }
                    )
                else:
                    skipped_distributions.append((x_vals, w_vals, z_vals))

    # Restore original setting
    query_estimator.store_raw_data = original_store_raw_data

    # If no valid distributions, return early
    if len(distribution_results) == 0:
        print(f"Warning: All tests were skipped for Rule 3. No valid tests.")
        return {
            "d_separation_holds": True,
            "distribution_results": [],
            "skipped_distributions": skipped_distributions,
            "alpha": alpha,
            "correction": correction,
            "rule_holds": None,
            "rejection_rate": None,
            "z_w": Z_W,
        }

    # Apply Multiple-Comparison Correction across all distributions
    # Extract p-values for correction
    p_values = [result["p_value"] for result in distribution_results]

    if correction == "BH":
        rejected, p_values_corrected, _, _ = multipletests(
            p_values, alpha=alpha, method="fdr_bh"
        )
    elif correction == "Bonferroni":
        rejected, p_values_corrected, _, _ = multipletests(
            p_values, alpha=alpha, method="bonferroni"
        )
    elif correction == "none":
        p_values_corrected = p_values
        rejected = np.array(p_values) < alpha
    else:
        raise ValueError(f"Unknown correction method: {correction}")

    # Update distribution results with corrected p-values and rejection status
    for i, result in enumerate(distribution_results):
        result["p_value_corrected"] = p_values_corrected[i]
        result["rejected"] = rejected[i]

    # Calculate rejection rate (fraction of distributions that were rejected)
    rejection_rate = np.mean(rejected)

    # Rule holds if rejection rate is zero.
    rule_holds = rejection_rate == 0

    # Warn if many tests were skipped
    if (
        len(skipped_distributions) > len(distribution_results)
        and len(distribution_results) > 0
    ):
        print(
            f"Warning: {len(skipped_distributions)} tests were skipped for Rule 3 due to Koehler criterion."
        )

    # Warn if rejection rate is high
    if rejection_rate > 0.2:
        print(f"Warning: Rule 3 has a high rejection rate of {rejection_rate:.2f}.")

    return {
        "d_separation_holds": True,
        "distribution_results": distribution_results,
        "skipped_distributions": skipped_distributions,
        "alpha": alpha,
        "correction": correction,
        "rule_holds": rule_holds,
        "rejection_rate": rejection_rate,
        "z_w": Z_W,
    }


def verify_do_calculus(
    query_estimator,
    scm,
    Y: str,
    Z: List[str],
    X: List[str],
    W: List[str],
    data_dict: Optional[Dict[str, np.ndarray]] = None,
    alpha: float = 0.05,
    correction: str = "BH",
    n_samples: int = 10000,
):
    """
    Verify all three rules of the do-calculus for the given variables.

    For each rule and each combination of X, W, Z values, a chi-square test is performed
    to compare the relevant distributions. Multiple testing correction is applied across
    all combinations to control the error rate. A rule holds if fewer than 10% of valid
    distributions show a significant difference.

    Parameters
    ----------
    query_estimator : object
        An estimator with methods:
            - evaluate_query(scm, query) -> float
    scm : SCM
        The structural causal model
    Y : str
        Target variable name
    Z : List[str]
        Variables for Z in do-calculus rules
    X : List[str]
        Intervention variables
    W : List[str]
        Additional conditioning variables
    data_dict : Dict[str, np.ndarray], optional
        If provided, extract unique values from this dataset
    alpha : float, default 0.05
        Significance level for the chi-square tests
    correction : str, default "BH"
        Multiple-comparison correction method:
        - 'BH'         : Benjamini-Hochberg (controls false discovery rate)
        - 'Bonferroni' : Bonferroni correction (very conservative)
        - 'none'       : No correction is applied.
    n_samples : int, default 10000
        Number of samples for the query_estimator

    Returns
    -------
    Dict
        Dictionary with test results for all three rules:
        - rule1_results: results for Rule 1 (insertion/deletion of observation)
        - rule2_results: results for Rule 2 (action/observation exchange)
        - rule3_results: results for Rule 3 (insertion/deletion of action)
        - variables: dictionary containing the variable names used
    """
    rule1_results = verify_rule1_insertion_deletion_observation(
        query_estimator=query_estimator,
        scm=scm,
        Y=Y,
        Z=Z,
        X=X,
        W=W,
        data_dict=data_dict,
        alpha=alpha,
        correction=correction,
        n_samples=n_samples,
    )

    rule2_results = verify_rule2_action_observation_exchange(
        query_estimator=query_estimator,
        scm=scm,
        Y=Y,
        Z=Z,
        X=X,
        W=W,
        data_dict=data_dict,
        alpha=alpha,
        correction=correction,
        n_samples=n_samples,
    )

    rule3_results = verify_rule3_insertion_deletion_action(
        query_estimator=query_estimator,
        scm=scm,
        Y=Y,
        Z=Z,
        X=X,
        W=W,
        data_dict=data_dict,
        alpha=alpha,
        correction=correction,
        n_samples=n_samples,
    )

    return {
        "rule1_results": rule1_results,
        "rule2_results": rule2_results,
        "rule3_results": rule3_results,
        "variables": {
            "Y": Y,
            "Z": Z,
            "X": X,
            "W": W,
        },
    }


if __name__ == "__main__":
    import os
    import random
    import sys

    import torch

    current_dir = os.path.dirname(os.path.realpath(__file__))
    # Get the root directory of the project
    project_root = os.path.abspath(
        os.path.join(current_dir, os.pardir, "causal_profiler")
    )
    # Add the project root directory to the Python path
    sys.path.insert(0, project_root)

    from query import Query
    from space_of_interest import (
        MechanismFamily,
        NoiseDistribution,
        QueryType,
        SpaceOfInterest,
        VariableDataType,
    )

    from causal_profiler import CausalProfiler, ErrorMetric

    # Set seed for reproducibility
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Generate an SCM, data, and a query estimator
    space = SpaceOfInterest(
        number_of_nodes=(4, 4),
        expected_edges=2,
        variable_dimensionality=(1, 1),
        mechanism_family=MechanismFamily.TABULAR,
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_args=[-1, 1],
        number_of_noise_regions="N",
        number_of_categories=2,
        variable_type=VariableDataType.DISCRETE,
        markovian=True,
        semi_markovian=False,
        causal_graph=None,
        control_positivity=False,
        number_of_queries=1,
        query_type=QueryType.CONDITIONAL,
        number_of_data_points=50,
    )
    profiler = CausalProfiler(
        space_of_interest=space,
        metric=ErrorMetric.L2,
        return_adjacency_matrix=False,
    )

    (data_dict, (_, _), (_, index_to_var)) = profiler.generate_samples_and_queries()
    scm = profiler.sampler._scm
    query_estimator = profiler.sampler.query_estimator

    # Set up variables for testing
    if len(index_to_var) >= 4:
        Y, Z, X, W = (
            index_to_var[0],
            [index_to_var[1]],
            [index_to_var[2]],
            [index_to_var[3]],
        )
    else:
        # If not enough variables, adjust accordingly
        Y = index_to_var[0]
        Z = [index_to_var[1]] if len(index_to_var) > 1 else []
        X = [index_to_var[2]] if len(index_to_var) > 2 else []
        W = []

    # Test the do-calculus verification
    results = verify_do_calculus(
        query_estimator=query_estimator,
        scm=scm,
        Y=Y,
        Z=Z,
        X=X,
        W=W,
        data_dict=data_dict,
        alpha=0.05,
        correction="BH",
        n_samples=10000,
    )

    print("DO-CALCULUS VERIFICATION RESULTS:")
    print(f"Variables: Y={Y}, Z={Z}, X={X}, W={W}")

    # Prepare summary table data
    table_data = []
    rule_names = [
        "Rule 1 (Insertion/deletion of observation)",
        "Rule 2 (Action/observation exchange)",
        "Rule 3 (Insertion/deletion of action)",
    ]

    rule_results = [
        results["rule1_results"],
        results["rule2_results"],
        results["rule3_results"],
    ]

    # Print detailed results for each rule
    for i, (rule_name, rule_result) in enumerate(zip(rule_names, rule_results)):
        print(f"\n{rule_name}:")
        print(f"D-separation holds: {rule_result['d_separation_holds']}")

        if rule_result["d_separation_holds"]:
            # Only include statistical test results if d-separation holds
            total_tests = len(rule_result.get("distribution_results", [])) + len(
                rule_result.get("skipped_distributions", [])
            )
            skipped = len(rule_result.get("skipped_distributions", []))
            valid_tests = total_tests - skipped

            # Get rejected count
            rejected = 0
            if (
                "distribution_results" in rule_result
                and rule_result["distribution_results"]
            ):
                rejected = sum(
                    1
                    for dist in rule_result["distribution_results"]
                    if dist.get("rejected", False)
                )

            passed = valid_tests - rejected

            # Calculate rejection rate
            rejection_rate = rejected / valid_tests if valid_tests > 0 else 0.0

            # Add row to table data
            table_data.append(
                [
                    rule_name,
                    total_tests,
                    passed,
                    rejected,
                    skipped,
                    f"{rejection_rate:.4f}",
                ]
            )

            # Print detailed info
            print(f"Rule holds: {rule_result.get('rule_holds')}")
            print(f"Rejection rate: {rule_result.get('rejection_rate', 0.0):.4f}")
            print(f"Number of skipped tests: {skipped}")
            if i == 2 and "z_w" in rule_result:  # For Rule 3
                print(f"Z(W): {rule_result.get('z_w', [])}")
        else:
            # If d-separation doesn't hold, add a row with NA values
            table_data.append([rule_name, "NA", "NA", "NA", "NA", "NA"])

    # Print summary table
    print("\n" + "=" * 70)
    print("DO-CALCULUS SUMMARY TABLE")
    print("=" * 70)

    # Print header
    header = ["Rule", "Total Tests", "Passed", "Failed", "Skipped", "Rejection Rate"]
    header_format = "{:<40} {:<12} {:<12} {:<12} {:<12} {:<12}"
    row_format = "{:<40} {:<12} {:<12} {:<12} {:<12} {:<12}"

    print(header_format.format(*header))
    print("-" * 70)

    # Print rows
    for row in table_data:
        print(row_format.format(*[str(cell) for cell in row]))

    print("=" * 70)
