import random

import numpy as np
import pytest
import torch

from verification.verify_structural_counterfactual_axioms import (
    verify_structural_counterfactual_axioms,
)


def create_test_scm(expected_edges):
    """
    Creates a random SCM, data_dict, and query_estimator using the given SpaceOfInterest config.
    Returns:
        data_dict, index_to_var, scm, query_estimator
    """
    from causal_profiler import CausalProfiler, ErrorMetric
    from causal_profiler.space_of_interest import (
        MechanismFamily,
        NoiseDistribution,
        QueryType,
        SpaceOfInterest,
        VariableDataType,
    )

    # Fix seeds for reproducibility
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    space = SpaceOfInterest(
        number_of_nodes=(4, 4),
        expected_edges=expected_edges,
        variable_dimensionality=(1, 1),
        mechanism_family=MechanismFamily.TABULAR,
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_args=[-1, 1],
        number_of_noise_regions="N",
        # discrete variables with 2 categories each
        number_of_categories=2,
        variable_type=VariableDataType.DISCRETE,
        markovian=True,
        semi_markovian=False,
        causal_graph=None,
        control_positivity=False,
        number_of_queries=1,
        query_type=QueryType.CONDITIONAL,
        number_of_data_points=50,
    )

    profiler = CausalProfiler(
        space_of_interest=space,
        metric=ErrorMetric.L2,
        return_adjacency_matrix=False,
        n_samples=10000,
    )

    profiler.generate_samples_and_queries()
    return profiler.sampler._scm


def test_counterfactual_axioms_no_edges():
    """
    Test the verify_structural_counterfactual_axioms with expected_edges=0 (fully independent).
    We check that the success rates for the axioms are high.
    """
    # Create a random SCM with no edges
    scm = create_test_scm(expected_edges=0)

    # Counterfactual axioms verification
    results = verify_structural_counterfactual_axioms(
        scm,
        tolerance=1e-7,  # numeric tolerance
        seed=123,  # fix seed for reproducibility
    )

    # We expect near-perfect compliance with the axioms.
    assert (
        results["composition_success_rate"] >= 0.99
    ), f"Composition: Expected ~100% compliance on a simple no-edges SCM, got {results['composition_success_rate']}"
    assert (
        results["effectiveness_success_rate"] >= 0.99
    ), f"Effectiveness: Expected ~100% compliance on a simple no-edges SCM, got {results['effectiveness_success_rate']}"
    assert (
        results["reversibility_success_rate"] >= 0.99
    ), f"Reversibility: Expected ~100% compliance on a simple no-edges SCM, got {results['reversibility_success_rate']}"
    # Check that at least some tests were run
    assert results["num_tests_composition"] > 0, "No composition tests were run!"
    assert results["num_tests_effectiveness"] > 0, "No effectiveness tests were run!"
    assert results["num_tests_reversibility"] > 0, "No reversibility tests were run!"


def test_counterfactual_axioms_with_some_edges():
    """
    Test the verify_structural_counterfactual_axioms with expected_edges=2 or 3,
    checking that it doesn't yield bizarre results or exceptions.
    """
    scm = create_test_scm(expected_edges=2)

    results = verify_structural_counterfactual_axioms(scm, tolerance=1e-5, seed=999)

    assert (
        results["composition_success_rate"] >= 0.99
    ), f"Composition: Expected ~100% compliance on a simple no-edges SCM, got {results['composition_success_rate']}"
    assert (
        results["effectiveness_success_rate"] >= 0.99
    ), f"Effectiveness: Expected ~100% compliance on a simple no-edges SCM, got {results['effectiveness_success_rate']}"
    # Note that in theory reversibility *can* be None if no premise was satisfied
    assert (
        results["reversibility_success_rate"] >= 0.99
    ), f"Reversibility: Expected ~100% compliance on a simple no-edges SCM, got {results['reversibility_success_rate']}"
    # Check that at least some tests were run
    assert results["num_tests_composition"] > 0, "No composition tests were run!"
    assert results["num_tests_effectiveness"] > 0, "No effectiveness tests were run!"
    assert results["num_tests_reversibility"] > 0, "No reversibility tests were run!"


@pytest.mark.parametrize("expected_edges", [0, 2, 4])
@pytest.mark.parametrize("n_tests", [1, 2, 5])
@pytest.mark.parametrize("tolerance", [1e-5, 1e-7, 1e-9])
def test_counterfactual_axioms_varied_parameters(expected_edges, n_tests, tolerance):
    """
    Tests `verify_structural_counterfactual_axioms` with various configurations:
      - expected_edges: 0 (independent SCM), 2 (moderate), 4 (dense)
      - n_tests: Deprecated but leaving for completeness
      - tolerance: 1e-5 (loose), 1e-7 (strict), 1e-9 (very strict)

    This ensures that our verification method is robust across different causal graphs.
    """
    scm = create_test_scm(expected_edges=expected_edges)

    results = verify_structural_counterfactual_axioms(
        scm, n_tests=n_tests, tolerance=tolerance, seed=123
    )

    assert (
        results["composition_success_rate"] >= 0.99
    ), f"Composition: Expected ~100% compliance on a simple no-edges SCM, got {results['composition_success_rate']}"
    assert (
        results["effectiveness_success_rate"] >= 0.99
    ), f"Effectiveness: Expected ~100% compliance on a simple no-edges SCM, got {results['effectiveness_success_rate']}"
    # Note that in theory reversibility *can* be None if no premise was satisfied
    assert (
        results["reversibility_success_rate"] >= 0.99
    ), f"Reversibility: Expected ~100% compliance on a simple no-edges SCM, got {results['reversibility_success_rate']}"
    # Check that at least some tests were run
    assert results["num_tests_composition"] > 0, "No composition tests were run!"
    assert results["num_tests_effectiveness"] > 0, "No effectiveness tests were run!"
    assert results["num_tests_reversibility"] > 0, "No reversibility tests were run!"
