import pytest

from frame.productions.concepts.negate import NegateRule
from frame.knowledge_base.entities import Conjecture, ConceptType, Var, Nat

# ==============================
# Fixtures
# ==============================


@pytest.fixture
def negate_rule():
    """Create a fresh NegateRule instance for each test."""
    return NegateRule()


# ==============================
# Basic Rule Tests
# ==============================


class TestNegateRuleBasics:
    """Test basic functionality of the NegateRule class."""

    def test_get_input_types(self, negate_rule):
        """Test that get_input_types returns the correct types."""
        input_types = negate_rule.get_input_types()
        assert isinstance(input_types, list), "Should return a list"
        assert len(input_types) > 0, "Should return at least one input type combination"

        # Each input type should be a tuple with a Type and ConceptType.PREDICATE
        for type_combo in input_types:
            assert isinstance(type_combo, tuple), "Each combination should be a tuple"
            assert len(type_combo) == 2, "Each type spec should have 2 elements"
            assert (
                type_combo[1] == ConceptType.PREDICATE
            ), "Second element should be predicate"

    def test_get_valid_parameterizations(self, negate_rule, is_even):
        """Test that get_valid_parameterizations returns valid parameterizations."""
        # Test with unary predicate (is_even)
        params_even = negate_rule.get_valid_parameterizations(is_even)
        assert isinstance(params_even, list), "Should return a list"
        assert (
            len(params_even) == 1
        ), "No parametrizations for negates, returns empty dict"

    def test_can_apply_valid_predicate(self, negate_rule, is_even):
        """Test that can_apply returns True for valid predicates."""
        assert negate_rule.can_apply(is_even), "Should return True for predicate"


# ==============================
# Predicate Negate Tests
# ==============================


class TestNegateRulePredicate:
    """Test negation of predicates."""

    def test_concept_metadata(self, negate_rule, is_even):
        """Test that concept name and description are properly set."""
        is_odd = negate_rule.apply(is_even)
        assert (
            is_odd.name == "not_(is_even)"
        ), "Negated concept should have correct name"
        assert (
            is_odd.description == "Negation of is_even"
        ), "Negated concept should have correct description"

    def test_negate_predicate(self, negate_rule, is_even):
        """Test negating a predicate argument"""

        print(f"is_even has z3 translation: {is_even.has_z3_translation()}")
        is_odd = negate_rule.apply(is_even)
        
        assert (
            is_odd.examples.example_structure == is_even.examples.example_structure
        ), "Negated concept should preserve example structure"

        # Test examples
        even_examples = is_even.examples.get_examples()
        even_nonexamples = is_even.examples.get_nonexamples()
        odd_examples = is_odd.examples.get_examples()
        odd_nonexamples = is_odd.examples.get_nonexamples()

        # Compare values of examples
        assert {e.value for e in even_examples} == {
            e.value for e in odd_nonexamples
        }, "Predicate examples should be negated predicate nonexamples."
        assert {e.value for e in even_nonexamples} == {
            e.value for e in odd_examples
        }, "Predicate nonexamples should be negated predicate examples."

        # Test computation
        test_cases, target = [1, 2, 3, 4, 5], [1, 0, 1, 0, 1]
        for n, t in zip(test_cases, target):
            result = is_odd.compute(n)
            assert result == t, "Computing negated predicate failed"

    def test_multiple_negations(self, negate_rule, is_even):
        """Test that multiple negations work correctly."""
        is_odd = negate_rule.apply(is_even)
        is_even_again = negate_rule.apply(is_odd)

        # Test computation
        test_cases = [1, 2, 3, 4, 5]
        for n in test_cases:
            assert is_even_again.compute(n) == is_even.compute(
                n
            ), "Double negation should return to original concept"


# ==============================
# Error Case Tests
# ==============================


class TestNegateRuleErrors:
    """Test error cases for the NegateRule class."""

    def test_invalid_concept_type(self, negate_rule, square, zero_concept):
        """Test that applying to a invalid concept type is not allowed"""
        assert not negate_rule.can_apply(
            square
        ), "Should not allow application to functions"
        assert not negate_rule.can_apply(
            zero_concept
        ), "Should not allow application to constant concepts"

    def test_invalid_input_count(self, negate_rule, is_even):
        """Test that applying with wrong number of inputs is not allowed"""
        assert not negate_rule.can_apply(), "Should not allow zero inputs"
        assert not negate_rule.can_apply(
            is_even, is_even
        ), "Should not allow multiple inputs"

    def test_invalid_input_type(self, negate_rule):
        """Test that applying with non-Concept input is not allowed"""
        conjecture = Conjecture(
            name="test_conjecture",
            description="A test conjecture",
            symbolic_definition=lambda: True,
            example_structure=None,
        )
        assert not negate_rule.can_apply(
            conjecture
        ), "Should not allow non-Concept inputs"


# ==============================
# Integration Tests
# ==============================


class TestNegateRuleIntegration:
    """Test integration of NegateRule with other rules and concepts."""

    def test_integration_with_exists_rule(self, negate_rule, square):
        """Test integrating ExistsRule and MatchRule to create not_is_square"""
        from frame.productions.concepts.exists import ExistsRule

        # Step 1: Use ExistsRule to create is_square from square
        exists_rule = ExistsRule()
        is_square = exists_rule.apply(square, indices_to_quantify=[0])

        # Step 2: Use NegateRule to create not_is_square from square
        not_is_square = negate_rule.apply(is_square)

        # Test the resulting concept
        assert not not_is_square.compute(0), "0 is a square"
        assert not not_is_square.compute(1), "1 is a square"
        assert not not_is_square.compute(4), "4 is a square"
        assert not not_is_square.compute(9), "9 is a square"
        # assert not not_is_square.compute(16), "16 is a square" # TODO: commented due to nondeterminism in exists rule, fix later.


# ==============================
# Z3 Translation Tests
# ==============================


class TestNegateRuleZ3:
    """Test Z3 translation of negated concepts."""

    def test_concrete_negation_translation(self, negate_rule, is_even):
        """Test Z3 translation of negating a concrete number."""
        # Create is_odd using negate rule
        is_odd = negate_rule.apply(is_even)

        # Test the translation of the negated concept
        program = is_odd.to_z3(Nat(1))
        result = program.run()
        assert result.proved

        program = is_odd.to_z3(Nat(2))
        result = program.run()
        assert not result.proved
    
    def test_negation_translation_multiple_args(self, negate_rule, greater_than):
        """Test Z3 translation of negating a concept with multiple arguments."""
        not_greater_than = negate_rule.apply(greater_than)
        program = not_greater_than.to_z3(Nat(1), Nat(2))
        result = program.run()
        assert result.proved

        program = not_greater_than.to_z3(Nat(2), Nat(1))
        result = program.run()
        assert not result.proved
