import pytest

from frame.productions.concepts.exists import ExistsRule
from frame.tools.z3_template import Z3Template
from frame.knowledge_base.entities import Nat

# ==============================
# Fixtures
# ==============================


@pytest.fixture
def exists_rule():
    """Create a fresh ExistsRule instance for each test."""
    return ExistsRule()


# ==============================
# Basic Rule Tests
# ==============================


class TestExistsRuleBasics:
    """Test basic functionality of the ExistsRule class."""

    def test_get_input_types(self, exists_rule):
        """Test that get_input_types returns the correct types."""
        input_types = exists_rule.get_input_types()
        assert isinstance(input_types, list), "Should return a list"
        assert len(input_types) > 0, "Should return at least one input type combination"

        # Each input type should be a tuple with a Type and a list of ConceptTypes
        for type_combo in input_types:
            assert isinstance(type_combo, tuple), "Each combination should be a tuple"
            assert len(type_combo) == 2, "Each type spec should have 2 elements"
            assert isinstance(type_combo[1], list), "Second element should be a list"

    def test_get_valid_parameterizations(self, exists_rule, square, multiply):
        """Test that get_valid_parameterizations returns valid parameterizations."""
        # Test with unary function (square)
        params_square = exists_rule.get_valid_parameterizations(square)
        assert isinstance(params_square, list), "Should return a list"
        assert len(params_square) > 0, "Should return at least one parameterization"
        assert all(
            "indices_to_quantify" in p for p in params_square
        ), "Each parameterization should include indices_to_quantify"

        # Test with binary function (multiply)
        params_multiply = exists_rule.get_valid_parameterizations(multiply)
        assert len(params_multiply) > len(
            params_square
        ), "Binary function should have more parameterizations than unary"

        # Check combinations: for binary function, we can quantify over index 0, index 1, or both
        indices_combinations = set(
            tuple(p["indices_to_quantify"]) for p in params_multiply
        )
        assert (
            0,
        ) in indices_combinations, "Should include quantifying over first argument"
        assert (
            1,
        ) in indices_combinations, "Should include quantifying over second argument"
        assert (
            0,
            1,
        ) in indices_combinations, "Should include quantifying over both arguments"
        assert (
            0,
            1,
            2,
        ) not in indices_combinations, "Should not include existential conjecture"

    def test_can_apply_valid_predicate(self, exists_rule, divides_and_even):
        """Test that can_apply returns True for valid predicates."""
        assert exists_rule.can_apply(
            divides_and_even, indices_to_quantify=[1]
        ), "Should allow existential quantification over second argument of a predicate"

    def test_can_apply_valid_function(self, exists_rule, square):
        """Test that can_apply returns True for valid functions."""
        assert exists_rule.can_apply(
            square, indices_to_quantify=[0]
        ), "Should allow existential quantification over the argument of a function"


# ==============================
# Predicate Quantification Tests
# ==============================


class TestPredicateQuantification:
    """Test existential quantification of predicates."""

    def test_quantify_over_predicate_argument(self, exists_rule, divides_and_even):
        """Test quantifying over the second argument of a binary predicate."""
        # Create has_even_divisor(a) = ∃b. divides_and_even(a, b)
        has_even_divisor = exists_rule.apply(divides_and_even, indices_to_quantify=[1])

        # Set computational implementation
        has_even_divisor.set_computational_implementation(
            lambda a: any(divides_and_even.compute(a, b) for b in range(2, a + 1, 2))
        )

        # Test the composition name
        assert has_even_divisor.name == "exists_(divides_and_even_indices_[1])"

        # Test computation for various cases
        # True cases (numbers with even divisors)
        assert has_even_divisor.compute(4), "4 has even divisor 2"
        assert has_even_divisor.compute(6), "6 has even divisor 2"
        assert has_even_divisor.compute(8), "8 has even divisors 2,4"
        assert has_even_divisor.compute(12), "12 has even divisors 2,4,6"

        # False cases (numbers without even divisors)
        assert not has_even_divisor.compute(1), "1 has no even divisors"
        assert not has_even_divisor.compute(3), "3 has no even divisors"
        assert not has_even_divisor.compute(5), "5 has no even divisors"
        assert not has_even_divisor.compute(7), "7 has no even divisors"

        # Test example transformation
        examples = has_even_divisor.examples.get_examples()
        expected_examples = [(4,), (6,), (8,), (12,)]  # Numbers with even divisors
        for expected in expected_examples:
            assert any(
                ex.value == expected for ex in examples
            ), f"Should have example {expected}"

    def test_quantify_over_multi_arg_predicate(self, exists_rule, multi_arg_predicate):
        """Test quantifying over one argument of a multi-argument predicate."""
        # Create exists_multi_arg predicate by quantifying over the second argument
        exists_multi_arg = exists_rule.apply(
            multi_arg_predicate, indices_to_quantify=[1]
        )

        # Set computational implementation
        exists_multi_arg.set_computational_implementation(
            lambda a, c: any(multi_arg_predicate.compute(a, b, c) for b in range(10))
        )

        # Test computation for various cases
        assert exists_multi_arg.compute(
            1, 3
        ), "∃b. 1 + b + 3 = 1 * b * 3 (true for b=2)"
        assert not exists_multi_arg.compute(
            2, 2
        ), "∃b. 2 + b + 2 = 2 * b * 2 (false for all integers b)"
        assert not exists_multi_arg.compute(
            3, 3
        ), "∃b. 3 + b + 3 = 3 * b * 3 (false for all integers b)"


# ==============================
# Function Quantification Tests
# ==============================


class TestFunctionQuantification:
    """Test existential quantification of functions."""

    def test_quantify_over_function_argument(self, exists_rule, square):
        """Test quantifying over the argument of a unary function to create a predicate."""
        # Create is_square(n) = ∃x. square(x) = n
        is_square = exists_rule.apply(square, indices_to_quantify=[0])

        # Set computational implementation
        is_square.set_computational_implementation(lambda n: int(n**0.5) ** 2 == n)

        # Test the concept name
        assert is_square.name == "exists_(square_indices_[0])"

        # Test computation for various cases
        # True cases (perfect squares)
        assert is_square.compute(0), "0 is a perfect square"
        assert is_square.compute(1), "1 is a perfect square"
        assert is_square.compute(4), "4 is a perfect square"
        assert is_square.compute(9), "9 is a perfect square"
        assert is_square.compute(16), "16 is a perfect square"

        # False cases (non-squares)
        assert not is_square.compute(2), "2 is not a perfect square"
        assert not is_square.compute(3), "3 is not a perfect square"
        assert not is_square.compute(5), "5 is not a perfect square"
        assert not is_square.compute(10), "10 is not a perfect square"

    def test_quantify_over_binary_function_first_arg(self, exists_rule, multiply):
        """Test quantifying over the first argument of a binary function."""
        # Create has_factor(y, product) = ∃x. multiply(x, y) = product
        has_factor = exists_rule.apply(multiply, indices_to_quantify=[0])

        # Set computational implementation
        has_factor.set_computational_implementation(
            lambda y, product: product % y == 0 and product != 0
        )

        # Test computation for various cases
        assert has_factor.compute(2, 4), "2 is a factor of 4"
        assert has_factor.compute(3, 6), "3 is a factor of 6"
        assert not has_factor.compute(2, 5), "2 is not a factor of 5"
        assert has_factor.compute(2, 10), "2 is a factor of 10"

    def test_quantify_over_binary_function_second_arg(self, exists_rule, multiply):
        """Test quantifying over the second argument of a binary function."""
        # Create can_multiply_to(x, product) = ∃y. multiply(x, y) = product
        can_multiply_to = exists_rule.apply(multiply, indices_to_quantify=[1])

        # Set computational implementation
        can_multiply_to.set_computational_implementation(
            lambda x, product: product % x == 0 and product != 0
        )

        # Test computation for various cases
        assert can_multiply_to.compute(2, 4), "2 can multiply to 4"
        assert can_multiply_to.compute(3, 6), "3 can multiply to 6"
        assert not can_multiply_to.compute(2, 5), "2 cannot multiply to 5"
        assert can_multiply_to.compute(4, 8), "4 can multiply to 8"

    def test_quantify_over_both_function_args(self, exists_rule, multiply):
        """Test quantifying over both arguments of a binary function."""
        # Create is_product(product) = ∃x,y. multiply(x, y) = product
        is_product = exists_rule.apply(multiply, indices_to_quantify=[0, 1])

        # Set computational implementation (checks if n has at least one factor pair)
        is_product.set_computational_implementation(
            lambda product: product > 0
            and any(product % i == 0 for i in range(1, int(product**0.5) + 1))
        )

        # Test computation for various cases
        assert is_product.compute(1), "1 is a product (1*1)"
        assert is_product.compute(4), "4 is a product (2*2)"
        assert is_product.compute(6), "6 is a product (2*3)"
        assert not is_product.compute(
            0
        ), "0 is not considered a valid product for our implementation"

    def test_multi_output_function_quantification(self, exists_rule, add_multiply):
        """Test quantifying over arguments of a function with multiple outputs."""
        # Create can_add_multiply_to(y, sum, product) = ∃x. add_multiply(x, y) = (sum, product)
        can_add_multiply_to = exists_rule.apply(add_multiply, indices_to_quantify=[0])

        # Set computational implementation
        can_add_multiply_to.set_computational_implementation(
            lambda y, target_sum, target_product: target_sum - y >= 0
            and (target_sum - y) * y == target_product
        )

        # Test computation for various cases
        assert can_add_multiply_to.compute(2, 5, 6), "∃x. x+2=5, x*2=6 (true for x=3)"
        assert can_add_multiply_to.compute(3, 8, 15), "∃x. x+3=8, x*3=15 (true for x=5)"
        assert not can_add_multiply_to.compute(
            2, 5, 7
        ), "∃x. x+2=5, x*2=7 (false for all x)"


# ==============================
# Error Case Tests
# ==============================


class TestExistsRuleErrors:
    """Test error cases for the ExistsRule class."""

    def test_quantify_over_all_predicate_args(self, exists_rule, divides_and_even):
        """Test that quantifying over all arguments of a predicate is not allowed."""
        assert not exists_rule.can_apply(
            divides_and_even, indices_to_quantify=[0, 1]
        ), "Should not allow quantifying over all arguments of a predicate"

    def test_quantify_over_function_output(self, exists_rule, square):
        """Test that quantifying over a function's output is not allowed."""
        # Try to quantify over an out-of-range index (function output)
        assert not exists_rule.can_apply(
            square, indices_to_quantify=[1]
        ), "Should not allow quantifying over function outputs"

    def test_no_indices_specified(self, exists_rule, square):
        """Test that not specifying any indices to quantify is not allowed."""
        assert not exists_rule.can_apply(
            square, indices_to_quantify=[]
        ), "Should not allow empty list of indices to quantify"

    def test_invalid_concept_type(self, exists_rule, zero_concept):
        """Test that applying to a constant concept is not allowed."""
        assert not exists_rule.can_apply(
            zero_concept, indices_to_quantify=[0]
        ), "Should not allow application to constant concepts"

    def test_apply_with_invalid_input(self, exists_rule, divides_and_even):
        """Test that apply raises an error with invalid inputs."""
        with pytest.raises(ValueError):
            # Try to apply with invalid indices
            exists_rule.apply(divides_and_even, indices_to_quantify=[0, 1])


# ==============================
# Complex Case Tests
# ==============================


class TestExistsRuleComplexCases:
    """Test complex cases for the ExistsRule class."""

    # TODO: This test was erroneous as the rule application was incorrect. Now that proper divisors count is a function it also needs to be amended with specialize first.
    # def test_is_prime_quantification(self, exists_rule, proper_divisors_count):
    #     """Test creating is_prime from proper_divisors_count using exists."""
    #     # Create is_prime(n) = ∃k. proper_divisors_count(n, k) ∧ k=1
    #     # We're testing if there exists exactly 1 proper divisor

    #     # First we'll create a concept that checks if n has exactly 1 proper divisor
    #     has_exactly_one_divisor = exists_rule.apply(
    #         proper_divisors_count, indices_to_quantify=[1]
    #     )

    #     # Set computational implementation
    #     has_exactly_one_divisor.set_computational_implementation(
    #         lambda n: len([i for i in range(1, n) if n % i == 0]) == 1
    #     )

    #     # Test computation for various cases
    #     assert has_exactly_one_divisor.compute(
    #         2
    #     ), "2 has exactly one proper divisor (1)"
    #     assert has_exactly_one_divisor.compute(
    #         3
    #     ), "3 has exactly one proper divisor (1)"
    #     assert has_exactly_one_divisor.compute(
    #         5
    #     ), "5 has exactly one proper divisor (1)"
    #     assert has_exactly_one_divisor.compute(
    #         7
    #     ), "7 has exactly one proper divisor (1)"
    #     assert not has_exactly_one_divisor.compute(4), "4 has two proper divisors (1,2)"
    #     assert not has_exactly_one_divisor.compute(
    #         6
    #     ), "6 has three proper divisors (1,2,3)"

    def test_nested_existential_quantification(self, exists_rule, multiply, is_even):
        """Test creating a concept using nested existential quantification."""
        # First create has_even_multiple(n) = ∃m. multiply(n, m) = even
        # Then create has_even_multiple_witness(n) = ∃k. has_even_multiple(k) ∧ k>n

        # Step 1: Create intermediate concept of numbers with even multiples
        # The concept checks if there exists m such that n*m is even
        multiply_to_even = exists_rule.apply(multiply, indices_to_quantify=[1])

        # We'll set a computational implementation that checks if n can multiply to an even number
        multiply_to_even.set_computational_implementation(
            lambda n, product: is_even.compute(product) and product % n == 0
        )

        # Test the intermediate concept
        assert multiply_to_even.compute(2, 4), "2 can multiply to even 4"
        assert multiply_to_even.compute(3, 6), "3 can multiply to even 6"
        assert not multiply_to_even.compute(2, 5), "2 cannot multiply to odd 5"


# ==============================
# Integration Tests
# ==============================


class TestExistsRuleIntegration:
    """Test integration of ExistsRule with other rules and concepts."""

    def test_integration_with_match_rule(self, exists_rule, multiply):
        """Test integrating MatchRule and ExistsRule to create is_square."""
        from frame.productions.concepts.match import MatchRule

        # Step 1: Use MatchRule to create square from multiplication
        match_rule = MatchRule()
        square = match_rule.apply(multiply, indices_to_match=[0, 1])

        # Step 2: Use ExistsRule to create is_square from square
        is_square = exists_rule.apply(square, indices_to_quantify=[0])




# ==============================
# Z3 Translation Tests
# ==============================


class TestExistsRuleZ3:
    """Test Z3 translation of exists concepts."""

    def test_exists_function_translation(self, exists_rule, multiply):
        """Test quantifying over the first argument of a binary function."""
        # Create has_factor(y, product) = ∃x. multiply(x, y) = product
        has_factor = exists_rule.apply(multiply, indices_to_quantify=[0])

        # Test the translation of the concept
        template = has_factor.to_z3(Nat(1), Nat(2))
        print(template)
        result = template.run()
        assert result.proved

        template = has_factor.to_z3(Nat(3), Nat(5))
        result = template.run()
        assert not result.proved

    def test_exists_predicate_translation(self, exists_rule, divides_and_even):
        """Test quantifying over the second argument of a binary predicate."""
        # Create has_even_divisor(a) = ∃b. divides_and_even(a, b)
        has_even_divisor = exists_rule.apply(divides_and_even, indices_to_quantify=[1])
        
        template = has_even_divisor.to_z3(Nat(2))
        result = template.run()
        assert result.proved

        template = has_even_divisor.to_z3(Nat(3))
        result = template.run()
        assert not result.proved
