import random

import numpy as np
import torch

from causal_profiler import CausalProfiler, ErrorMetric
from causal_profiler.space_of_interest import (
    MechanismFamily,
    NoiseDistribution,
    QueryType,
    SpaceOfInterest,
    VariableDataType,
)

from verification.verify_query_estimator_ci import (
    verify_query_estimator_conditional_independence,
)


def create_test_scm_and_estimator(expected_edges):
    """
    Creates a random SCM, data_dict, and query_estimator using the given SpaceOfInterest config.
    Returns:
        data_dict, index_to_var, scm, query_estimator
    """
    # Set seed for reproducility
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    space = SpaceOfInterest(
        number_of_nodes=(4, 4),
        expected_edges=expected_edges,
        variable_dimensionality=(1, 1),
        mechanism_family=MechanismFamily.TABULAR,
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_args=[-1, 1],
        number_of_noise_regions="N",
        # discrete variables with 2 categories each
        number_of_categories=2,
        variable_type=VariableDataType.DISCRETE,
        markovian=True,
        semi_markovian=False,
        causal_graph=None,
        control_positivity=False,
        number_of_queries=1,
        query_type=QueryType.CONDITIONAL,
        number_of_data_points=50,
    )

    profiler = CausalProfiler(
        space_of_interest=space,
        metric=ErrorMetric.L2,
        return_adjacency_matrix=False,
    )

    (data_dict, (_, _), (_, index_to_var)) = profiler.generate_samples_and_queries()
    scm = profiler.sampler._scm
    query_estimator = profiler.sampler.query_estimator

    return data_dict, index_to_var, scm, query_estimator


# def test_ci_js_max_strict():
#     """
#     Test the check_conditional_independence_js function using:
#       - 'max' aggregation
#       - a strict threshold (0.01)
#       - multi_query = True
#     """
#     # Create random SCM, get data
#     data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
#         expected_edges=0  # No edges, everything is independent
#     )

#     # Decide which variables are A, B, and which are in C
#     # For simplicity, let A=X1, B=X2, C=[X3, X4] because there are 4 variables
#     A, B, *C_vars = index_to_var

#     result = check_query_estimator_conditional_independence_js(
#         query_estimator=query_estimator,
#         scm=scm,
#         A=A,
#         B=B,
#         C_set=C_vars,
#         data_dict=data_dict,  # so we use the observed values
#         use_multi_query=False,
#         js_threshold=0.01,  # strict
#         aggregation_method="max",  # we care if *any* c-combo violates independence
#     )

#     print("===== test_ci_js_max_strict =====")
#     print("Variables:", A, B, " | C:", C_vars)
#     print("JS divergences:", result["js_divergences"])
#     print("Skipped c-values:", result["skipped_c_values"])
#     print("Aggregated JS:", result["aggregation"])
#     print("Independence accepted?", result["independence_accepted"])
#     print("================================================\n")
#     assert result["independence_accepted"], "Should accept independence!"


def test_ci_js_mean_moderate():
    """
    Test with 'mean' aggregation and moderate threshold (0.05).
    We'll also try use_multi_query=False to see if it behaves differently.
    """
    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=0  # No edges, everything independent
    )

    A, B, *C_vars = index_to_var

    result = verify_query_estimator_conditional_independence(
        query_estimator=query_estimator,
        scm=scm,
        A=A,
        B=B,
        C_set=C_vars,
        data_dict=data_dict,
        use_multi_query=False,  # each query re-samples data
        js_threshold=0.08,  # 0.05
        aggregation_method="mean",
    )

    print("===== test_ci_js_mean_moderate =====")
    print("Variables:", A, B, " | C:", C_vars)
    print("JS divergences:", result["js_divergences"])
    print("Skipped c-values:", result["skipped_c_values"])
    print("Aggregated JS:", result["aggregation"])
    print("Independence accepted?", result["independence_accepted"])
    print("================================================\n")
    assert result["independence_accepted"], "Should accept independence!"


def test_ci_js_median_lenient():
    """
    Test with 'median' aggregation and a more lenient threshold (0.1).
    We'll use multi_query=True again.
    """
    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=0  # No edges, everything independent
    )

    A, B, *C_vars = index_to_var

    result = verify_query_estimator_conditional_independence(
        query_estimator=query_estimator,
        scm=scm,
        A=A,
        B=B,
        C_set=C_vars,
        data_dict=data_dict,
        use_multi_query=False,
        js_threshold=0.1,
        aggregation_method="median",
    )

    print("===== test_ci_js_median_lenient =====")
    print("Variables:", A, B, " | C:", C_vars)
    print("JS divergences:", result["js_divergences"])
    print("Skipped c-values:", result["skipped_c_values"])
    print("Aggregated JS:", result["aggregation"])
    print("Independence accepted?", result["independence_accepted"])
    print("================================================\n")
    assert result["independence_accepted"], "Should accept independence!"
