import random

import numpy as np
import pytest
import torch

from causal_profiler import CausalProfiler, ErrorMetric
from causal_profiler.constants import FunctionSampling
from causal_profiler.space_of_interest import (
    MechanismFamily,
    NoiseDistribution,
    QueryType,
    SpaceOfInterest,
    VariableDataType,
)

from verification.verify_query_estimator_conditionals import (
    verify_query_estimator_conditionals,
)


def create_test_scm_and_estimator(
    expected_edges, number_of_data_points, number_of_categories=2
):
    """
    Creates a random SCM, data_dict, and query_estimator using the given SpaceOfInterest config.

    Argument:
        expected_edges:
            Expected number of edges in SCM
        number_of_data_points:
            Number of data points in sampled dataset from SCM
            (the bigger the more accurate the estimates will be)
    Returns:
        data_dict, index_to_var, scm, query_estimator
    """
    # Set seed for reproducility
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    space = SpaceOfInterest(
        number_of_nodes=(10, 10),
        expected_edges=expected_edges,
        variable_dimensionality=(1, 1),
        mechanism_family=MechanismFamily.TABULAR,
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_args=[-1, 1],
        number_of_noise_regions="N",
        # discrete variables with some categories each
        number_of_categories=number_of_categories,
        variable_type=VariableDataType.DISCRETE,
        markovian=True,
        semi_markovian=False,
        causal_graph=None,
        control_positivity=False,
        number_of_queries=1,
        query_type=QueryType.CONDITIONAL,
        number_of_data_points=number_of_data_points,
    )

    profiler = CausalProfiler(
        space_of_interest=space,
        metric=ErrorMetric.L2,
        return_adjacency_matrix=False,
    )
    profiler.sampler.function_sampling = FunctionSampling.RANDOM

    (data_dict, (_, _), (_, index_to_var)) = profiler.generate_samples_and_queries()
    scm = profiler.sampler._scm
    query_estimator = profiler.sampler.query_estimator

    return data_dict, index_to_var, scm, query_estimator


@pytest.mark.parametrize(
    "expected_edges,number_of_data_points",
    [
        (0, 100),
        (12, 10000),
    ],
)
def test_query_estimator_conditionals(expected_edges, number_of_data_points):
    """Test for verifying conditional independence estimation."""

    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=expected_edges, number_of_data_points=number_of_data_points
    )

    result = verify_query_estimator_conditionals(
        query_estimator=query_estimator,
        scm=scm,
        A=index_to_var[0],
        C=index_to_var[1:3],
        data_dict=data_dict,
        use_multi_query=False,
        js_threshold=0.05,
        aggregation_method="mean",
    )

    assert (
        result["agreement_accepted"] == True
    ), f"Test failed for {expected_edges, number_of_data_points}"


def test_small_dataset():
    """Test handling of a small dataset (edge case)."""
    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=0,
        number_of_data_points=5,  # only 5 data points
        number_of_categories=5,
    )

    result = verify_query_estimator_conditionals(
        query_estimator=query_estimator,
        scm=scm,
        A=index_to_var[0],
        C=index_to_var[1:3],
        data_dict=data_dict,
        use_multi_query=False,
        js_threshold=0.05,
        aggregation_method="mean",
    )

    assert not result["agreement_accepted"], "Small dataset should break the estimates."


def test_empty_dataset():
    """Test handling of an empty dataset."""
    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=0, number_of_data_points=1  # 0 data points
    )

    # Replace each NumPy array with an empty array
    for key in data_dict:
        data_dict[key] = np.array([])

    result = verify_query_estimator_conditionals(
        query_estimator=query_estimator,
        scm=scm,
        A=index_to_var[0],
        C=index_to_var[1:3],
        data_dict=data_dict,
        use_multi_query=False,
        js_threshold=0.05,
        aggregation_method="mean",
    )

    assert result["agreement_accepted"] is None, "Empty dataset should return None."
