from pprint import pprint

import numpy as np

from verification.verify_data_ci import verify_data_conditional_independence


# Helper functions for generating datasets
def generate_independent_data(num_samples=5000, num_categories=5, num_c_values=4):
    """Generate a dataset where A and B are independent given C."""
    np.random.seed(42)
    return {
        "A": np.random.randint(0, num_categories, num_samples),
        "B": np.random.randint(0, num_categories, num_samples),
        "C1": np.random.randint(0, num_c_values, num_samples),
        "C2": np.random.randint(0, num_c_values, num_samples),
    }


def generate_dependent_data(num_samples=5000, num_categories=5, num_c_values=4):
    """Generate a dataset where A and B are dependent given C."""
    np.random.seed(42)
    C1 = np.random.randint(0, num_c_values, num_samples)
    C2 = np.random.randint(0, num_c_values, num_samples)
    A = np.random.randint(0, num_categories, num_samples)
    B = (A + C1 + C2) % num_categories  # Introduce dependence
    return {"A": A, "B": B, "C1": C1, "C2": C2}


# Tests
def test_independence_case():
    """Test correct acceptance of independence."""
    data_dict = generate_independent_data()
    results = verify_data_conditional_independence(
        data_dict, A="A", B="B", C_set=["C1", "C2"]
    )
    assert results["independence_accepted"], "Should accept independence!"


def test_dependence_case():
    """Test correct rejection of dependence."""
    data_dict = generate_dependent_data()
    results = verify_data_conditional_independence(
        data_dict, A="A", B="B", C_set=["C1", "C2"]
    )
    assert not results["independence_accepted"], "Should reject dependence!"


def test_small_dataset():
    """Test handling of a small dataset."""
    data_dict = {
        "A": np.array([0, 1, 2]),
        "B": np.array([1, 2, 0]),
        "C1": np.array([0, 0, 1]),
    }
    results = verify_data_conditional_independence(
        data_dict, A="A", B="B", C_set=["C1"]
    )
    assert (
        results["independence_accepted"] is None
    ), "All tests should be skipped because of Koehler criterion!"


def test_large_dataset():
    """Test performance on a large dataset."""
    data_dict = generate_independent_data(num_samples=100000)
    results = verify_data_conditional_independence(
        data_dict, A="A", B="B", C_set=["C1", "C2"]
    )
    assert results[
        "independence_accepted"
    ], "Should accept independence for a large dataset!"


def test_identical_values():
    """Test handling of a dataset with all identical values."""
    num_samples = 5000
    data_dict = {
        "A": np.ones(num_samples),
        "B": np.ones(num_samples),
        "C1": np.ones(num_samples),
        "C2": np.ones(num_samples),
    }
    results = verify_data_conditional_independence(
        data_dict, A="A", B="B", C_set=["C1", "C2"]
    )
    pprint(results)
    assert (
        results["independence_accepted"] is None
    ), "Identical values should be rejected by the Koehler criterion!"


def test_empty_dataset():
    """Test handling of an empty dataset."""
    data_dict = {"A": np.array([]), "B": np.array([]), "C1": np.array([])}
    results = verify_data_conditional_independence(
        data_dict, A="A", B="B", C_set=["C1"]
    )
    assert results["independence_accepted"] is None, "Empty dataset should return None!"


# Removed the ability to provide your own threshold for now
# @pytest.mark.parametrize("threshold,expected", [(0.01, False), (1.0, True)])
# def test_threshold_effect(threshold, expected):
#     """Test the effect of max_fraction_rejected parameter."""
#     data_dict = generate_dependent_data()
#     results = check_conditional_independence(
#         data_dict, A="A", B="B", C_set=["C1", "C2"], max_fraction_rejected=threshold
#     )
#     assert (
#         results["independence_accepted"] == expected
#     ), f"Threshold {threshold} should {'accept' if expected else 'reject'} dependence!"
