import pytest
from frame.knowledge_base.entities import ConceptType, Nat
from frame.productions.concepts.forall import ForallRule, NUM_PARAMETERIZATIONS

# ==============================
# Fixtures
# ==============================


@pytest.fixture
def forall_rule():
    """Create a fresh ForallRule instance for each test."""
    return ForallRule()


# ==============================
# Basic Rule Tests
# ==============================


class TestForallRuleBasics:
    """Test basic functionality of the ForallRule class."""

    def test_get_input_types(self, forall_rule):
        """Test that get_input_types returns the correct types."""
        input_types = forall_rule.get_input_types()
        assert isinstance(input_types, list), "Should return a list"
        assert len(input_types) == 2, "Should return two input type combinations"

        # First combo is for single predicate case
        assert len(input_types[0]) == 1, "First combo should have one input type"
        assert (
            input_types[0][0][1] == ConceptType.PREDICATE
        ), "First combo should be a predicate"

        # Second combo is for two predicates case
        assert len(input_types[1]) == 2, "Second combo should have two input types"
        assert (
            input_types[1][0][1] == ConceptType.PREDICATE
        ), "First input should be a predicate"
        assert (
            input_types[1][1][1] == ConceptType.PREDICATE
        ), "Second input should be a predicate"

    def test_get_valid_parameterizations_single_predicate(
        self, forall_rule, is_proper_divisor
    ):
        """Test get_valid_parameterizations for a single predicate."""
        # is_proper_divisor has arity 2, so we should get parameterizations for quantifying
        # over index 0 and index 1
        valid_params = forall_rule.get_valid_parameterizations(is_proper_divisor)

        assert (
            len(valid_params) == 3
        ), "Should return 3 parameterizations for binary predicate"

        # Check that we have parameterizations for [0] and [1]
        index_options = [param["indices_to_quantify"] for param in valid_params]
        assert [0] in index_options, "Should include option to quantify over index 0"
        assert [1] in index_options, "Should include option to quantify over index 1"
        assert [0, 1] in index_options, "Should include option to quantify over both indices"
    
    def test_get_valid_parameterizations_two_predicates(
        self, forall_rule, is_proper_divisor, is_even
    ):
        """Test get_valid_parameterizations for two predicates."""
        # is_proper_divisor has arity 2 and is_even has arity 1
        # We should get parameterizations for only valid mappings:
        # - index [0] from is_proper_divisor to is_even
        # - index [1] from is_proper_divisor to is_even
        valid_params = forall_rule.get_valid_parameterizations(
            is_proper_divisor, is_even
        )

        assert (
            len(valid_params) == 6
        ), "Should return parameterizations for each valid index"

        # Check that we have parameterizations for [0] and [1]
        index_options = [param["indices_to_quantify"] for param in valid_params]
        assert [0] in index_options, "Should include option to quantify over index 0"
        assert [1] in index_options, "Should include option to quantify over index 1"

    def test_get_valid_parameterizations_higher_arity(
        self, forall_rule, divisor_mod_k_equals_r
    ):
        """Test get_valid_parameterizations for a higher arity predicate."""
        # divisor_mod_k_equals_r has arity 4, so we can quantify over 1, 2, or 3 indices
        valid_params = forall_rule.get_valid_parameterizations(divisor_mod_k_equals_r)

        # We should get combinations of size 1: [0], [1], [2], [3]
        # combinations of size 2: [0,1], [0,2], [0,3], [1,2], [1,3], [2,3]
        # combinations of size 3: [0,1,2], [0,1,3], [0,2,3], [1,2,3]
        # combinations of size 4: [0,1,2,3] (because we can make conjectures)
        # Total: 4 + 6 + 4 = 14
        assert (
            len(valid_params) == 15
        ), "Should return correct number of parameterizations"

        # Check a few specific combinations
        index_options = [tuple(param["indices_to_quantify"]) for param in valid_params]
        assert all([not "indices_to_map" in param for param in valid_params]), "Should not map any indices"
        assert (0,) in index_options, "Should include option to quantify over index 0"
        assert (
            1,
            2,
        ) in index_options, "Should include option to quantify over indices 1,2"
        assert (
            0,
            2,
            3,
        ) in index_options, "Should include option to quantify over indices 0,2,3"

        assert (0, 1, 2, 3) in index_options, "Should include option to quantify over all indices"

    def test_get_valid_parameterizations_complex_case(
        self, forall_rule, commutes, multi_arg_predicate
    ):
        """Test get_valid_parameterizations for complex predicates match."""
        # commutes has arity 3, multi_arg_predicate has arity 3
        valid_params = forall_rule.get_valid_parameterizations(
            commutes, multi_arg_predicate
        )

        # For two predicates, number of indices must match secondary arity (3)
        # So we need combinations of size 3 from indices 0,1,2
        # Only [0,1,2] is valid
        assert len(valid_params) == NUM_PARAMETERIZATIONS, f"Should return {NUM_PARAMETERIZATIONS} parameterizations but got {len(valid_params)}"

    def test_can_apply_single_predicate(self, forall_rule, is_proper_divisor):
        """Test that can_apply returns True for valid single predicate case."""
        assert forall_rule.can_apply(
            is_proper_divisor, indices_to_quantify=[1], verbose=False
        ), "Should allow forall over single predicate"

    def test_can_apply_two_predicates(self, forall_rule, is_proper_divisor, is_even):
        """Test that can_apply returns True for valid two predicates case."""
        assert forall_rule.can_apply(
            is_proper_divisor, is_even, indices_to_map={0: 0}, indices_to_quantify=[0], verbose=False
        ), "Should allow forall over two predicates"


# ==============================
# Single Predicate Tests
# ==============================


class TestSinglePredicateForall:
    """Test universal quantification over a single predicate."""

    def test_divisor_property_quantification(self, forall_rule, divisor_mod_k_equals_r):
        """Test quantifying over divisors with a modular property."""
        # Create a predicate: "All divisors of n are congruent to r mod k"
        forall_divisors_mod_k = forall_rule.apply(
            divisor_mod_k_equals_r, indices_to_quantify=[1]  # Quantify over the divisor d
        )

        # Test the resulting concept
        assert (
            forall_divisors_mod_k.name == "forall_(divisor_mod_k_equals_r_indices_to_quantify_[1])"
        )
        assert forall_divisors_mod_k.examples.example_structure.input_arity == 3

    def test_group_property_quantification(self, forall_rule, commutes):
        """Test quantifying over group elements to check if a group is abelian."""
        # Create a predicate: "For all elements a, b in group G, a * b = b * a"
        is_abelian = forall_rule.apply(
            commutes, indices_to_quantify=[1, 2]  # Quantify over both elements
        )

        # Test the resulting concept
        assert is_abelian.name == "forall_(commutes_indices_to_quantify_[1, 2])"
        assert is_abelian.examples.example_structure.input_arity == 1

    def test_invalid_indices(self, forall_rule, is_proper_divisor):
        """Test that can_apply rejects invalid indices."""
        # Try to quantify over non-existent index
        assert not forall_rule.can_apply(
            is_proper_divisor, indices_to_quantify=[2], verbose=True
        ), "Should reject quantification over non-existent index"

        # Try to quantify over indices that don't exist
        assert not forall_rule.can_apply(
            is_proper_divisor, indices_to_map={0: 0, 1: 1}, indices_to_quantify=[0, 1, 2], verbose=True
        ), "Should reject quantification over indices that don't exist"


# ==============================
# Two Predicates Tests
# ==============================


class TestTwoPredicatesForall:
    """Test universal quantification with implication between two predicates."""

    def test_odd_divisors_implication(self, forall_rule, is_proper_divisor, is_odd):
        """Test forall with implication: if d is a proper divisor of n, then d is odd."""
        # Create a predicate: "All proper divisors of n are odd"
        all_divisors_odd = forall_rule.apply(
            is_proper_divisor, is_odd, indices_to_map={1: 0}, indices_to_quantify=[0]  # d index in is_proper_divisor
        )

        # Test the resulting concept
        assert (
            all_divisors_odd.name
            == "forall_(is_proper_divisor_with_is_odd_indices_to_map_{1: 0}_indices_to_quantify_[0])"
        )
        assert all_divisors_odd.examples.example_structure.input_arity == 1

        # Test with known examples
        assert not all_divisors_odd.compute(4)  # has proper divisor 2, which is not odd
        assert not all_divisors_odd.compute(6)  # has proper divisor 2, which is not odd
        assert not all_divisors_odd.compute(8)  # has proper divisor 2 which is not odd
        assert not all_divisors_odd.compute(10)  # has proper divisor 2 which is not odd

    def test_even_divisors_implication(self, forall_rule, is_proper_divisor, is_even):
        """Test forall with implication: if d is a proper divisor of n, then d is even."""
        # Create a predicate: "All proper divisors of n are even"
        all_divisors_even = forall_rule.apply(
            is_proper_divisor, is_even, indices_to_map={1: 0}, indices_to_quantify=[1]  # d index in is_proper_divisor
        )

        # Test the resulting concept
        assert (
            all_divisors_even.name
            == "forall_(is_proper_divisor_with_is_even_indices_to_map_{1: 0}_indices_to_quantify_[1])"
        )
        assert all_divisors_even.examples.example_structure.input_arity == 1

        # Test with known examples
        # 2 has proper divisor 1, which is not even
        assert not all_divisors_even.compute(2)
        # 6 has proper divisor 1 and 3, which are not even
        assert not all_divisors_even.compute(6)
        # 8 has proper divisor 1, which is not even
        assert not all_divisors_even.compute(8)
        # 16 has proper divisor 1, which is not even
        assert not all_divisors_even.compute(16)

    def test_nonpredicates_rejected(self, forall_rule, is_proper_divisor, square):
        """Test that non-predicates are rejected."""
        # Try to use a function instead of a predicate
        assert not forall_rule.can_apply(
            is_proper_divisor, square, indices_to_map={0: 0}, indices_to_quantify=[1], verbose=False
        ), "Should reject non-predicate secondary concept"

        assert not forall_rule.can_apply(
            square, is_proper_divisor, indices_to_map={0: 0}, indices_to_quantify=[0], verbose=False
        ), "Should reject non-predicate primary concept"


# ==============================
# Error Case Tests
# ==============================


class TestForallRuleErrors:
    """Test error cases for the ForallRule class."""

    def test_missing_indices(self, forall_rule, is_proper_divisor):
        """Test that applying without indices raises an error."""
        with pytest.raises(ValueError):
            forall_rule.apply(is_proper_divisor)

    def test_wrong_number_of_inputs(
        self, forall_rule, is_proper_divisor, is_even, square
    ):
        """Test that wrong number of inputs is rejected."""
        # Try with three inputs
        assert not forall_rule.can_apply(
            is_proper_divisor, is_even, square, indices_to_map={0: 0}, indices_to_quantify=[1], verbose=False
        ), "Should reject three inputs"

    def test_empty_parameterizations(self, forall_rule):
        """Test that get_valid_parameterizations returns empty list for no inputs."""
        valid_params = forall_rule.get_valid_parameterizations()
        assert valid_params == [], "Should return empty list for no inputs"

    # def test_invalid_inputs_parameterizations(self, forall_rule, square):
    #     """Test that get_valid_parameterizations returns empty list for invalid inputs."""
    #     # Try with a function (non-predicate)
    #     valid_params = forall_rule.get_valid_parameterizations(square)
    #     assert valid_params == [], "Should return empty list for non-predicate input"

    def test_parameterizations_with_invalid_indices(
        self, forall_rule, is_proper_divisor
    ):
        """Test that get_valid_parameterizations with invalid indices would fail can_apply."""
        # Get all valid parameterizations
        valid_params = forall_rule.get_valid_parameterizations(is_proper_divisor)

        # Modify each parameterization to include invalid index
        for param in valid_params:
            # Add invalid index 5
            param["indices_to_quantify"].append(5)
            # Should now fail can_apply
            assert not forall_rule.can_apply(
                is_proper_divisor, indices_to_quantify=param["indices_to_quantify"], verbose=False
            ), "Should reject invalid indices"


# ==============================
# Integration Tests
# ==============================


class TestForallRuleIntegration:
    """Integration tests for the ForallRule."""

    def test_example_transformation_implication(
        self, forall_rule, is_proper_divisor, is_even
    ):
        """Test that examples are properly transformed."""
        # Create a predicate: "All proper divisors of n are even"
        all_divisors_even = forall_rule.apply(is_proper_divisor, is_even, indices_to_map={1: 0}, indices_to_quantify=[1])

        # Check that non examples were properly generated
        # examples are not generated for this concept by default
        nonexamples = [ex.value for ex in all_divisors_even.examples.get_nonexamples()]

        # Nonexamples (numbers with at least one odd proper divisor)
        assert (6,) in nonexamples, "6 should be a nonexample (has odd divisor 3)"
        assert (9,) in nonexamples, "9 should be a nonexample (has odd divisor 3)"
        assert (12,) in nonexamples, "12 should be a nonexample (has odd divisor 3)"

    def test_nested_quantification(self, forall_rule, divisor_mod_k_equals_r):
        """Test applying multiple forall rules in sequence."""
        # First, quantify over the divisor d
        forall_divisors = forall_rule.apply(divisor_mod_k_equals_r, indices_to_quantify=[1])

        # Then, quantify over the modulus k
        forall_divisors_all_mod = forall_rule.apply(forall_divisors, indices_to_quantify=[1])

        # The final concept should check if for all k, all divisors of n are congruent to r mod k
        assert forall_divisors_all_mod.examples.example_structure.input_arity == 2

        # Test examples
        assert not forall_divisors_all_mod.compute(6, 2)
        assert not forall_divisors_all_mod.compute(3, 1)
        assert not forall_divisors_all_mod.compute(15, 5)

    def test_apply_using_valid_parameterizations(self, forall_rule, commutes):
        """Test applying rule using valid parameterizations."""
        # Get valid parameterizations for commutes
        valid_params = forall_rule.get_valid_parameterizations(commutes)

        # We expect multiple valid parameterizations since commutes has arity 3
        assert len(valid_params) > 0, "Should have valid parameterizations"

        # Try applying with each parameterization
        for param in valid_params:
            # Apply the rule with this parameterization
            result_concept = forall_rule.apply(commutes, indices_to_quantify=param["indices_to_quantify"])

            # Check that the resulting concept has the expected input arity
            expected_arity = 3 - len(param["indices_to_quantify"])
            assert (
                result_concept.examples.example_structure.input_arity == expected_arity
            ), f"Result should have arity {expected_arity} for indices {param['indices_to_quantify']}"

            # Check naming convention
            assert (
                result_concept.name == f"forall_(commutes_indices_to_quantify_{param['indices_to_quantify']})"
            ), "Result should have correct name"

    def test_two_predicates_with_parameterizations(
        self, forall_rule, is_proper_divisor, is_even
    ):
        """Test applying two-predicate rule using valid parameterizations."""
        # Get valid parameterizations for is_proper_divisor and is_even
        valid_params = forall_rule.get_valid_parameterizations(
            is_proper_divisor, is_even
        )

        assert len(valid_params) > 0, "Should have valid parameterizations"

        # Apply the rule with each parameterization
        for param in valid_params:
            result_concept = forall_rule.apply(
                is_proper_divisor, is_even, indices_to_map=param["indices_to_map"], indices_to_quantify=param["indices_to_quantify"]
            )
            print(f"param: {param}")
            # Check that the resulting concept has the expected input arity
            expected_arity = 2 + 1 - len(param["indices_to_quantify"]) - len(param["indices_to_map"])
            assert (
                result_concept.examples.example_structure.input_arity == expected_arity
            ), f"Result should have arity {expected_arity} for indices {param['indices_to_quantify']}"

            # Check naming convention
            assert (
                result_concept.name
                == f"forall_(is_proper_divisor_with_is_even_indices_to_map_{param['indices_to_map']}_indices_to_quantify_{param['indices_to_quantify']})"
            ), "Result should have correct name"

            # Test special case: when quantifying over index 1
            if param["indices_to_quantify"] == [1] and param["indices_to_map"] == {1: 0}:
                # Test known examples
                assert not result_concept.compute(
                    8
                ), "8 should not have all even divisors"
                assert not result_concept.compute(
                    6
                ), "6 should not have all even divisors"


# ==============================
# Z3 Translation Tests
# ==============================


class TestForallRuleZ3:
    """Test Z3 translation of forall concepts."""

    def test_z3_translate_forall_single_predicate(self, forall_rule, divisor_mod_k_equals_r):
        """Test Z3 translation of forall with single predicate."""
        result_concept = forall_rule.apply(
            divisor_mod_k_equals_r,
            indices_to_quantify=[1],
        )

        result = result_concept.to_z3(Nat(1), Nat(2), Nat(3)).run()
        assert not result.proved, "Should not prove the concept"
        

    def test_z3_translate_forall_two_predicates(self, forall_rule, is_proper_divisor, is_odd):
        """Test Z3 translation of forall with two predicates."""
        result_concept = forall_rule.apply(
            is_proper_divisor,
            is_odd,
            indices_to_quantify=[0],
            indices_to_map={1: 0},
        )

        result = result_concept.to_z3(Nat(5)).run()
        assert result.proved, "Should prove the concept"

        result = result_concept.to_z3(Nat(6)).run()
        assert not result.proved, "Should not prove the concept"

