import pytest

from frame.productions.concepts.map_iterate import MapIterateRule
from frame.knowledge_base.entities import (
    Concept,
    ConceptType,
    Nat,
)


@pytest.fixture
def map_iterate_rule():
    """Create a MapIterateRule instance for testing."""
    return MapIterateRule()


@pytest.fixture
def iterated_successor(map_iterate_rule, successor_concept):
    """Create an iterated successor concept (equivalent to addition)."""
    return map_iterate_rule.apply(successor_concept)


@pytest.fixture
def iterated_addition(map_iterate_rule, add_two_numbers, zero_concept):
    """Create an iterated addition concept (equivalent to multiplication)."""
    return map_iterate_rule.apply(add_two_numbers, zero_concept)


@pytest.fixture
def iterated_multiplication(map_iterate_rule, multiply, one_concept):
    """Create an iterated multiplication concept (equivalent to power)."""
    return map_iterate_rule.apply(multiply, one_concept)


class TestMapIterateRuleBasics:
    """Test basic functionality of the MapIterateRule."""

    def test_get_input_types(self, map_iterate_rule):
        """Test that get_input_types returns the correct input types."""
        input_types = map_iterate_rule.get_input_types()
        assert len(input_types) == 2, "Should return two input type alternatives"

        # First alternative: Single unary function
        assert len(input_types[0]) == 1
        assert input_types[0][0][0] == Concept
        assert input_types[0][0][1] == ConceptType.FUNCTION

        # Second alternative: Binary function + constant
        assert len(input_types[1]) == 2
        assert input_types[1][0][0] == Concept
        assert input_types[1][0][1] == ConceptType.FUNCTION
        assert input_types[1][1][0] == Concept
        assert isinstance(input_types[1][1][1], list)
        assert ConceptType.FUNCTION in input_types[1][1][1]
        assert ConceptType.CONSTANT in input_types[1][1][1]

    def test_can_apply_unary_function(self, map_iterate_rule, successor_concept):
        """Test the can_apply method with a unary function."""
        assert map_iterate_rule.can_apply(successor_concept)

    def test_can_apply_binary_function_with_accumulator(
        self, map_iterate_rule, add_two_numbers, zero_concept
    ):
        """Test the can_apply method with a binary function and accumulator."""
        assert map_iterate_rule.can_apply(add_two_numbers, zero_concept)

    def test_can_apply_invalid_inputs(
        self, map_iterate_rule, add_two_numbers, successor_concept, zero_concept
    ):
        """Test the can_apply method with invalid inputs."""
        # Binary function without accumulator
        assert not map_iterate_rule.can_apply(add_two_numbers)

        # Unary function with accumulator (should be invalid)
        assert not map_iterate_rule.can_apply(successor_concept, zero_concept)

        # Non-function input
        assert not map_iterate_rule.can_apply(zero_concept)

        # Too many inputs
        assert not map_iterate_rule.can_apply(
            add_two_numbers, zero_concept, successor_concept
        )


class TestUnaryFunctionIteration:
    """Test the iteration of unary functions."""

    def test_successor_iteration(self, map_iterate_rule, successor_concept):
        """Test iteration of successor function (should produce addition)."""
        addition = map_iterate_rule.apply(successor_concept)

        assert addition.name == "iterate_(successor)"
        assert addition.description == "Applies successor iteratively"
        assert addition.get_input_arity() == 2

        # Test computation for various inputs
        assert addition.compute(3, 4) == 7  # 3 + 4 = 7
        assert addition.compute(5, 0) == 5  # 5 + 0 = 5
        assert addition.compute(0, 7) == 7  # 0 + 7 = 7
        assert addition.compute(10, 20) == 30  # 10 + 20 = 30

    def test_add2_iteration(self, map_iterate_rule, add2):
        """Test iteration of add2 function."""
        iterated_add2 = map_iterate_rule.apply(add2)

        assert iterated_add2.name == "iterate_(add2)"
        assert iterated_add2.get_input_arity() == 2

        # Test computation for various inputs
        assert iterated_add2.compute(3, 2) == 7  # 3 + (2*2) = 7
        assert iterated_add2.compute(5, 0) == 5  # 5 + (0*2) = 5
        assert iterated_add2.compute(0, 3) == 6  # 0 + (3*2) = 6


class TestBinaryFunctionIteration:
    """Test the iteration of binary functions with accumulators."""

    def test_addition_with_zero(self, map_iterate_rule, add_two_numbers, zero_concept):
        """Test iteration of addition with zero (should produce multiplication)."""
        multiplication = map_iterate_rule.apply(add_two_numbers, zero_concept)

        assert multiplication.name == "iterate_(add_two_numbers_with_zero)"
        assert multiplication.description == "Applies add_two_numbers iteratively"
        assert multiplication.get_input_arity() == 2

        # Test computation for various inputs
        assert multiplication.compute(3, 4) == 12  # 3 * 4 = 12
        assert multiplication.compute(5, 0) == 0  # 5 * 0 = 0
        assert multiplication.compute(0, 7) == 0  # 0 * 7 = 0
        assert multiplication.compute(6, 7) == 42  # 6 * 7 = 42

    def test_multiplication_with_one(self, map_iterate_rule, multiply, one_concept):
        """Test iteration of multiplication with one (should produce power)."""
        power = map_iterate_rule.apply(multiply, one_concept)

        assert power.name == "iterate_(multiply_with_one)"
        assert power.description == "Applies multiply iteratively"
        assert power.get_input_arity() == 2

        # Test computation for various inputs
        assert power.compute(2, 3) == 8  # 2^3 = 8
        assert power.compute(3, 2) == 9  # 3^2 = 9
        assert power.compute(5, 0) == 1  # 5^0 = 1
        assert power.compute(0, 5) == 0  # 0^5 = 0
        assert power.compute(1, 10) == 1  # 1^10 = 1


class TestExampleTransformation:
    """Test the transformation of examples when applying the rule."""

    def test_unary_function_examples(self, map_iterate_rule, successor_concept):
        """Test that examples are correctly transformed for unary functions."""
        addition = map_iterate_rule.apply(successor_concept)

        # Get examples from the generated concept
        examples = addition.examples.get_examples()
        example_values = {ex.value for ex in examples}

        # Check that we have appropriate examples
        # Iterations with 0, 1, and 2 for some input value
        assert any(ex.value[1] == 0 for ex in examples)
        assert any(ex.value[1] == 1 for ex in examples)
        assert any(ex.value[1] == 2 for ex in examples)

        # Verify some specific examples that should exist
        # For input 0 with iterations 0, 1, 2
        assert (0, 0, 0) in example_values
        assert (0, 1, 1) in example_values
        assert (0, 2, 2) in example_values

    def test_binary_function_examples(
        self, map_iterate_rule, add_two_numbers, zero_concept
    ):
        """Test that examples are correctly transformed for binary functions."""
        multiplication = map_iterate_rule.apply(add_two_numbers, zero_concept)

        # Get examples from the generated concept
        examples = multiplication.examples.get_examples()
        example_values = {ex.value for ex in examples}

        # Verify iteration counts
        assert any(ex.value[1] == 0 for ex in examples)
        assert any(ex.value[1] == 1 for ex in examples)
        assert any(ex.value[1] == 2 for ex in examples)

        # Verify some specific examples
        for i in range(10):
            assert (i, 0, 0) in example_values  # i * 0 = 0
            if i < 5:  # Limit to a few examples for iterations 1 and 2
                assert (i, 1, i) in example_values  # i * 1 = i
                assert (i, 2, i * 2) in example_values  # i * 2 = 2i


class TestMapIterateRuleErrors:
    """Test error cases for the MapIterateRule."""

    def test_invalid_concept_types(self, map_iterate_rule, is_even):
        """Test that the rule rejects concepts that aren't functions."""
        # Predicates aren't valid inputs
        assert not map_iterate_rule.can_apply(is_even)

        # Non-numeric domain/codomain concepts shouldn't work
        with pytest.raises(ValueError):
            map_iterate_rule.apply(is_even)

    def test_missing_accumulator(self, map_iterate_rule, add_two_numbers):
        """Test that binary functions require an accumulator."""
        assert not map_iterate_rule.can_apply(add_two_numbers)

        with pytest.raises(ValueError):
            map_iterate_rule.apply(add_two_numbers)

    def test_invalid_accumulator(self, map_iterate_rule, add_two_numbers, is_even):
        """Test that accumulators must be constants or 0-arity functions."""
        assert not map_iterate_rule.can_apply(add_two_numbers, is_even)

        with pytest.raises(ValueError):
            map_iterate_rule.apply(add_two_numbers, is_even)

    def test_negative_iterations(self, iterated_successor):
        """Test that negative iteration counts raise an error."""
        with pytest.raises(ValueError):
            iterated_successor.compute(3, -1)


class TestValidParameterizations:
    """Test the get_valid_parameterizations method."""

    def test_unary_function_parameterizations(
        self, map_iterate_rule, successor_concept
    ):
        """Test parameterizations for unary functions."""
        params = map_iterate_rule.get_valid_parameterizations(successor_concept)
        assert len(params) == 1
        assert params[0] == {}  # No parameters needed

    def test_binary_function_parameterizations(
        self, map_iterate_rule, add_two_numbers, zero_concept
    ):
        """Test parameterizations for binary functions with accumulators."""
        params = map_iterate_rule.get_valid_parameterizations(
            add_two_numbers, zero_concept
        )
        assert len(params) == 1
        assert params[0] == {}  # No parameters needed

    def test_invalid_parameterizations(
        self, map_iterate_rule, add_two_numbers, is_even, one_concept
    ):
        """Test parameterizations for invalid inputs."""
        # Binary function without accumulator
        params = map_iterate_rule.get_valid_parameterizations(add_two_numbers)
        assert len(params) == 0

        # Non-function input
        params = map_iterate_rule.get_valid_parameterizations(is_even)
        assert len(params) == 0

        # Invalid accumulator
        params = map_iterate_rule.get_valid_parameterizations(add_two_numbers, is_even)
        assert len(params) == 0


class TestVerificationCapabilities:
    """Test the determine_verification_capabilities method."""

    def test_unary_function_capabilities(self, map_iterate_rule, successor_concept):
        """Test verification capabilities for unary functions."""
        can_add_examples, can_add_nonexamples = (
            map_iterate_rule.determine_verification_capabilities(successor_concept)
        )

        # Should inherit capabilities from the base concept
        assert can_add_examples == successor_concept.can_add_examples
        assert can_add_nonexamples == successor_concept.can_add_nonexamples

    def test_binary_function_capabilities(
        self, map_iterate_rule, add_two_numbers, zero_concept
    ):
        """Test verification capabilities for binary functions with accumulators."""
        can_add_examples, can_add_nonexamples = (
            map_iterate_rule.determine_verification_capabilities(
                add_two_numbers, zero_concept
            )
        )

        # Should be the logical AND of both concepts' capabilities
        expected_can_add_examples = (
            add_two_numbers.can_add_examples and zero_concept.can_add_examples
        )
        expected_can_add_nonexamples = (
            add_two_numbers.can_add_nonexamples and zero_concept.can_add_nonexamples
        )

        assert can_add_examples == expected_can_add_examples
        assert can_add_nonexamples == expected_can_add_nonexamples


class TestMapIterateIntegration:
    """Integration tests for the MapIterateRule."""

    def test_successor_to_addition(self, map_iterate_rule, successor_concept):
        """Test iterating successor to produce addition."""
        addition = map_iterate_rule.apply(successor_concept)

        # Test on range of values
        for a in range(5):
            for b in range(5):
                assert addition.compute(a, b) == a + b

        # Check symbolic definition
        assert "Fold" in str(addition.symbolic(Nat(3), Nat(4)))

    def test_addition_to_multiplication(
        self, map_iterate_rule, add_two_numbers, zero_concept
    ):
        """Test iterating addition to produce multiplication."""
        multiplication = map_iterate_rule.apply(add_two_numbers, zero_concept)

        # Test on range of values
        for a in range(5):
            for b in range(5):
                assert multiplication.compute(a, b) == a * b

        # Check symbolic definition
        assert "Fold" in str(multiplication.symbolic(Nat(3), Nat(4)))

    def test_multiplication_to_power(self, map_iterate_rule, multiply, one_concept):
        """Test iterating multiplication to produce power."""
        power = map_iterate_rule.apply(multiply, one_concept)

        # Test on range of values
        for a in range(5):
            for b in range(5):
                if b == 0:
                    assert power.compute(a, b) == 1
                else:
                    assert power.compute(a, b) == a**b

        # Check symbolic definition
        assert "Fold" in str(power.symbolic(Nat(3), Nat(4)))

    def test_nested_iteration(
        self,
        map_iterate_rule,
        iterated_successor,
        zero_concept,
    ):
        """Test nested application of the rule (succession → addition → multiplication)."""
        # First we apply the rule to get addition from successor
        addition = iterated_successor

        # Then we apply the rule to get multiplication from addition
        multiplication = map_iterate_rule.apply(addition, zero_concept)

        # Verify the result
        assert multiplication.compute(3, 4) == 12  # 3 * 4 = 12
        assert multiplication.compute(5, 0) == 0  # 5 * 0 = 0
        assert multiplication.compute(0, 7) == 0  # 0 * 7 = 0
