import numpy as np
import pandas as pd
from scipy.stats import chi2_contingency
from statsmodels.stats.multitest import multipletests


def verify_data_conditional_independence(
    data_dict, A, B, C_set, alpha=0.05, correction="BH"
):
    """
    Test whether A and B are conditionally independent given a set of categorical variables C_set
    using a Chi-square test for each unique combination of C values. Implements the Koehler Criterion
    to ensure test reliability.

    Parameters
    ----------
    data_dict : dict
        Dictionary mapping variable names (str) to arrays (shape: [n_samples, 1]) or [n_samples].
        Example:
            {
                'X1': array([[0], [1], [2]]),
                'X2': array([[3], [4], [5]]),
                'X3': array([[6], [7], [8]])
            }
    A, B : str
        Names of the two categorical variables to test for conditional independence.
    C_set : list of str
        List of categorical variables that form the conditioning set (C).
    alpha : float, optional
        Significance level for the Chi-square test. Default is 0.05.
    correction : {'BH', 'Bonferroni', 'none'}, optional
        Multiple-comparison correction method:
        - 'BH'         : Benjamini-Hochberg (controls false discovery rate)
        - 'Bonferroni' : Bonferroni correction (very conservative)
        - 'none'       : No correction is applied.

    Returns
    -------
    result_dict : dict
        Dictionary with test results:
        - 'distribution_results': Dict mapping each C combination to test results (p_value, rejected, etc.)
        - 'skipped_distributions': List of C values where the test was skipped (e.g., Koehler criterion failed)
        - 'alpha': Significance level used
        - 'correction': Correction method used
        - 'independence_accepted': Boolean indicating if conditional independence holds (True if independence is accepted)
        - 'rejection_rate': Fraction of distributions for which independence was rejected
    """

    # Convert input to numpy arrays
    A_vals = np.asarray(data_dict[A]).flatten()
    B_vals = np.asarray(data_dict[B]).flatten()

    # Convert multiple C variables into a tuple representation
    C_vals = np.vstack(
        [np.asarray(data_dict[C_i]).flatten() for C_i in C_set]
    ).T  # Shape (n_samples, |C_set|)

    # Convert to Pandas DataFrame
    df = pd.DataFrame({A: A_vals, B: B_vals})
    for i, C_var in enumerate(C_set):
        df[C_var] = C_vals[:, i]  # Assign each column for C variables

    # Get unique (C1, C2, ...) value combinations observed in the dataset
    unique_c_values = list(
        df[C_set].drop_duplicates().itertuples(index=False, name=None)
    )

    # Initialize results structure
    distribution_results = {}
    skipped_distributions = []
    failed_koehler_tests = 0
    total_tests = len(unique_c_values)

    # Store raw p-values for multiple testing correction
    raw_pvals = []
    c_value_to_index = {}  # Map C value to index in raw_pvals list

    for i, c_tuple in enumerate(unique_c_values):
        # Subset data where C variables match this specific tuple
        c_filter = (df[C_set] == pd.Series(c_tuple, index=C_set)).all(axis=1)
        sub_df = df[c_filter]

        # Compute contingency table for A vs B given this C combination
        contingency_table = pd.crosstab(sub_df[A], sub_df[B])

        # If there is not enough data to construct a valid table, skip
        # We need at least 2 categories per variable to do chi2
        if contingency_table.shape[0] < 2 or contingency_table.shape[1] < 2:
            skipped_distributions.append(c_tuple)
            continue

        # Compute chi-square test and expected counts
        chi2_stat, p_val, dof, expected = chi2_contingency(contingency_table)

        # Apply Koehler Criterion
        num_low_expected = (expected < 1).sum()
        total_bins = expected.size
        koehler_fraction = num_low_expected / total_bins

        if koehler_fraction > 0.20:
            failed_koehler_tests += 1
            skipped_distributions.append(c_tuple)
            continue  # Skip this chi-square test

        # Store the p-value and its index for later correction
        raw_pvals.append(p_val)
        c_value_to_index[c_tuple] = len(raw_pvals) - 1

        # Initialize distribution result with raw p-value
        distribution_results[c_tuple] = {
            "p_value": p_val,
            "p_value_corrected": None,  # Will be updated after correction
            "rejected": None,  # Will be updated after correction
            "chi2_statistic": chi2_stat,
            "degrees_of_freedom": dof,
        }

    # If too many tests fail the Koehler criterion, print a warning
    koehler_failure_rate = failed_koehler_tests / total_tests if total_tests > 0 else 0
    if koehler_failure_rate > 0.20 and total_tests >= 5:
        print(
            f"Warning: {koehler_failure_rate*100:.1f}% of tests failed the Koehler criterion. "
        )

    # Handle the case where all tests were skipped
    if len(raw_pvals) == 0:
        print(f"Warning: All tests were skipped. ")
        return {
            "distribution_results": {},
            "skipped_distributions": skipped_distributions,
            "alpha": alpha,
            "correction": correction,
            "independence_accepted": None,  # Can't determine if no tests were performed
            "rejection_rate": None,
        }

    # Apply Multiple-Comparison Correction
    if correction == "BH":
        reject_mask, pvals_corrected, _, _ = multipletests(
            raw_pvals, alpha=alpha, method="fdr_bh"
        )
    elif correction == "Bonferroni":
        reject_mask, pvals_corrected, _, _ = multipletests(
            raw_pvals, alpha=alpha, method="bonferroni"
        )
    elif correction == "fractional":
        # at least 5%, at least equal to alpha, if alpha is very small
        # e.g., 0.001 you don't want a tiny fraction to decide independence
        pvals_corrected = None
        reject_raw = np.array(raw_pvals) < alpha
        fraction_rejected = reject_raw.mean()  # Compute fraction of rejected tests
        rejection_threshold = max(alpha, 0.05)

        # Update distribution results with rejection status based on fractional threshold
        for c_tuple, idx in c_value_to_index.items():
            distribution_results[c_tuple]["rejected"] = raw_pvals[idx] < alpha

        # One global decision based on fraction of rejected tests
        independence_accepted = fraction_rejected <= rejection_threshold
        rejection_rate = fraction_rejected
    elif correction == "none":
        pvals_corrected = None
        # Update distribution results with rejection status
        for c_tuple, idx in c_value_to_index.items():
            distribution_results[c_tuple]["rejected"] = raw_pvals[idx] < alpha

        # Count rejections
        rejections = sum(
            1 for result in distribution_results.values() if result["rejected"]
        )
        rejection_rate = (
            rejections / len(distribution_results) if distribution_results else 0
        )
        independence_accepted = rejection_rate <= alpha
    else:
        raise ValueError(f"Unknown correction method: {correction}")

    # Update distribution results with corrected p-values and rejection status for BH and Bonferroni
    if correction in ["BH", "Bonferroni"]:
        for c_tuple, idx in c_value_to_index.items():
            distribution_results[c_tuple]["p_value_corrected"] = pvals_corrected[idx]
            distribution_results[c_tuple]["rejected"] = reject_mask[idx]

        # Count rejections
        rejections = sum(
            1 for result in distribution_results.values() if result["rejected"]
        )
        rejection_rate = (
            rejections / len(distribution_results) if distribution_results else 0
        )
        independence_accepted = (
            rejection_rate == 0
        )  # Consider rule holds only if no rejection

    # Construct results dictionary
    result_dict = {
        "distribution_results": distribution_results,
        "skipped_distributions": skipped_distributions,
        "alpha": alpha,
        "correction": correction,
        "independence_accepted": independence_accepted,
        "rejection_rate": rejection_rate,
    }

    return result_dict


if __name__ == "__main__":
    # Generate a dataset (5000 samples)
    np.random.seed(42)
    num_samples = 5000

    data_dict = {
        "X1": np.random.randint(0, 5, num_samples),  # A has 5 categories
        "X2": np.random.randint(0, 5, num_samples),  # B has 5 categories
        "C1": np.random.randint(0, 4, num_samples),  # C1 has 4 categories
        "C2": np.random.randint(0, 3, num_samples),  # C2 has 3 categories
    }

    results = verify_data_conditional_independence(
        data_dict, A="X1", B="X2", C_set=["C1", "C2"], alpha=0.05, correction="BH"
    )

    print(
        "Skipped C values due to Koehler Criterion:",
        len(results["skipped_distributions"]),
    )
    print(
        "Final Decision: Conditional independence is",
        "ACCEPTED" if results["independence_accepted"] else "REJECTED",
    )
