import pytest

from frame.productions.concepts.match import MatchRule
from frame.knowledge_base.entities import (
    ConceptApplication,
    Concept,
    ExampleType,
    ExampleStructure,
    ConceptType,
    Nat,
)
from frame.knowledge_base.demonstrations import multiplication, addition, SetCardinality
from frame.tools.z3_template import Z3Template

@pytest.fixture
def match_rule():
    """Create a MatchRule instance for testing."""
    return MatchRule()


class TestMatchRuleBasics:
    """Basic tests for the MatchRule class."""

    def test_get_input_types(self, match_rule):
        """Test that get_input_types returns the expected types."""
        input_types = match_rule.get_input_types()
        assert len(input_types) == 1
        assert input_types[0][0] == Concept
        assert input_types[0][1] == [ConceptType.FUNCTION, ConceptType.PREDICATE]

    def test_can_apply_valid_inputs(self, match_rule, multiply):
        """Test that can_apply returns True for valid inputs."""
        assert match_rule.can_apply(multiply, indices_to_match=[0, 1], verbose=False)

    def test_basic_rule_application(self, match_rule, multiply):
        """Test basic application of the match rule."""
        square = match_rule.apply(multiply, indices_to_match=[0, 1])
        assert square.name == "matched_(multiply_indices_[0, 1])"
        assert square.examples.example_structure.input_arity == 1
        assert square.compute(4) == 16
        assert square.compute(5) == 25


class TestMatchRuleParameterizations:
    """Tests for parameterization generation."""

    def test_binary_concept_parameterizations(self, match_rule, multiply):
        """Test parameterizations for binary concepts."""
        params = match_rule.get_valid_parameterizations(multiply)
        assert len(params) == 1
        assert params[0]["indices_to_match"] == [0, 1]

    def test_ternary_concept_parameterizations(self, match_rule, multiply3):
        """Test parameterizations for ternary concepts."""
        params = match_rule.get_valid_parameterizations(multiply3)
        assert len(params) == 4
        # Should include [0,1], [0,2], [1,2], and [0,1,2]
        param_sets = [set(p["indices_to_match"]) for p in params]
        assert {0, 1} in param_sets
        assert {0, 2} in param_sets
        assert {1, 2} in param_sets
        assert {0, 1, 2} in param_sets

    def test_many_args_parameterizations(self, match_rule, many_args_concept):
        """Test parameterizations for concepts with many arguments."""
        params = match_rule.get_valid_parameterizations(many_args_concept)
        assert len(params) == 100  # Should be limited to 100
        # First parameterizations should have the smallest number of indices
        assert len(params[0]["indices_to_match"]) == 2

    def test_invalid_parameterizations(self, match_rule, multiply):
        """Test parameterizations for invalid concepts."""
        # Create a single-argument concept
        single_arg = Concept(
            name="single_arg",
            description="A function with a single argument",
            symbolic_definition=lambda a: ConceptApplication(multiply, a, a),
            computational_implementation=lambda a: a * a,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
                input_arity=1,
            ),
        )
        params = match_rule.get_valid_parameterizations(single_arg)
        assert (
            len(params) == 0
        )  # No valid parameterizations for single-argument concepts


class TestBinaryMatching:
    """Tests for matching arguments in binary concepts."""

    def test_multiply_matching(self, match_rule, multiply):
        """Test matching both arguments in multiplication (creating square)."""
        square = match_rule.apply(multiply, indices_to_match=[0, 1])
        assert square.compute(2) == 4
        assert square.compute(3) == 9
        assert square.compute(4) == 16
        assert square.compute(5) == 25

        # Check that examples were properly transformed
        examples = [ex.value for ex in square.examples.get_examples()]
        assert (3, 9) in examples  # From (3, 3, 9)
        assert (4, 16) in examples  # From (4, 4, 16)

    def test_addition_matching(self, match_rule, add_two_numbers):
        """Test matching both arguments in addition (creating double)."""
        double = match_rule.apply(add_two_numbers, indices_to_match=[0, 1])
        assert double.compute(2) == 4
        assert double.compute(3) == 6
        assert double.compute(4) == 8
        assert double.compute(5) == 10

        # Check that examples were properly transformed
        examples = [ex.value for ex in double.examples.get_examples()]
        assert (1, 2) in examples  # From (1, 1, 2)
        assert (4, 8) in examples  # From (4, 4, 8)


class TestTernaryMatching:
    """Tests for matching arguments in ternary concepts."""

    def test_matching_first_two_arguments(self, match_rule, multiply3):
        """Test matching first two arguments (a=b)."""
        square_times = match_rule.apply(multiply3, indices_to_match=[0, 1])
        assert square_times.compute(2, 3) == 12  # (2*2)*3 = 12
        assert square_times.compute(3, 2) == 18  # (3*3)*2 = 18
        assert square_times.compute(4, 1) == 16  # (4*4)*1 = 16

        # Check that examples were properly transformed
        examples = [ex.value for ex in square_times.examples.get_examples()]
        assert (2, 2, 8) in examples  # From (2, 2, 2, 8)

    def test_matching_last_two_arguments(self, match_rule, multiply3):
        """Test matching last two arguments (b=c)."""
        times_square = match_rule.apply(multiply3, indices_to_match=[1, 2])
        assert times_square.compute(2, 3) == 18  # 2*(3*3) = 18
        assert times_square.compute(3, 2) == 12  # 3*(2*2) = 12
        assert times_square.compute(1, 4) == 16  # 1*(4*4) = 16

        # Check that examples were properly transformed
        examples = [ex.value for ex in times_square.examples.get_examples()]
        assert (2, 2, 8) in examples  # From (2, 2, 2, 8)

    def test_matching_first_and_last_arguments(self, match_rule, multiply3):
        """Test matching first and last arguments (a=c)."""
        first_last_match = match_rule.apply(multiply3, indices_to_match=[0, 2])
        assert first_last_match.compute(2, 3) == 12  # 2*3*2 = 12
        assert first_last_match.compute(3, 2) == 18  # 3*2*3 = 18
        assert first_last_match.compute(4, 1) == 16  # 4*1*4 = 16

        # Check that examples were properly transformed
        examples = [ex.value for ex in first_last_match.examples.get_examples()]
        assert (2, 2, 8) in examples  # From (2, 2, 2, 8)

    def test_matching_all_three_arguments(self, match_rule, multiply3):
        """Test matching all three arguments (a=b=c)."""
        cube = match_rule.apply(multiply3, indices_to_match=[0, 1, 2])
        assert cube.compute(2) == 8  # 2*2*2 = 8
        assert cube.compute(3) == 27  # 3*3*3 = 27
        assert cube.compute(4) == 64  # 4*4*4 = 64

        # Check that examples were properly transformed
        examples = [ex.value for ex in cube.examples.get_examples()]
        assert (2, 8) in examples  # From (2, 2, 2, 8)
        assert (3, 27) in examples  # From (3, 3, 3, 27)


class TestExampleTransformation:
    """Tests for example transformation."""

    def test_binary_example_filtering(self, match_rule, multiply):
        """Test example filtering for binary concepts."""
        square = match_rule.apply(multiply, indices_to_match=[0, 1])

        # Check examples
        examples = [ex.value for ex in square.examples.get_examples()]

        assert (3, 9) in examples  # From (3, 3, 9)
        assert (4, 16) in examples  # From (4, 4, 16)

        # Check that examples with unequal arguments were filtered out
        assert (2, 6) not in examples  # From (2, 3, 6)
        assert (2, 10) not in examples  # From (2, 5, 10)

    def test_ternary_example_filtering(self, match_rule, multiply3):
        """Test example filtering for ternary concepts."""
        cube = match_rule.apply(multiply3, indices_to_match=[0, 1, 2])

        # Check examples
        examples = [ex.value for ex in cube.examples.get_examples()]
        assert (2, 8) in examples  # From (2, 2, 2, 8)
        assert (3, 27) in examples  # From (3, 3, 3, 27)

        # Check that examples with unequal arguments were filtered out
        assert (2, 24) not in examples  # From (2, 3, 4, 24)
        assert (1, 6) not in examples  # From (1, 2, 3, 6)

    def test_nonexample_filtering(self, match_rule, multiply):
        """Test nonexample filtering."""
        square = match_rule.apply(multiply, indices_to_match=[0, 1])

        # Check nonexamples
        nonexamples = [ex.value for ex in square.examples.get_nonexamples()]
        assert (2, 5) in nonexamples  # From (2, 2, 5)


class TestMatchRuleErrors:
    """Tests for error handling."""

    def test_invalid_concept_types(self, match_rule):
        """Test invalid concept types."""

        # Create a non-concept entity
        dummy = Concept(
            name="dummy",
            description="A dummy concept",
            symbolic_definition=lambda a: a,
            computational_implementation=lambda a: a,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                input_arity=1,
                component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            ),
        )

        # Check that can_apply returns False
        assert not match_rule.can_apply(dummy, indices_to_match=[0, 1], verbose=False)

        # Check that apply raises ValueError
        with pytest.raises(ValueError):
            match_rule.apply(dummy, indices_to_match=[0, 1])

    def test_invalid_indices_to_match(self, match_rule, multiply):
        """Test invalid indices to match."""
        # Index out of range
        assert not match_rule.can_apply(
            multiply, indices_to_match=[0, 2], verbose=False
        )
        with pytest.raises(ValueError):
            match_rule.apply(multiply, indices_to_match=[0, 2])

        # Too few indices
        assert not match_rule.can_apply(multiply, indices_to_match=[0], verbose=False)
        with pytest.raises(ValueError):
            match_rule.apply(multiply, indices_to_match=[0])

    def test_type_mismatch(self, match_rule, mixed_types_concept):
        """Test type mismatch between indices."""
        # Indices 0 and 1 have different types (NUMERIC and SET)
        assert not match_rule.can_apply(
            mixed_types_concept, indices_to_match=[0, 1], verbose=False
        )
        with pytest.raises(ValueError):
            match_rule.apply(mixed_types_concept, indices_to_match=[0, 1])

    def test_insufficient_arity(self, match_rule):
        """Test insufficient arity."""
        # Create a single-argument concept
        single_arg = Concept(
            name="single_arg",
            description="A function with a single argument",
            symbolic_definition=lambda a: ConceptApplication(single_arg, a),
            computational_implementation=lambda a: a * 2,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
                input_arity=1,
            ),
        )

        # Check that can_apply returns False
        assert not match_rule.can_apply(
            single_arg, indices_to_match=[0, 0], verbose=False
        )

        # Check that apply raises ValueError
        with pytest.raises(ValueError):
            match_rule.apply(single_arg, indices_to_match=[0, 0])


class TestMatchRuleIntegration:
    """Integration tests."""

    def test_nested_matching(self, match_rule, multiply3):
        """Test nested matching (match on already matched concept)."""
        # First match: a=b
        square_times = match_rule.apply(multiply3, indices_to_match=[0, 1])

        assert not match_rule.can_apply(
            square_times,
            indices_to_match=[0],
            verbose=False,
        )

        # Second match: match the remaining arguments (now there's two inputs left)
        assert match_rule.can_apply(
            square_times, indices_to_match=[0, 1], verbose=False
        )

        cube = match_rule.apply(square_times, indices_to_match=[0, 1])
        assert cube.compute(2) == 8  # 2*2*2 = 8
        assert cube.compute(3) == 27  # 3*3*3 = 27
        assert cube.compute(4) == 64  # 4*4*4 = 64

        # Check that examples were properly transformed
        examples = [ex.value for ex in cube.examples.get_examples()]
        assert (2, 8) in examples

        assert not match_rule.can_apply(cube, indices_to_match=[0, 1, 2], verbose=False)

        # But we can create a new ternary concept from square_times
        square_times_plus = Concept(
            name="square_times_plus",
            description="Square of first argument times second argument plus third argument",
            symbolic_definition=lambda a, b, c: ConceptApplication(
                square_times_plus, a, b, c
            ),
            computational_implementation=lambda a, b, c: (a * a * b) + c,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=(
                    ExampleType.NUMERIC,
                    ExampleType.NUMERIC,
                    ExampleType.NUMERIC,
                    ExampleType.NUMERIC,
                ),
                input_arity=3,
            ),
        )

        # Now we can match the first two arguments
        assert match_rule.can_apply(
            square_times_plus, indices_to_match=[0, 1], verbose=False
        )
        square_square_plus = match_rule.apply(
            square_times_plus, indices_to_match=[0, 1]
        )

        # Test the resulting concept
        assert square_square_plus.compute(2, 3) == 11  # (2*2*2) + 3 = 11
        assert square_square_plus.compute(3, 1) == 28  # (3*3*3) + 1 = 28

    def test_with_predicate(self, match_rule, equals_predicate):
        """Test with a predicate concept."""
        print(
            match_rule.can_apply(
                equals_predicate, indices_to_match=[0, 1], verbose=True
            )
        )
        # For equals_predicate, matching both arguments should create a tautology
        tautology = match_rule.apply(equals_predicate, indices_to_match=[0, 1])

        # The resulting concept should always return True
        assert tautology.compute(1)
        assert tautology.compute(2)
        assert tautology.compute(100)

        # Check examples
        examples = [ex.value for ex in tautology.examples.get_examples()]
        assert (2,) in examples  # From (2, 2)
        assert (5,) in examples  # From (5, 5)
        assert (0,) in examples  # From (0, 0)

        # There should be no nonexamples (since a == a is always true)
        nonexamples = [ex.value for ex in tautology.examples.get_nonexamples()]
        assert len(nonexamples) == 0


# ==============================
# Z3 Translation Tests
# ==============================


class TestMatchRuleZ3:
    """Test Z3 translation of match concepts."""

    def test_two_args_match_translation(self, match_rule, add_two_numbers):
        """Test Z3 translation of matching two arguments."""
        # Create is_odd using negate rule
        double = match_rule.apply(add_two_numbers, indices_to_match=[0, 1])

        # Test the translation of the negated concept
        program = double.to_z3(Nat(1)).program
        predicate_program = Z3Template(
            f"""
            params 1;
            bounded params 0;
            f_0 := Func(
            {program.dsl()}
            );
            ReturnExpr None;
            ReturnPred f_0(x_0=x_0) == 2;
            """
        )
        predicate_program.set_args(Nat(1))
        result = predicate_program.run()
        assert result.proved

        predicate_program.set_args(Nat(2))
        result = predicate_program.run()
        assert not result.proved

    def test_three_args_match_translation(self, match_rule, add_three_numbers):
        """Test Z3 translation of matching three arguments."""
        # Create is_odd using negate rule
        matched_add_three_numbers = match_rule.apply(add_three_numbers, indices_to_match=[0, 2])

        # Test the translation of the negated concept
        program = matched_add_three_numbers.to_z3(Nat(1), Nat(2)).program
        predicate_program = Z3Template(
            f"""
            params 2;
            bounded params 0;
            f_0 := Func(
            {program.dsl()}
            );
            ReturnExpr None;
            ReturnPred f_0(x_0=x_0, x_1=x_1) == 4;
            """
        )
        predicate_program.set_args(Nat(1), Nat(2))
        result = predicate_program.run()
        assert result.proved

        predicate_program.set_args(Nat(1), Nat(3))
        result = predicate_program.run()
        assert not result.proved
