import pytest

from frame.productions.concepts.size import SizeRule
from frame.knowledge_base.entities import (
    Concept,
    ConceptType,
    ExampleType,
    ExampleStructure,
)

from frame.knowledge_base.entities import (
    ConceptApplication,
    Nat,
    Var,
)

from frame.knowledge_base.demonstrations import (
    addition,
)


@pytest.fixture
def size_rule():
    """Create a SizeRule instance for testing."""
    return SizeRule()


class TestSizeRuleBasics:
    """Basic tests for the SizeRule class."""

    def test_get_input_types(self, size_rule):
        """Test that get_input_types returns the expected types."""
        input_types = size_rule.get_input_types()
        assert len(input_types) == 1
        assert input_types[0][0] == Concept
        assert input_types[0][1] == [ConceptType.PREDICATE, ConceptType.FUNCTION]

    def test_can_apply_valid_set_of_divisors_concept(
        self, size_rule, set_of_divisors_concept
    ):
        """Test that can_apply returns True for valid set concept."""
        assert size_rule.can_apply(set_of_divisors_concept, verbose=False)
        # For set concepts, indices_to_quantify should not be provided
        assert not size_rule.can_apply(
            set_of_divisors_concept, indices_to_quantify=[0], verbose=False
        )

    def test_can_apply_valid_predicate(self, size_rule, divides_predicate):
        """Test that can_apply returns True for valid predicate with indices."""
        assert size_rule.can_apply(
            divides_predicate, indices_to_quantify=[0], verbose=False
        )
        assert size_rule.can_apply(
            divides_predicate, indices_to_quantify=[1], verbose=False
        )
        assert size_rule.can_apply(
            divides_predicate, indices_to_quantify=[0, 1], verbose=False
        )

    def test_basic_rule_application_set(self, size_rule, set_of_divisors_concept):
        """Test basic application of the size rule to a set concept."""
        size_of_divisors = size_rule.apply(set_of_divisors_concept)
        assert size_of_divisors.name == "size_of_(set_of_divisors_indices_None)"
        assert size_of_divisors.examples.example_structure.input_arity == 1
        assert size_of_divisors.compute(6) == 4  # 6 has 4 divisors: 1, 2, 3, 6
        assert size_of_divisors.compute(7) == 2  # 7 has 2 divisors: 1, 7

    def test_basic_rule_application_predicate(self, size_rule, divides_predicate):
        """Test basic application of the size rule to a predicate concept."""
        tau = size_rule.apply(divides_predicate, indices_to_quantify=[0])
        assert tau.name == "size_of_(divides_indices_[0])"
        assert tau.examples.example_structure.input_arity == 1
        # has no computational implementation for now


class TestSizeRuleParameterizations:
    """Tests for parameterization generation."""

    def test_set_of_divisors_concept_parameterizations(
        self, size_rule, set_of_divisors_concept
    ):
        """Test parameterizations for set concepts."""
        params = size_rule.get_valid_parameterizations(set_of_divisors_concept)
        assert len(params) == 1
        assert params[0] == {}  # Empty dict for set concepts

    def test_predicate_parameterizations(self, size_rule, divides_predicate):
        """Test parameterizations for predicate concepts."""
        params = size_rule.get_valid_parameterizations(divides_predicate)
        assert len(params) == 3
        # Should include all possible combinations of indices
        param_sets = [set(p.get("indices_to_quantify", [])) for p in params]
        assert {0} in param_sets
        assert {1} in param_sets
        assert {0, 1} in param_sets

    def test_multi_arg_parameterizations(self, size_rule, multi_arg_predicate):
        """Test parameterizations for predicates with many arguments."""
        params = size_rule.get_valid_parameterizations(multi_arg_predicate)
        assert len(params) == 7  # 2^3 - 1 = 7 possible non-empty subsets of {0,1,2}

        # Check that all possible combinations are included
        param_sets = [set(p.get("indices_to_quantify", [])) for p in params]
        assert {0} in param_sets
        assert {1} in param_sets
        assert {2} in param_sets
        assert {0, 1} in param_sets
        assert {0, 2} in param_sets
        assert {1, 2} in param_sets
        assert {0, 1, 2} in param_sets

    def test_invalid_parameterizations(self, size_rule):
        """Test parameterizations for invalid concepts."""
        # Create a non-set, non-predicate concept
        invalid_concept = Concept(
            name="invalid",
            description="A function that doesn't return a set",
            symbolic_definition=lambda x: ConceptApplication(
                addition, Var("x"), Nat(1)
            ),
            computational_implementation=lambda x: x + 1,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
                input_arity=1,
            ),
        )
        params = size_rule.get_valid_parameterizations(invalid_concept)
        assert len(params) == 0  # No valid parameterizations


class TestSetSizeRule:
    """Tests for applying size rule to set concepts."""

    def test_set_size_computation(self, size_rule, set_of_divisors_concept):
        """Test size computation for set concepts."""
        size_of_divisors = size_rule.apply(set_of_divisors_concept)

        # Test various inputs
        assert size_of_divisors.compute(1) == 1  # 1 has 1 divisor: 1
        assert size_of_divisors.compute(2) == 2  # 2 has 2 divisors: 1, 2
        assert size_of_divisors.compute(3) == 2  # 3 has 2 divisors: 1, 3
        assert size_of_divisors.compute(4) == 3  # 4 has 3 divisors: 1, 2, 4
        assert size_of_divisors.compute(6) == 4  # 6 has 4 divisors: 1, 2, 3, 6
        assert size_of_divisors.compute(8) == 4  # 8 has 4 divisors: 1, 2, 4, 8
        assert size_of_divisors.compute(12) == 6  # 12 has 6 divisors: 1, 2, 3, 4, 6, 12

    # TODO: This test is failing for some reason, fix later.
    # def test_set_example_transformation(self, size_rule, set_of_divisors_concept):
    #     """Test example transformation for set concepts."""
    #     size_of_divisors = size_rule.apply(set_of_divisors_concept)

    #     # Check examples
    #     examples = [ex.value for ex in size_of_divisors.examples.get_examples()]
    #     assert (1, 1) in examples  # From (1, {1})
    #     assert (2, 2) in examples  # From (2, {1, 2})
    #     assert (3, 2) in examples  # From (3, {1, 3})
    #     assert (4, 3) in examples  # From (4, {1, 2, 4})
    #     assert (6, 4) in examples  # From (6, {1, 2, 3, 6})
    #     assert (8, 4) in examples  # From (8, {1, 2, 4, 8})
    #     assert (12, 6) in examples  # From (12, {1, 2, 3, 4, 6, 12})


class TestPredicateSizeRule:
    """Tests for applying size rule to predicate concepts."""

    def test_single_arg_quantification(self, size_rule, divides_predicate):
        """Test quantifying over a single argument."""
        # Create tau function (counts divisors)
        tau = size_rule.apply(divides_predicate, indices_to_quantify=[0])
        assert tau.name == "size_of_(divides_indices_[0])"
        assert tau.examples.example_structure.input_arity == 1
        # has no computational implementation for now

        # Create sigma_0 function (counts numbers that divide n)
        sigma_0 = size_rule.apply(divides_predicate, indices_to_quantify=[1])
        assert sigma_0.name == "size_of_(divides_indices_[1])"
        assert sigma_0.examples.example_structure.input_arity == 1
        # has no computational implementation for now

    def test_multi_arg_quantification(self, size_rule, multi_arg_predicate):
        """Test quantifying over multiple arguments."""
        # Quantify over first two arguments
        size_ab = size_rule.apply(multi_arg_predicate, indices_to_quantify=[0, 1])
        assert size_ab.name == "size_of_(sum_equals_product_indices_[0, 1])"
        assert size_ab.examples.example_structure.input_arity == 1
        # has no computational implementation for now

        # Quantify over all arguments
        size_abc = size_rule.apply(multi_arg_predicate, indices_to_quantify=[0, 1, 2])
        assert size_abc.name == "size_of_(sum_equals_product_indices_[0, 1, 2])"
        assert size_abc.examples.example_structure.input_arity == 0
        # has no computational implementation for now

    # TODO: This test is failing for some reason, fix later.
    # def test_predicate_example_transformation(self, size_rule, divides_predicate):
    #     """Test example transformation for predicate concepts."""
    #     tau = size_rule.apply(divides_predicate, indices_to_quantify=[0])

    #     # Check examples
    #     examples = [ex.value for ex in tau.examples.get_examples()]
    #     assert (1, 1) in examples  # 1 has 1 divisor
    #     assert (2, 2) in examples  # 2 has 2 divisors
    #     assert (3, 2) in examples  # 3 has 2 divisors
    #     assert (4, 3) in examples  # 4 has 3 divisors
    #     assert (6, 4) in examples  # 6 has 4 divisors
    #     assert (12, 6) in examples  # 12 has 6 divisors


class TestSizeRuleErrors:
    """Tests for error handling."""

    def test_invalid_concept_types(self, size_rule):
        """Test invalid concept types."""
        # Create a non-set, non-predicate concept
        invalid_concept = Concept(
            name="invalid",
            description="A function that doesn't return a set",
            symbolic_definition=lambda x: ConceptApplication(
                addition, Var("x"), Nat(1)
            ),
            computational_implementation=lambda x: x + 1,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
                input_arity=1,
            ),
        )

        # Check that can_apply returns False
        assert not size_rule.can_apply(invalid_concept, verbose=False)

        # Check that apply raises ValueError
        with pytest.raises(ValueError):
            size_rule.apply(invalid_concept)

    def test_invalid_indices_to_quantify(
        self, size_rule, divides_predicate, set_of_divisors_concept
    ):
        """Test invalid indices to quantify."""
        # Index out of range
        assert not size_rule.can_apply(
            divides_predicate, indices_to_quantify=[2], verbose=False
        )
        with pytest.raises(ValueError):
            size_rule.apply(divides_predicate, indices_to_quantify=[2])

        # Empty indices for predicate
        assert not size_rule.can_apply(
            divides_predicate, indices_to_quantify=[], verbose=False
        )
        with pytest.raises(ValueError):
            size_rule.apply(divides_predicate, indices_to_quantify=[])

        # Providing indices for set concept
        assert not size_rule.can_apply(
            set_of_divisors_concept, indices_to_quantify=[0], verbose=False
        )
        with pytest.raises(ValueError):
            size_rule.apply(set_of_divisors_concept, indices_to_quantify=[0])

    def test_insufficient_arity(
        self, size_rule, prime_of_form_n2_plus_n_plus_1_predicate
    ):
        """Test insufficient arity."""
        # prime_of_form_n2_plus_n_plus_1_predicate has arity 1, so we can't quantify over multiple indices
        assert not size_rule.can_apply(
            prime_of_form_n2_plus_n_plus_1_predicate,
            indices_to_quantify=[0, 1],
            verbose=False,
        )
        with pytest.raises(ValueError):
            size_rule.apply(
                prime_of_form_n2_plus_n_plus_1_predicate, indices_to_quantify=[0, 1]
            )


class TestSizeRuleIntegration:
    """Integration tests."""

    def test_tau_function(self, size_rule, set_of_divisors_concept):
        """Test creating the tau function (divisor counting)."""
        tau = size_rule.apply(set_of_divisors_concept)
        tau.name = "tau"  # Rename for clarity

        # Test properties of tau
        assert tau.compute(1) == 1
        assert tau.compute(2) == 2
        assert tau.compute(3) == 2
        assert tau.compute(4) == 3
        assert tau.compute(6) == 4
        assert tau.compute(8) == 4
        assert tau.compute(12) == 6

        # Check that tau(p) = 2 for prime p
        assert tau.compute(2) == 2
        assert tau.compute(3) == 2
        assert tau.compute(5) == 2
        assert tau.compute(7) == 2
        assert tau.compute(11) == 2
        assert tau.compute(13) == 2

        # Check that tau(p^k) = k+1 for prime p
        assert tau.compute(2) == 2  # 2^1: 2 divisors
        assert tau.compute(4) == 3  # 2^2: 3 divisors
        assert tau.compute(8) == 4  # 2^3: 4 divisors
        assert tau.compute(16) == 5  # 2^4: 5 divisors
        assert tau.compute(3) == 2  # 3^1: 2 divisors
        assert tau.compute(9) == 3  # 3^2: 3 divisors
        assert tau.compute(27) == 4  # 3^3: 4 divisors

    def test_complex_quantification(self, size_rule, multi_arg_predicate):
        """Test complex quantification patterns."""
        # Count solutions for each value of c
        solutions_for_c = size_rule.apply(
            multi_arg_predicate, indices_to_quantify=[0, 1]
        )
        assert solutions_for_c.name == "size_of_(sum_equals_product_indices_[0, 1])"
        assert solutions_for_c.examples.example_structure.input_arity == 1
        # has no computational implementation for now

        # Count solutions for each value of a
        solutions_for_a = size_rule.apply(
            multi_arg_predicate, indices_to_quantify=[1, 2]
        )
        assert solutions_for_a.name == "size_of_(sum_equals_product_indices_[1, 2])"
        assert solutions_for_a.examples.example_structure.input_arity == 1
        # has no computational implementation for now

    def test_composition_with_other_rules(self, size_rule, equals_predicate):
        """Test composition with other rules."""
        # Create a concept that counts numbers equal to a given value
        count_equals = size_rule.apply(equals_predicate, indices_to_quantify=[1])
        assert count_equals.name == "size_of_(equals_indices_[1])"
        assert count_equals.examples.example_structure.input_arity == 1
        # has no computational implementation for now

        # Create a concept that counts values equal to a given number
        count_equals_to = size_rule.apply(equals_predicate, indices_to_quantify=[0])
        assert count_equals_to.name == "size_of_(equals_indices_[0])"
        assert count_equals_to.examples.example_structure.input_arity == 1
        # has no computational implementation for now
