import random

import numpy as np
import torch

from causal_profiler import CausalProfiler, ErrorMetric
from causal_profiler.space_of_interest import (
    MechanismFamily,
    NoiseDistribution,
    QueryType,
    SpaceOfInterest,
    VariableDataType,
)

from verification.verify_do_calculus import (
    verify_do_calculus,
    verify_rule1_insertion_deletion_observation,
    verify_rule2_action_observation_exchange,
    verify_rule3_insertion_deletion_action,
)


def create_test_scm_and_estimator(expected_edges):
    """
    Creates a random SCM, data_dict, and query_estimator using the given SpaceOfInterest config.

    Parameters
    ----------
    expected_edges : int
        Expected number of edges in SCM

    Returns
    -------
    data_dict, index_to_var, scm, query_estimator
    """
    # Configure SpaceOfInterest
    space = SpaceOfInterest(
        number_of_nodes=(4, 4),
        expected_edges=expected_edges,
        variable_dimensionality=(1, 1),
        mechanism_family=MechanismFamily.TABULAR,
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_args=[-1, 1],
        number_of_noise_regions="N",
        number_of_categories=2,  # Binary variables for simplicity
        variable_type=VariableDataType.DISCRETE,
        number_of_queries=1,
        query_type=QueryType.CONDITIONAL,
        number_of_data_points=50,  # Just need a small dataset
    )

    # Create CausalProfiler and generate samples
    profiler = CausalProfiler(
        space_of_interest=space,
        metric=ErrorMetric.L2,
        return_adjacency_matrix=False,
    )

    # Generate samples and queries
    (data_dict, (_, _), (_, index_to_var)) = profiler.generate_samples_and_queries()
    scm = profiler.sampler._scm
    query_estimator = profiler.sampler.query_estimator

    return data_dict, index_to_var, scm, query_estimator


def test_rule1_empty_graph():
    """Test rule 1 with an empty graph (no edges)."""
    # Set seed for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    # Create SCM with no edges
    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=0
    )

    # Choose variables for testing
    Y = index_to_var[0]  # Target variable
    Z = [index_to_var[1]]  # Observation to insert/delete
    X = [index_to_var[2]]  # Intervention
    W = []  # No extra conditioning

    # Test rule 1
    result = verify_rule1_insertion_deletion_observation(
        query_estimator=query_estimator,
        scm=scm,
        Y=Y,
        Z=Z,
        X=X,
        W=W,
        data_dict=data_dict,
        alpha=0.05,
        correction="BH",
        n_samples=50000,
    )

    # In an empty graph, all nodes should be d-separated
    assert result[
        "d_separation_holds"
    ], "Variables should be d-separated in empty graph"

    # Rule should have relatively few rejections
    # Due to sampling and statistical variations, even in empty graphs we might see some rejections
    if result["rejection_rate"] is not None:
        # For testing, we'll use a more lenient threshold
        assert (
            result["rejection_rate"] < 0.05
        ), f"Rule 1 rejection rate {result['rejection_rate']} should be lower in empty graph"

    # Check that p-values were computed (if we have valid tests)
    if len(result["distribution_results"]) > 0:
        assert all(
            d["p_value"] >= 0.0 and d["p_value"] <= 1.0
            for d in result["distribution_results"]
        ), "P-values should be in [0,1]"

        # Check that rejection rate is available
        assert "rejection_rate" in result
        assert 0 <= result["rejection_rate"] <= 1, "Rejection rate should be in [0,1]"


def test_rule2_chain_graph():
    """Test rule 2 with a chain graph (X → Z → Y)."""
    # Set seed for reproducibility
    random.seed(43)
    np.random.seed(43)
    torch.manual_seed(43)

    # Create SCM with a few edges (aiming for chain-like)
    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=2
    )

    # Choose variables for testing - we'll try multiple combinations
    # since we don't control the graph structure precisely
    for i in range(3):  # Try a few combinations
        Y = index_to_var[i]  # Target variable
        Z = [index_to_var[(i + 1) % 4]]  # Observation/action variable
        X = [index_to_var[(i + 2) % 4]]  # Intervention
        W = []  # No extra conditioning

        # Test rule 2
        result = verify_rule2_action_observation_exchange(
            query_estimator=query_estimator,
            scm=scm,
            Y=Y,
            Z=Z,
            X=X,
            W=W,
            data_dict=data_dict,
            alpha=0.05,
            correction="BH",
            n_samples=5000,
        )

        # Output the result for debugging
        print(
            f"Rule 2 test {i}: d-separation: {result['d_separation_holds']}, rule holds: {result['rule_holds']}"
        )

        # Some combinations should lead to d-separation in reasonable graphs
        if result["d_separation_holds"]:
            # If d-separation holds, check statistical results
            if result["rule_holds"] is not None:
                assert result[
                    "rule_holds"
                ], "Rule 2 should hold when d-separation holds"

            # Check that we have valid p-values
            if len(result["distribution_results"]) > 0:
                assert all(
                    d["p_value"] >= 0.0 and d["p_value"] <= 1.0
                    for d in result["distribution_results"]
                ), "P-values should be in [0,1]"
                break


def test_rule3_fork_graph():
    """Test rule 3 with a fork-like graph."""
    # Set seed for reproducibility
    random.seed(44)
    np.random.seed(44)
    torch.manual_seed(44)

    # Create SCM with a few edges (aiming for fork-like)
    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=3
    )

    # Choose variables for testing - we'll try multiple combinations
    for i in range(3):  # Try a few combinations
        Y = index_to_var[i]  # Target variable
        Z = [index_to_var[(i + 1) % 4]]  # Action variable
        X = [index_to_var[(i + 2) % 4]]  # Intervention
        W = [index_to_var[(i + 3) % 4]]  # Conditioning

        # Test rule 3
        result = verify_rule3_insertion_deletion_action(
            query_estimator=query_estimator,
            scm=scm,
            Y=Y,
            Z=Z,
            X=X,
            W=W,
            data_dict=data_dict,
            alpha=0.05,
            correction="BH",
            n_samples=5000,
        )

        # Output the result for debugging
        print(
            f"Rule 3 test {i}: d-separation: {result['d_separation_holds']}, rule holds: {result['rule_holds']}"
        )
        print(f"Z(W): {result.get('z_w', [])}")

        if result["d_separation_holds"]:
            # If d-separation holds, check statistical results
            if result["rule_holds"] is not None:
                assert result[
                    "rule_holds"
                ], "Rule 3 should hold when d-separation holds"

            # Check that we have valid p-values
            if len(result["distribution_results"]) > 0:
                assert all(
                    d["p_value"] >= 0.0 and d["p_value"] <= 1.0
                    for d in result["distribution_results"]
                ), "P-values should be in [0,1]"
                break


def test_full_do_calculus_verification():
    """Test the combined do-calculus verification function."""
    # Set seed for reproducibility
    random.seed(45)
    np.random.seed(45)
    torch.manual_seed(45)

    # Create SCM
    data_dict, index_to_var, scm, query_estimator = create_test_scm_and_estimator(
        expected_edges=1  # Simple graph for faster testing
    )

    # Choose variables for testing
    Y = index_to_var[0]  # Target variable
    Z = [index_to_var[1]]  # Variables to test in rules
    X = [index_to_var[2]]  # Intervention
    W = []  # No extra conditioning

    # Test do-calculus verification
    results = verify_do_calculus(
        query_estimator=query_estimator,
        scm=scm,
        Y=Y,
        Z=Z,
        X=X,
        W=W,
        data_dict=data_dict,
        alpha=0.05,
        correction="BH",
        n_samples=5000,
    )

    # Verify that we get results for each rule
    assert "rule1_results" in results
    assert "rule2_results" in results
    assert "rule3_results" in results

    print("Do-calculus verification results:")
    print(f"Rule 1 d-separation: {results['rule1_results']['d_separation_holds']}")
    print(f"Rule 2 d-separation: {results['rule2_results']['d_separation_holds']}")
    print(f"Rule 3 d-separation: {results['rule3_results']['d_separation_holds']}")

    # For any rule where d-separation holds, verify the statistical testing
    for rule_num, rule_results in [
        (1, results["rule1_results"]),
        (2, results["rule2_results"]),
        (3, results["rule3_results"]),
    ]:
        if rule_results["d_separation_holds"]:
            if rule_results.get("rejection_rate") is not None:
                # Just verify rejection rate is a reasonable value
                assert (
                    0 <= rule_results["rejection_rate"] <= 1
                ), f"Rule {rule_num} rejection rate should be in [0,1]"

                # For inference rules 2 and 3, we still expect reasonable performance
                if (
                    rule_num in [2, 3]
                    and rule_results["rejection_rate"]
                    > 0.65  # Higher than the 0.2 used in the implementation, lenient, for testing only
                ):
                    print(
                        f"Warning: Rule {rule_num} has unexpectedly high rejection rate: {rule_results['rejection_rate']}"
                    )

            # Verify we have data about p-values and rejections
            if len(rule_results.get("p_values", [])) > 0:
                assert (
                    "rejected" in rule_results
                ), f"Rule {rule_num} missing 'rejected' field"
                assert len(rule_results["p_values"]) == len(
                    rule_results["rejected"]
                ), f"Rule {rule_num} p_values and rejected arrays should have same length"

                # Check p-values are valid
                assert all(
                    0 <= p <= 1 for p in rule_results["p_values"]
                ), f"Rule {rule_num} p-values should be in [0,1]"

                # Check rejection count matches rejection rate
                if len(rule_results["rejected"]) > 0:
                    rejection_count = sum(rule_results["rejected"])
                    calculated_rate = rejection_count / len(rule_results["rejected"])
                    assert (
                        abs(calculated_rate - rule_results["rejection_rate"]) < 1e-5
                    ), f"Rule {rule_num} rejection rate inconsistent"
