import pytest
from frame.knowledge_base.entities import (
    Concept,
    ConceptType,
    ExampleType,
    ExampleStructure,
    Nat,
)

from frame.productions.concepts.compose import ComposeRule
from frame.tools.z3_template import Z3Template
# ==============================
# Fixtures
# ==============================


@pytest.fixture
def compose_rule():
    """Create a fresh ComposeRule instance for each test."""
    return ComposeRule()


# ==============================
# Basic Rule Tests
# ==============================


class TestComposeRuleBasics:
    """Test basic functionality of the ComposeRule class."""

    def test_get_input_types(self, compose_rule):
        """Test that get_input_types returns the correct types."""
        input_types = compose_rule.get_input_types()
        assert isinstance(input_types, list), "Should return a list"
        assert len(input_types) > 0, "Should return at least one input type combination"

        # Each input type should be a list of tuples
        for type_combo in input_types:
            assert isinstance(type_combo, list), "Each combination should be a list"
            for type_spec in type_combo:
                assert isinstance(type_spec, tuple), "Each type spec should be a tuple"
                assert len(type_spec) == 2, "Each type spec should have 2 elements"

    def test_can_apply_valid_function_composition(self, compose_rule, add3, add2):
        """Test that can_apply returns True for valid function composition."""
        assert compose_rule.can_apply(
            add3, add2, output_to_input_map={0: 0}
        ), "Should allow composition of add3 ∘ add2"

    def test_can_apply_valid_predicate_composition(
        self, compose_rule, is_even, is_square
    ):
        """Test that can_apply returns True for valid predicate composition."""
        assert compose_rule.can_apply(
            is_even, is_square, shared_vars={0: 0}
        ), "Should allow composition of is_even AND is_square"

    def test_can_apply_valid_function_to_predicate(self, compose_rule, square, is_even):
        """Test that can_apply returns True for valid function-to-predicate composition."""
        assert compose_rule.can_apply(
            square, is_even, output_to_input_map={0: 0}
        ), "Should allow composition of is_even(square(x))"


# ==============================
# Function Composition Tests
# ==============================


class TestFunctionComposition:
    """Test composition of functions."""

    def test_simple_unary_composition(self, compose_rule, add2, add3):
        """Test composing two unary functions (add3 ∘ add2 = add5)."""
        # Create add3 ∘ add2 composition
        add5 = compose_rule.apply(add3, add2, output_to_input_map={0: 0})

        # Test the composition name and computation
        assert add5.name == "compose_(add3_with_add2_output_to_input_map={0: 0})"
        assert add5.compute(0) == 5, "0 -> 2 -> 5"
        assert add5.compute(1) == 6, "1 -> 3 -> 6"
        assert add5.compute(2) == 7, "2 -> 4 -> 7"
        assert add5.compute(3) == 8, "3 -> 5 -> 8"

        # Test that examples were properly transformed
        # For each input x, if add2(x) = y and add3(y) = z, then add5(x) = z
        examples = add5.examples.get_examples()
        expected_examples = [
            (0, 5),  # 0 -> 2 -> 5
            (1, 6),  # 1 -> 3 -> 6
            (2, 7),  # 2 -> 4 -> 7
            (3, 8),  # 3 -> 5 -> 8
        ]
        for expected in expected_examples:
            assert any(
                ex.value == expected for ex in examples
            ), f"Should have example {expected}"

    def test_binary_function_composition(self, compose_rule, multiply):
        """Test composing multiplication with itself to create triple product."""
        # Map output of first multiplication to first input of second multiplication
        triple_product = compose_rule.apply(
            multiply, multiply, output_to_input_map={0: 0}
        )

        # Test the composition
        assert (
            triple_product.name
            == "compose_(multiply_with_multiply_output_to_input_map={0: 0})"
        )
        assert triple_product.compute(2, 3, 4) == 24, "2 * 3 * 4 should be 24"
        assert triple_product.compute(1, 2, 3) == 6, "1 * 2 * 3 should be 6"
        assert triple_product.compute(2, 2, 2) == 8, "2 * 2 * 2 should be 8"

        # Test example transformation
        examples = triple_product.examples.get_examples()
        expected_examples = [
            (2, 3, 4, 24),  # (2 * 3) * 4 = 24
            (1, 2, 3, 6),  # (1 * 2) * 3 = 6
            (2, 2, 2, 8),  # (2 * 2) * 2 = 8
        ]
        for expected in expected_examples:
            assert any(
                ex.value == expected for ex in examples
            ), f"Should have example {expected}"

    def test_nested_function_composition(self, compose_rule, add2, add3, square):
        """Test nested function compositions (square ∘ add3 ∘ add2)."""
        # First compose add3 and add2: add3 ∘ add2 = add5
        # Note: In ComposeRule, the parameters are named confusingly
        # The first parameter is treated as the outer function in the composition
        # The second parameter is treated as the inner function
        add5 = compose_rule.apply(add3, add2, output_to_input_map={0: 0})

        # Test the intermediate composition
        assert add5.name == "compose_(add3_with_add2_output_to_input_map={0: 0})"
        assert add5.compute(0) == 5, "0 -> 2 -> 5"
        assert add5.compute(1) == 6, "1 -> 3 -> 6"

        # Then compose with square: square ∘ add5
        # square_of_add5 = compose_rule.apply(square, add5, output_to_input_map={0: 0})
        square_of_add5 = compose_rule.apply(add5, square, output_to_input_map={0: 0})

        # Test the final composition
        assert (
            square_of_add5.name
            == "compose_(compose_(add3_with_add2_output_to_input_map={0: 0})_with_square_output_to_input_map={0: 0})"
        )

        # Verify the computation: square(add5(x))
        # For x=0: add5(0)=5, square(5)=25
        assert square_of_add5.compute(0) == 25, "0 -> 5 -> 25"
        # For x=1: add5(1)=6, square(6)=36
        assert square_of_add5.compute(1) == 36, "1 -> 6 -> 36"
        # For x=2: add5(2)=7, square(7)=49
        assert square_of_add5.compute(2) == 49, "2 -> 7 -> 49"
        # For x=3: add5(3)=8, square(8)=64
        assert square_of_add5.compute(3) == 64, "3 -> 8 -> 64"

        # Test example transformation
        examples = square_of_add5.examples.get_examples()
        expected_examples = [
            (0, 25),  # 0 -> 5 -> 25
            (1, 36),  # 1 -> 6 -> 36
            (2, 49),  # 2 -> 7 -> 49
            (3, 64),  # 3 -> 8 -> 64
        ]
        for expected in expected_examples:
            assert any(
                ex.value == expected for ex in examples
            ), f"Should have example {expected}"


# ==============================
# Predicate Composition Tests
# ==============================


class TestPredicateComposition:
    """Test composition of predicates."""

    def test_simple_predicate_composition(self, compose_rule, is_even, is_square):
        """Test composing two predicates (is_even AND is_square)."""
        # Create even_square predicate
        even_square = compose_rule.apply(is_even, is_square, shared_vars={0: 0})

        # Test the composition name
        assert even_square.name == "compose_(is_even_with_is_square_shared_vars={0: 0})"

        # Test computation for various cases
        # True cases (both even and square)
        assert even_square.compute(0), "0 is even and square"
        assert even_square.compute(4), "4 is even and square"
        assert even_square.compute(16), "16 is even and square"

        # False cases
        assert not even_square.compute(2), "2 is even but not square"
        assert not even_square.compute(9), "9 is square but not even"
        assert not even_square.compute(7), "7 is neither even nor square"
        assert not even_square.compute(25), "25 is square but not even"

        # Test example transformation
        examples = even_square.examples.get_examples()
        expected_examples = [(0,), (4,), (16,)]  # Numbers that are both even and square
        for expected in expected_examples:
            assert any(
                ex.value == expected for ex in examples
            ), f"Should have example {expected}"

        # Test nonexample transformation
        nonexamples = even_square.examples.get_nonexamples()
        expected_nonexamples = [
            (2,),  # Even but not square
            (9,),  # Square but not even
            (3,),  # Neither even nor square
            (25,),  # Square but not even
        ]
        for expected in expected_nonexamples:
            assert any(
                ex.value == expected for ex in nonexamples
            ), f"Should have nonexample {expected}"


# ==============================
# Function-to-Predicate Composition Tests
# ==============================


class TestFunctionToPredicateComposition:
    """Test composition of functions into predicates."""

    def test_function_to_predicate_composition(self, compose_rule, square, is_even):
        """Test composing a function into a predicate (is_even(square(x)))."""
        # Create a predicate that tests if the square of a number is even
        square_is_even = compose_rule.apply(square, is_even, output_to_input_map={0: 0})

        # Test the composition name
        assert (
            square_is_even.name
            == "compose_(square_with_is_even_output_to_input_map={0: 0})"
        )

        # Test computation for various cases
        # True cases (square is even)
        assert square_is_even.compute(0), "0^2 = 0 is even"
        assert square_is_even.compute(2), "2^2 = 4 is even"
        assert square_is_even.compute(4), "4^2 = 16 is even"
        assert square_is_even.compute(6), "6^2 = 36 is even"

        # False cases (square is odd)
        assert not square_is_even.compute(1), "1^2 = 1 is odd"
        assert not square_is_even.compute(3), "3^2 = 9 is odd"
        assert not square_is_even.compute(5), "5^2 = 25 is odd"

        # Test example transformation
        examples = square_is_even.examples.get_examples()
        expected_examples = [(0,), (2,), (4,), (6,)]  # Numbers whose square is even
        for expected in expected_examples:
            assert any(
                ex.value == expected for ex in examples
            ), f"Should have example {expected}"

        # Test nonexample transformation
        nonexamples = square_is_even.examples.get_nonexamples()
        expected_nonexamples = [(1,), (3,), (5,)]  # Numbers whose square is odd
        for expected in expected_nonexamples:
            assert any(
                ex.value == expected for ex in nonexamples
            ), f"Should have nonexample {expected}"

    def test_nested_function_to_predicate(self, compose_rule, add2, square, is_even):
        """Test composing multiple functions with a predicate (is_even(square(add2(x))))."""
        # First compose square and add2: square ∘ add2
        # Note: In ComposeRule, the parameters are named confusingly
        # The first parameter is treated as the outer function in the composition
        # The second parameter is treated as the inner function
        square_of_add2 = compose_rule.apply(add2, square, output_to_input_map={0: 0})

        # Test the intermediate composition
        assert (
            square_of_add2.name
            == "compose_(add2_with_square_output_to_input_map={0: 0})"
        )

        # Verify the computation: square(add2(x))
        # For x=0: add2(0)=2, square(2)=4
        assert square_of_add2.compute(0) == 4, "0 -> 2 -> 4"
        # For x=1: add2(1)=3, square(3)=9
        assert square_of_add2.compute(1) == 9, "1 -> 3 -> 9"
        # For x=2: add2(2)=4, square(4)=16
        assert square_of_add2.compute(2) == 16, "2 -> 4 -> 16"
        # For x=3: add2(3)=5, square(5)=25
        assert square_of_add2.compute(3) == 25, "3 -> 5 -> 25"

        is_even_square_of_add2 = compose_rule.apply(
            square_of_add2, is_even, output_to_input_map={0: 0}
        )
        assert (
            is_even_square_of_add2.name
            == "compose_(compose_(add2_with_square_output_to_input_map={0: 0})_with_is_even_output_to_input_map={0: 0})"
        )

        # Test computation for various cases
        # True cases (square of add2 is even)
        assert is_even_square_of_add2.compute(0), "0 -> 4 -> True"
        assert is_even_square_of_add2.compute(2), "2 -> 16 -> True"
        assert is_even_square_of_add2.compute(4), "4 -> 16 -> True"


# ==============================
# Error Case Tests
# ==============================


class TestComposeRuleErrors:
    """Test error cases for the ComposeRule class."""

    def test_invalid_function_composition_types(self, compose_rule, add2, is_even):
        """Test that function composition fails with incompatible types."""
        # Try to compose a function with a predicate incorrectly
        # add2 outputs a number, but is_even expects a number as input
        # However, we're mapping incorrectly
        assert not compose_rule.can_apply(
            add2, is_even, output_to_input_map={1: 0}
        ), "Should reject invalid output-to-input mapping"

    def test_invalid_predicate_composition(self, compose_rule, is_even, add2):
        """Test that predicate composition fails with non-predicates."""
        # Try to compose a predicate with a function using shared_vars
        assert not compose_rule.can_apply(
            is_even, add2, shared_vars={0: 0}
        ), "Should reject predicate composition with non-predicate"

    def test_missing_mapping(self, compose_rule, add2, add3):
        """Test that composition fails without required mapping."""
        # Try to compose without providing output_to_input_map
        assert not compose_rule.can_apply(
            add2, add3
        ), "Should reject composition without mapping"

        # Try to apply without mapping
        with pytest.raises(ValueError):
            compose_rule.apply(add2, add3)

    def test_invalid_function_to_predicate(self, compose_rule, is_even, add2):
        """Test that function-to-predicate composition fails with invalid mapping."""
        # Try to compose with invalid mapping
        assert not compose_rule.can_apply(
            is_even, add2, output_to_input_map={0: 0}
        ), "Should reject invalid function-to-predicate composition"


# ==============================
# Edge Case Tests
# ==============================


class TestComposeRuleEdgeCases:
    """Test edge cases for the ComposeRule class."""

    def test_identity_function_composition(self, compose_rule, add2):
        """Test composing a function with itself."""
        # add2 ∘ add2 = add4
        add4 = compose_rule.apply(add2, add2, output_to_input_map={0: 0})

        # Test the composition
        assert add4.name == "compose_(add2_with_add2_output_to_input_map={0: 0})"
        assert add4.compute(0) == 4, "0 -> 2 -> 4"
        assert add4.compute(1) == 5, "1 -> 3 -> 5"
        assert add4.compute(2) == 6, "2 -> 4 -> 6"

    def test_empty_examples(self, compose_rule):
        """Test composition with concepts that have no examples."""
        # Create concepts with no examples
        empty_func = Concept(
            name="empty_func",
            description="Function with no examples",
            symbolic_definition=lambda x: x + 1,
            computational_implementation=lambda x: x + 1,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
                input_arity=1,
            ),
        )

        empty_pred = Concept(
            name="empty_pred",
            description="Predicate with no examples",
            symbolic_definition=lambda x: x > 0,
            computational_implementation=lambda x: x > 0,
            example_structure=ExampleStructure(
                concept_type=ConceptType.PREDICATE,
                component_types=(ExampleType.NUMERIC,),
                input_arity=1,
            ),
        )

        # Test function composition
        composed_func = compose_rule.apply(
            empty_func, empty_func, output_to_input_map={0: 0}
        )
        assert (
            composed_func.name
            == "compose_(empty_func_with_empty_func_output_to_input_map={0: 0})"
        )
        assert composed_func.compute(0) == 2, "0 -> 1 -> 2"

        # Test function-to-predicate composition
        func_to_pred = compose_rule.apply(
            empty_func, empty_pred, output_to_input_map={0: 0}
        )
        assert (
            func_to_pred.name
            == "compose_(empty_func_with_empty_pred_output_to_input_map={0: 0})"
        )
        assert func_to_pred.compute(0), "0 -> 1 -> True"
        assert not func_to_pred.compute(-1), "-1 -> 0 -> False"


# ==============================
# Z3 Translation Tests
# ==============================


class TestComposeRuleZ3:
    """Test Z3 translation of compose concepts."""

    def test_function_composition_translation(self, compose_rule, add2, add3):
        """Test Z3 translation of function composition."""
        add5 = compose_rule.apply(add2, add3, output_to_input_map={0: 0})
        program = add5.to_z3(Nat(1)).program
        predicate_program = Z3Template(
            f"""
            params 1;
            bounded params 0;
            f_0 := Func(
            {program.dsl()}
            );
            ReturnExpr None;
            ReturnPred f_0(x_0=x_0) == 6;
            """
        )
        predicate_program.set_args(Nat(1))
        result = predicate_program.run()
        assert result.proved

        predicate_program.set_args(Nat(2))
        result = predicate_program.run()
        assert not result.proved

    def test_predicate_composition_translation(self, compose_rule, is_even, is_square):
        """Test Z3 translation of predicate composition."""
        is_even_square = compose_rule.apply(is_even, is_square, shared_vars={0: 0})
        program = is_even_square.to_z3(Nat(4))
        result = program.run()
        assert result.proved

        program = is_even_square.to_z3(Nat(5))
        result = program.run()
        assert not result.proved

    def test_function_to_predicate_composition_translation(self, compose_rule, square, is_even):
        """Test Z3 translation of function-to-predicate composition."""
        square_is_even = compose_rule.apply(square, is_even, output_to_input_map={0: 0})
        program = square_is_even.to_z3(Nat(2))
        result = program.run()
        assert result.proved

        program = square_is_even.to_z3(Nat(5))
        result = program.run()
        assert not result.proved