import pytest
from frame.knowledge_base.entities import (
    Concept,
    ConceptType,
    ExampleType,
    ExampleStructure,
    Nat,
    NatDomain,
    Var,
    Equals,
    Exists,
    ConceptApplication,
    SetCardinality,
    Set,
    Lambda,
    And,
    Zero,
    Succ,
)

from frame.productions.concepts.specialize import SpecializeRule
from frame.tools.z3_template import Z3Template

# ==============================
# Fixtures
# ==============================


@pytest.fixture
def specialize_rule():
    """Create a fresh SpecializeRule instance for each test."""
    return SpecializeRule()

@pytest.fixture
def complex_specialized_concept(specialize_rule, successor_concept, zero_concept):
    """
    Create the complex concept 'specialized_(size_of_(exists_(successor_indices_[0])_indices_[0])_output_eq_0)'.

    This represents a concept that:
    1. Creates a set using existential quantification over the successor function
    2. Computes the size (cardinality) of this set
    3. Specializes this to check if the size equals 0
    """
    # First, create a variable for the existential quantifier
    var_name = "x"

    # Create an existential quantifier over the successor function
    # This represents the set of elements where successor(x) holds for some x
    exists_expr = Exists(
        var_name,
        NatDomain(),  # Domain of natural numbers
        ConceptApplication(successor_concept, Var(var_name)),
    )

    # Create a concept that computes the size of this set
    size_concept = Concept(
        name="size_of_(exists_(successor_indices_[0])_indices_[0])",
        description="Size of the set defined by existential quantification over successor",
        symbolic_definition=lambda: SetCardinality(exists_expr),
        computational_implementation=lambda: float(
            "inf"
        ),  # This set is infinite in size
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
    )

    # Now specialize this concept to check if the size equals 0
    specialized_concept = specialize_rule.apply(
        size_concept, zero_concept, index_to_specialize=0
    )

    return specialized_concept


# ==============================
# Basic Rule Tests
# ==============================


class TestSpecializeRuleBasics:
    """Test basic functionality of the SpecializeRule class."""

    def test_get_input_types(self, specialize_rule):
        """Test that get_input_types returns the correct types."""
        input_types = specialize_rule.get_input_types()
        assert isinstance(input_types, list), "Should return a list"
        assert len(input_types) == 2, "Should return two input types"

        # First input should be a function or predicate
        assert input_types[0][0] == Concept
        assert input_types[0][1] == [ConceptType.FUNCTION, ConceptType.PREDICATE]

        # Second input should be a constant, function, or predicate
        assert input_types[1][0] == Concept
        assert input_types[1][1] == [
            ConceptType.CONSTANT,
            ConceptType.FUNCTION,
        ]

    def test_can_apply_valid_input_specialization(
        self, specialize_rule, proper_divisors_count, one_concept
    ):
        """Test that can_apply returns True for valid input specialization."""
        assert specialize_rule.can_apply(
            proper_divisors_count, one_concept, index_to_specialize=1, verbose=False
        ), "Should allow specializing proper_divisors_count with one_concept at index 1"

    def test_can_apply_valid_output_specialization(
        self, specialize_rule, add2, two_concept, one_concept, tau_function_concept
    ):
        """Test that can_apply returns True for valid output specialization."""
        assert specialize_rule.can_apply(
            add2, two_concept, index_to_specialize=1, verbose=False
        ), "Should allow specializing add2 output with two_concept"

        assert specialize_rule.can_apply(
            tau_function_concept, one_concept, index_to_specialize=1, verbose=True
        ), "Should allow specializing tau_function_concept output with one_concept"


# ==============================
# Input Specialization Tests
# ==============================


class TestInputSpecialization:
    """Test specializing input arguments of functions and predicates."""

    def test_specialize_predicate_input(
        self, specialize_rule, proper_divisors_count, one_concept
    ):
        """Test specializing a predicate input (proper_divisors_count with k=1)."""
        # Create has_one_proper_divisor predicate
        has_one_proper_divisor = specialize_rule.apply(
            proper_divisors_count, one_concept, index_to_specialize=1
        )

        assert (
            "Specialization of proper_divisors_count"
            in has_one_proper_divisor.description
        )
        assert (
            has_one_proper_divisor.name
            == "specialized_(proper_divisors_count_output_eq_one)"
        )
        # Test computation for various cases
        assert has_one_proper_divisor.compute(7), "7 has exactly one proper divisor"
        assert not has_one_proper_divisor.compute(
            4
        ), "4 has two proper divisors, not one"
        assert not has_one_proper_divisor.compute(
            6
        ), "6 has three proper divisors, not one"

        # Test example transformation
        examples = has_one_proper_divisor.examples.get_examples()
        assert any(ex.value == (7,) for ex in examples), "Should have example (7,)"

        # Test nonexample transformation
        nonexamples = has_one_proper_divisor.examples.get_nonexamples()
        assert any(
            ex.value == (4,) for ex in nonexamples
        ), "Should have nonexample (4,)"
        assert any(
            ex.value == (6,) for ex in nonexamples
        ), "Should have nonexample (6,)"

    def test_specialize_binary_predicate_first_arg(
        self, specialize_rule, greater_than, two_concept
    ):
        """Test specializing first argument of a binary predicate (greater_than with a=5)."""
        # Create a concept for the number 5
        five_concept = Concept(
            name="five",
            description="The number five",
            symbolic_definition=lambda: Nat(5),
            computational_implementation=lambda: 5,
            example_structure=ExampleStructure(
                concept_type=ConceptType.CONSTANT,
                component_types=(ExampleType.NUMERIC,),
                input_arity=0,
            ),
        )
        five_concept.add_example((5,))

        # Create is_less_than_5 predicate
        is_less_than_5 = specialize_rule.apply(
            greater_than, five_concept, index_to_specialize=0
        )

        assert is_less_than_5.name == "specialized_(greater_than_at_0_to_five)"

        # Test the specialized concept
        assert is_less_than_5.compute(3), "5 > 3 should be True"
        assert not is_less_than_5.compute(7), "5 > 7 should be False"
        assert not is_less_than_5.compute(5), "5 > 5 should be False"

    def test_specialize_function_input(
        self, specialize_rule, multiply, two_concept
    ):
        """Test specializing a function input (multiplication with first arg = 2)."""
        # Create double function
        double = specialize_rule.apply(
            multiply, two_concept, index_to_specialize=0
        )

        # Test the specialized concept
        assert double.compute(3) == 6, "2 * 3 should be 6"
        assert double.compute(4) == 8, "2 * 4 should be 8"
        assert double.compute(0) == 0, "2 * 0 should be 0"

        # Test example transformation
        examples = double.examples.get_examples()
        expected_examples = [(3, 6), (4, 8), (0, 0)]
        for expected in expected_examples:
            assert any(
                ex.value == expected for ex in examples
            ), f"Should have example {expected}"


# ==============================
# Output Specialization Tests
# ==============================


class TestOutputSpecialization:
    """Test specializing outputs of functions."""

    def test_specialize_function_output(self, specialize_rule, add2, two_concept):
        """Test specializing a function output (add2 == 2)."""
        # Create is_zero predicate
        is_zero = specialize_rule.apply(add2, two_concept, index_to_specialize=1)

        # Test the specialized concept
        assert is_zero.compute(0), "add2(0) == 2 should be True"
        assert not is_zero.compute(1), "add2(1) == 2 should be False"
        assert not is_zero.compute(2), "add2(2) == 2 should be False"

        # Verify the concept type changed from function to predicate
        assert (
            is_zero.examples.example_structure.concept_type == ConceptType.PREDICATE
        ), "Output specialization should create a predicate"

    def test_specialize_square_output(self, specialize_rule, square):
        """Test specializing square output to 4 (square == 4)."""
        # Create a concept for the number 4
        four_concept = Concept(
            name="four",
            description="The number four",
            symbolic_definition=lambda: Nat(4),
            computational_implementation=lambda: 4,
            example_structure=ExampleStructure(
                concept_type=ConceptType.CONSTANT,
                component_types=(ExampleType.NUMERIC,),
                input_arity=0,
            ),
        )
        four_concept.add_example((4,))

        # Create is_square_root_of_4 predicate
        is_square_root_of_4 = specialize_rule.apply(
            square, four_concept, index_to_specialize=1
        )

        # Test the specialized concept
        assert is_square_root_of_4.compute(2), "square(2) == 4 should be True"
        assert not is_square_root_of_4.compute(1), "square(1) == 4 should be False"
        assert not is_square_root_of_4.compute(3), "square(3) == 4 should be False"

    def test_specialize_tau_function_output(
        self, specialize_rule, tau_function_concept, two_concept
    ):
        """Test specializing tau_function_concept output to 1 (has_one_proper_divisor)."""
        has_two_divisors = specialize_rule.apply(
            tau_function_concept, two_concept, index_to_specialize=1
        )

        # Test the specialized concept
        assert has_two_divisors.compute(7), "7 has exactly two divisors"
        assert has_two_divisors.compute(11), "11 has exactly two divisors"
        assert not has_two_divisors.compute(4), "4 has exactly two divisors"
        assert not has_two_divisors.compute(1), "1 has exactly one divisor"


# ==============================
# Example Transformation Tests
# ==============================


class TestExampleTransformation:
    """Test how examples are transformed during specialization."""

    def test_input_specialization_examples(
        self, specialize_rule, proper_divisors_count, one_concept
    ):
        """Test example transformation for input specialization."""
        # Create has_one_proper_divisor predicate
        has_one_proper_divisor = specialize_rule.apply(
            proper_divisors_count, one_concept, index_to_specialize=1
        )

        # Get examples and nonexamples
        examples = has_one_proper_divisor.examples.get_examples()
        nonexamples = has_one_proper_divisor.examples.get_nonexamples()

        # Check examples (should only include numbers with exactly one proper divisor)
        assert any(ex.value == (7,) for ex in examples), "Should have example (7,)"

        # Check nonexamples (should include numbers with not exactly one proper divisor)
        assert any(
            ex.value == (4,) for ex in nonexamples
        ), "Should have nonexample (4,)"
        assert any(
            ex.value == (6,) for ex in nonexamples
        ), "Should have nonexample (6,)"
        assert any(
            ex.value == (8,) for ex in nonexamples
        ), "Should have nonexample (8,)"

    def test_output_specialization_examples(
        self, specialize_rule, tau_function_concept, two_concept
    ):
        """Test example transformation for output specialization."""
        # Create is_zero predicate
        has_two_divisors = specialize_rule.apply(
            tau_function_concept, two_concept, index_to_specialize=1
        )

        # Get examples
        examples = has_two_divisors.examples.get_examples()
        nonexamples = has_two_divisors.examples.get_nonexamples()

        print(nonexamples)

        # Check examples
        assert any(ex.value == (2,) for ex in examples), "Should have example (2,)"
        assert any(ex.value == (3,) for ex in examples), "Should have example (3,)"
        assert any(ex.value == (5,) for ex in examples), "Should have example (5,)"
        assert any(ex.value == (7,) for ex in examples), "Should have example (7,)"
        assert any(ex.value == (11,) for ex in examples), "Should have example (11,)"

        # Check nonexamples (should include inputs where add2(x) != 2)
        assert any(
            ex.value == (1,) for ex in nonexamples
        ), "Should have nonexample (1,)"
        assert any(
            ex.value == (4,) for ex in nonexamples
        ), "Should have nonexample (4,)"
        assert any(
            ex.value == (6,) for ex in nonexamples
        ), "Should have nonexample (6,)"
        assert any(
            ex.value == (8,) for ex in nonexamples
        ), "Should have nonexample (8,)"


#
# ==============================
# Error Case Tests
# ==============================


class TestSpecializeRuleErrors:
    """Test error cases for the SpecializeRule class."""

    def test_invalid_concept_types(self, specialize_rule, one_concept, two_concept):
        """Test that specialization fails with invalid concept types."""
        # Try to specialize a constant (should fail)
        assert not specialize_rule.can_apply(
            one_concept, two_concept, index_to_specialize=0
        ), "Should reject specializing a constant"

        # Try to apply with invalid types (should raise ValueError)
        with pytest.raises(ValueError):
            specialize_rule.apply(one_concept, two_concept, index_to_specialize=0)

    def test_invalid_specialization_index(
        self, specialize_rule, proper_divisors_count, one_concept
    ):
        """Test that specialization fails with invalid index."""
        # Try with negative index
        assert not specialize_rule.can_apply(
            proper_divisors_count, one_concept, index_to_specialize=-1
        ), "Should reject negative index"

        # Try with index >= input_arity for predicate
        assert not specialize_rule.can_apply(
            proper_divisors_count, one_concept, index_to_specialize=2
        ), "Should reject index >= input_arity for predicate"

        # Try to apply with invalid index (should raise ValueError)
        with pytest.raises(ValueError):
            specialize_rule.apply(
                proper_divisors_count, one_concept, index_to_specialize=3
            )

    def test_type_mismatch(self, specialize_rule, proper_divisors_count):
        """Test that specialization fails with type mismatch."""
        # Create a string concept
        string_concept = Concept(
            name="hello",
            description="The string 'hello'",
            symbolic_definition=lambda: "hello",
            computational_implementation=lambda: "hello",
            example_structure=ExampleStructure(
                concept_type=ConceptType.CONSTANT,
                component_types=(ExampleType.ANY,),
                input_arity=0,
            ),
        )

        # Try with type mismatch
        assert not specialize_rule.can_apply(
            proper_divisors_count, string_concept, index_to_specialize=1
        ), "Should reject type mismatch"

    def test_multi_output_function(
        self, specialize_rule, multi_output_function, one_concept
    ):
        """Test that output specialization fails for multi-output functions."""
        # Try to specialize output of multi-output function
        assert not specialize_rule.can_apply(
            multi_output_function, one_concept, index_to_specialize=1
        ), "Should reject output specialization of multi-output function"


# ==============================
# Parameterization Tests
# ==============================


class TestValidParameterizations:
    """Test the get_valid_parameterizations method."""

    def test_predicate_parameterizations(
        self, specialize_rule, proper_divisors_count, one_concept
    ):
        """Test valid parameterizations for predicates."""
        params = specialize_rule.get_valid_parameterizations(
            proper_divisors_count, one_concept
        )

        # Should allow specializing either argument
        assert {
            "index_to_specialize": 0
        } in params, "Should allow specializing first argument"
        assert {
            "index_to_specialize": 1
        } in params, "Should allow specializing second argument"
        assert len(params) == 2, "Should have exactly two valid parameterizations"

        specialized_concept = specialize_rule.apply(
            proper_divisors_count, one_concept, index_to_specialize=0
        )

        params = specialize_rule.get_valid_parameterizations(specialized_concept)
        assert (
            len(params) == 0
        ), "Should have no valid parameterization if input is predicate of arity 1"

    def test_constant_parameterizations(self, specialize_rule, one_concept):
        """Test valid parameterizations for constants."""
        params = specialize_rule.get_valid_parameterizations(one_concept)
        assert len(params) == 0, "Should have exactly zero valid parameterizations"

    def test_add2_function_parameterizations(self, specialize_rule, add2, two_concept):
        """Test valid parameterizations for functions."""
        params = specialize_rule.get_valid_parameterizations(add2, two_concept)

        # Should allow specializing input and output
        assert {"index_to_specialize": 0} in params, "Should allow specializing input"
        assert {"index_to_specialize": 1} in params, "Should allow specializing output"
        assert len(params) == 2, "Should have exactly two valid parameterizations"

    def test_binary_function_parameterizations(
        self, specialize_rule, multiply, two_concept
    ):
        """Test valid parameterizations for binary functions."""
        params = specialize_rule.get_valid_parameterizations(
            multiply, two_concept
        )

        # Should allow specializing either input or output
        assert {
            "index_to_specialize": 0
        } in params, "Should allow specializing first input"
        assert {
            "index_to_specialize": 1
        } in params, "Should allow specializing second input"
        assert {
            "index_to_specialize": 2
        } in params, "Should allow specializing the output"
        assert len(params) == 3, "Should have exactly three valid parameterizations"

    def test_invalid_parameterizations(self, specialize_rule, add2, one_concept, two_concept):
        """Test that invalid inputs return empty parameterizations."""
        # Test with invalid inputs
        params = specialize_rule.get_valid_parameterizations(one_concept, two_concept)
        assert len(params) == 0, "Should return empty list for invalid inputs"

        # Test with type mismatch
        string_concept = Concept(
            name="hello",
            description="The string 'hello'",
            symbolic_definition=lambda: "hello",
            computational_implementation=lambda: "hello",
            example_structure=ExampleStructure(
                concept_type=ConceptType.CONSTANT,
                component_types=(ExampleType.ANY,),
                input_arity=0,
            ),
        )

        params = specialize_rule.get_valid_parameterizations(add2, string_concept)
        assert len(params) == 0, "Should return empty list for type mismatch"


# ==============================
# Integration Tests
# ==============================


class TestSpecializeIntegration:
    """Test more complex specialization scenarios."""

    def test_successor_specialization(
        self, specialize_rule, zero_concept, successor_concept
    ):
        """Test specializing successor with zero to create one."""
        # Check if rule can be applied
        assert specialize_rule.can_apply(
            successor_concept, zero_concept, index_to_specialize=0
        ), "Should allow specializing successor with zero"

        # Apply the rule to create "one" concept
        one = specialize_rule.apply(
            successor_concept, zero_concept, index_to_specialize=0
        )

        # Test the specialized concept
        assert one.compute() == 1, "successor(0) should be 1"

        # Verify the concept type and arity
        assert one.examples.example_structure.concept_type == ConceptType.FUNCTION
        assert (
            one.examples.example_structure.input_arity == 0
        ), "Should have zero input arity"

    def test_nested_specialization(
        self, specialize_rule, multiply, two_concept, three_concept
    ):
        """Test applying specialization multiple times to create a nested specialized concept."""
        # First specialization: Create a "double" function by fixing first argument to 2
        double = specialize_rule.apply(
            multiply, two_concept, index_to_specialize=0
        )

        # Verify the first specialization
        assert double.compute(3) == 6, "2 * 3 should be 6"
        assert double.compute(4) == 8, "2 * 4 should be 8"
        assert double.compute(0) == 0, "2 * 0 should be 0"
        assert double.compute(5) == 10, "2 * 5 should be 10"

        # Second specialization: Create a predicate that checks if double(x) == 6
        # by specializing the output to 6
        six_concept = Concept(
            name="six",
            description="The number six",
            symbolic_definition=lambda: Nat(6),
            computational_implementation=lambda: 6,
            example_structure=ExampleStructure(
                concept_type=ConceptType.CONSTANT,
                component_types=(ExampleType.NUMERIC,),
                input_arity=0,
            ),
        )
        six_concept.add_example((6,))

        is_doubled_to_six = specialize_rule.apply(
            double, six_concept, index_to_specialize=1
        )

        # Verify the second specialization
        assert is_doubled_to_six.compute(3), "double(3) == 6 should be True"
        assert not is_doubled_to_six.compute(4), "double(4) == 6 should be False"
        assert not is_doubled_to_six.compute(5), "double(5) == 6 should be False"

        # Verify the concept type changed from function to predicate
        assert (
            is_doubled_to_six.examples.example_structure.concept_type
            == ConceptType.PREDICATE
        ), "Output specialization should create a predicate"

        # Alternative nested specialization: Create a function that multiplies by 6
        # First create a "triple" function
        triple = specialize_rule.apply(
            multiply, three_concept, index_to_specialize=0
        )

        # Then create a function that multiplies by 6 by specializing the input of triple to 2
        multiply_by_six = specialize_rule.apply(
            triple, two_concept, index_to_specialize=0
        )

        # Verify the nested specialization
        assert multiply_by_six.compute() == 6, "3 * 2 should be 6"
        assert (
            multiply_by_six.examples.example_structure.input_arity == 0
        ), "Should have zero input arity"
        assert (
            multiply_by_six.examples.example_structure.concept_type
            == ConceptType.FUNCTION
        ), "Should still be a function"

# ==============================
# Z3 Translation Tests
# ==============================

class TestSpecializeRuleZ3:
    """Test the Z3 translation of the SpecializeRule."""

    def test_predicate_input_specialization_z3(self, specialize_rule, greater_than, two_concept):
        """Test the Z3 translation of input specialization for the > predicate."""
        geq_two = specialize_rule.apply(greater_than, two_concept, index_to_specialize=1)

        program = geq_two.to_z3(Nat(1))
        result = program.run()
        assert not result.proved, "1 > 2 should be False"

        program = geq_two.to_z3(Nat(2))
        result = program.run()
        assert not result.proved, "2 > 2 should be False"

        program = geq_two.to_z3(Nat(3))
        result = program.run()
        assert result.proved, "3 > 2 should be True"

    def test_function_input_specialization_z3(self, specialize_rule, is_even, multiply, two_concept):
        """Test the Z3 translation of input specialization for the * function."""
        multiply_by_two = specialize_rule.apply(multiply, two_concept, index_to_specialize=0)
        program = multiply_by_two.to_z3(Nat(4)).program

        predicate_program = Z3Template(
            f"""
            params 1;
            bounded params 1;
            f_0 := Func(
            {program.dsl()}
            );
            ReturnExpr None;
            ReturnPred f_0(x_0=x_0) == 8;
            """
        )
        predicate_program.set_args(Nat(4))
        result = predicate_program.run()
        assert result.proved, "4 * 2 should be 8"

        predicate_program.set_args(Nat(5))
        result = predicate_program.run()
        assert not result.proved, "5 * 2 should not be 8"

    def test_function_output_specialization_z3(self, specialize_rule, multiply, two_concept):
        """Test the Z3 translation of output specialization."""
        multiply_eq_two = specialize_rule.apply(multiply, two_concept, index_to_specialize=2)
        program = multiply_eq_two.to_z3(Nat(2), Nat(4))
        result = program.run()
        assert not result.proved, "2 * 4 should not be 2"

        program = multiply_eq_two.to_z3(Nat(1), Nat(2))
        result = program.run()
        assert result.proved, "1 * 2 should be 2"