import pytest

from frame.productions.concepts.match import MatchRule
from frame.productions.concepts.compose import ComposeRule
from frame.productions.concepts.specialize import SpecializeRule
from frame.productions.concepts.size import SizeRule
from frame.knowledge_base.entities import (
    ConceptApplication,
    Concept,
    ExampleType,
    ExampleStructure,
    ConceptType,
    Set,
    NatDomain,
    Equals,
    Exists,
    And,
    Lambda,
    Zero,
    Succ,
    Nat,
    Var,
    GroupElement,
    Not,
)
from frame.knowledge_base.demonstrations import (
    multiplication,
    addition,
    SetCardinality,
    proper_divisors,
    divides,
    less_than,
    is_prime,
)
from frame.tools.z3_template import Z3Template

@pytest.fixture
def square():
    """Create a concept that squares a number."""
    concept = Concept(
        name="square",
        description="Square a number",
        symbolic_definition=lambda x: ConceptApplication(multiplication, x, x),
        computational_implementation=lambda x: x * x,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=1,
        ),
        z3_translation=lambda x: Z3Template(
            f"""
            params 1;
            bounded params 0;
            ReturnExpr x_0 * x_0;
            ReturnPred None;
            """,
            x,
        ),
    )
    # Add examples for 0 through 19
    # Note(_; 3/30): test_constant.py requires largest example (19, 361)
    for i in range(20):
        concept.add_example((i, i * i))

    return concept


@pytest.fixture
def is_square():
    """Create a concept that tests if a number is a perfect square."""
    concept = Concept(
        name="is_square",
        description="Test if a number is a perfect square",
        symbolic_definition=lambda n: Exists(
            "k",
            NatDomain(),
            Equals(n, ConceptApplication(multiplication, Var("k"), Var("k"))),
        ),
        computational_implementation=lambda n: int(n**0.5) ** 2 == n,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1,
        ),
        z3_translation=lambda n: Z3Template(
            f"""
            params 1;
            bounded params 1;
            ReturnExpr None;
            ReturnPred Exists([b_0], x_0 == b_0 * b_0);
            """,
            n,
        ),
    )
    # Add examples for perfect squares 0 through 16
    for i in range(20):
        concept.add_example((i * i,))

    # Add nonexamples for non-squares
    for i in range(1, 20):
        concept.add_nonexample((i * i + 1,))

    return concept


@pytest.fixture
def is_even():
    """Create a concept that tests if a number is even."""
    concept = Concept(
        name="is_even",
        description="A number divisible by 2",
        symbolic_definition=lambda n: ConceptApplication(divides, Nat(2), n),
        computational_implementation=lambda n: n % 2 == 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1,
        ),
        z3_translation=lambda n: Z3Template(
            f"""
            params 1;
            bounded params 1;
            ReturnExpr None;
            ReturnPred Exists([b_0], b_0 * 2 == x_0);
            """,
            n,
        ),
    )
    # Add examples for even numbers 0 through 19
    # Note(_; 3/30): test_constant.py requires largest example (18,)
    for i in range(20):
        if i % 2 == 0:
            concept.add_example((i,))

    # Add nonexamples for odd numbers 1 through 15
    for i in range(20):
        if i % 2 == 1:
            concept.add_nonexample((i,))

    return concept


@pytest.fixture
def is_odd():
    """Create a concept that tests if a number is odd."""
    concept = Concept(
        name="is_odd",
        description="A number not divisible by 2",
        symbolic_definition=lambda n: Not(ConceptApplication(divides, Nat(2), n)),
        computational_implementation=lambda n: n % 2 == 1,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1,
        ),
        z3_translation=lambda n: Z3Template(
            f"""
            params 1;
            bounded params 1;
            ReturnExpr None;
            ReturnPred Not(Exists([b_0], b_0 * 2 == x_0));
            """,
            n,
        ),
    )
    # Add examples for odd numbers 0 through 16
    for i in range(20):
        if i % 2 == 1:
            concept.add_example((i,))

    # Add nonexamples for even numbers 1 through 15
    for i in range(20):
        if i % 2 == 0:
            concept.add_nonexample((i,))

    return concept


@pytest.fixture
def proper_divisors_count():
    """Create a concept for checking number of proper divisors."""
    concept = Concept(
        name="proper_divisors_count",
        description="Checks if k is the number of proper divisors of n",
        symbolic_definition=lambda n, k: Equals(
            k, SetCardinality(ConceptApplication(proper_divisors, n))
        ),
        computational_implementation=lambda n: len(proper_divisors.compute(n)),
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=1,
        ),
    )

    # Add examples
    for i in range(1, 20):
        concept.add_example((i, len(proper_divisors.compute(i))))

    # Add nonexamples (not exhaustive)
    concept.add_nonexample((4, 3))  # 4 doesn't have three proper divisors
    concept.add_nonexample((6, 2))  # 6 doesn't have two proper divisors
    concept.add_nonexample((8, 2))  # 8 doesn't have two proper divisors
    concept.add_nonexample((7, 2))  # 7 doesn't have two proper divisors

    return concept


@pytest.fixture
def tau_function_concept():
    """Create a concept for the tau function."""
    concept = Concept(
        name="tau",
        description="The number of divisors of n",
        symbolic_definition=lambda n: SetCardinality(
            Set(
                domain=NatDomain(),
                predicate=Lambda(
                    "d",
                    And(
                        ConceptApplication(divides, Var("d"), n),
                        ConceptApplication(less_than, Zero(), Var("d")),
                    ),
                ),
            )
        ),
        computational_implementation=lambda n: len(
            {i for i in range(1, n + 1) if n % i == 0},
        ),
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=1,
        ),
    )

    # Add examples
    for n in range(1, 20):
        concept.add_example((n, len({i for i in range(1, n + 1) if n % i == 0})))

    return concept


@pytest.fixture
def greater_than():
    """Create a binary predicate for greater than."""
    concept = Concept(
        name="greater_than",
        description="a > b if not (a ≤ b)",
        symbolic_definition=lambda a, b: Exists(
            "m", NatDomain(), Equals(b, addition(a, Var("m")))
        ),
        computational_implementation=lambda a, b: a > b,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
        lean4_translation=lambda a, b: f"({a} > {b})",
        prolog_translation=lambda a, b: f"({a} > {b})",
        z3_translation=lambda a, b: Z3Template(
            f"""
            params 2;
            bounded params 1;
            ReturnExpr None;
            ReturnPred Exists([b_0], x_0 == b_0 + 1 + x_1);
            """,
            a,
            b,
        ),
    )

    # Add examples
    for i in range(1, 10):
        for j in range(1, 10):
            if i > j:
                concept.add_example((i, j))

    # Add nonexamples (not exhaustive)
    concept.add_nonexample((3, 5))  # 3 is not > 5
    concept.add_nonexample((2, 2))  # 2 is not > 2
    concept.add_nonexample((1, 7))  # 1 is not > 7

    return concept


@pytest.fixture
def zero_concept():
    """Create a constant concept for zero."""
    concept = Concept(
        name="zero",
        description="The number zero",
        symbolic_definition=lambda: Nat(0),
        computational_implementation=lambda: 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.CONSTANT,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
    )
    concept.add_example((0,))
    return concept


@pytest.fixture
def one_concept():
    """Create a constant concept for one."""
    concept = Concept(
        name="one",
        description="The number one",
        symbolic_definition=lambda: Nat(1),
        computational_implementation=lambda: 1,
        example_structure=ExampleStructure(
            concept_type=ConceptType.CONSTANT,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
    )
    concept.add_example((1,))
    return concept


@pytest.fixture
def two_concept():
    """Create a constant concept for two."""
    concept = Concept(
        name="two",
        description="The number two",
        symbolic_definition=lambda: Nat(2),
        computational_implementation=lambda: 2,
        example_structure=ExampleStructure(
            concept_type=ConceptType.CONSTANT,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
        z3_translation=lambda: Z3Template(
            f"""
            params 0;
            bounded params 0;
            ReturnExpr 2;
            ReturnPred None;
            """,
        )
    )
    concept.add_example((2,))
    return concept


@pytest.fixture
def three_concept():
    """Create a constant concept for three."""
    concept = Concept(
        name="three",
        description="The number three",
        symbolic_definition=lambda: Nat(3),
        computational_implementation=lambda: 3,
        example_structure=ExampleStructure(
            concept_type=ConceptType.CONSTANT,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
    )
    concept.add_example((3,))
    return concept


@pytest.fixture
def multi_output_function():
    """Create a function with multiple outputs (for testing error cases)."""
    concept = Concept(
        name="multi_output",
        description="A function that returns multiple values",
        symbolic_definition=lambda x: (x, x + 1),
        computational_implementation=lambda x: (x, x + 1),
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
            ),
            input_arity=1,
        ),
    )
    # Add examples
    for i in range(5):
        concept.add_example((i, i, i + 1))
    return concept


@pytest.fixture
def successor_concept():
    """Create a concept for the successor function."""
    concept = Concept(
        name="successor",
        description="The successor function n -> n+1",
        symbolic_definition=lambda n: Succ(n),
        computational_implementation=lambda n: n + 1,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=1,
        ),
    )

    # Add examples
    for i in range(10):
        concept.add_example((i, i + 1))

    return concept


@pytest.fixture
def multiply():
    """Create a binary multiplication concept for testing."""
    concept = Concept(
        name="multiply",
        description="Multiplication of two numbers",
        symbolic_definition=lambda a, b: ConceptApplication(multiplication, a, b),
        computational_implementation=lambda a, b: a * b,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
            ),
            input_arity=2,
        ),
        lean4_translation=lambda a, b: f"({a} * {b})",
        prolog_translation=lambda a, b: f"times({a}, {b}, Result)",
        z3_translation=lambda a, b: Z3Template(
            f"""
            params 2;
            bounded params 0;
            ReturnExpr x_0 * x_1;
            ReturnPred None;
            """,
            a,
            b,
        ),
    )

    # Add examples
    for i in range(10):
        for j in range(10):
            concept.add_example((i, j, i * j))
    for i in range(10):
        for j in range(10):
            concept.add_nonexample((i, j, i * j + 1))

    return concept


@pytest.fixture
def multiply3():
    """Create a ternary multiplication concept for testing."""
    concept = Concept(
        name="multiply3",
        description="Multiplication of three numbers",
        symbolic_definition=lambda a, b, c: ConceptApplication(
            multiplication,
            ConceptApplication(multiplication, a, b),
            c,
        ),
        computational_implementation=lambda a, b, c: a * b * c,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
            ),
            input_arity=3,
        ),
        lean4_translation=lambda a, b, c: f"({a} * {b} * {c})",
        prolog_translation=lambda a, b, c: f"times3({a}, {b}, {c}, Result)",
        z3_translation=lambda a, b, c: f"(* {a} (* {b} {c}))",
    )

    # Add examples
    for i in range(10):
        for j in range(10):
            for k in range(10):
                concept.add_example((i, j, k, i * j * k))

    for i in range(5):
        for j in range(5):
            for k in range(5):
                concept.add_nonexample((i, j, k, i * j * k + 1))

    return concept


@pytest.fixture
def add2():
    """Create a concept that adds 2 to a number."""
    concept = Concept(
        name="add2",
        description="Add 2 to a number",
        symbolic_definition=lambda x: ConceptApplication(addition, x, Nat(2)),
        computational_implementation=lambda x: x + 2,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=1,
        ),
        z3_translation=lambda x: Z3Template(
            f"""
            params 1;
            bounded params 0;
            ReturnExpr x_0 + 2;
            ReturnPred None;
            """,
            x,
        ),
    )
    # Add examples for 0 through 9
    for i in range(10):
        concept.add_example((i, i + 2))
    return concept


@pytest.fixture
def add3():
    """Create a concept that adds 3 to a number."""
    concept = Concept(
        name="add3",
        description="Add 3 to a number",
        symbolic_definition=lambda x: ConceptApplication(addition, x, Nat(3)),
        computational_implementation=lambda x: x + 3,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=1,
        ),
        z3_translation=lambda x: Z3Template(
            f"""
            params 1;
            bounded params 0;
            ReturnExpr x_0 + 3;
            ReturnPred None;
            """,
            x,
        ),
    )
    # Add examples for 0 through 9
    for i in range(10):
        concept.add_example((i, i + 3))

    return concept


@pytest.fixture
def add_two_numbers():
    """Create a binary addition concept for testing."""
    concept = Concept(
        name="add_two_numbers",
        description="Addition of two numbers",
        symbolic_definition=lambda a, b: ConceptApplication(addition, a, b),
        computational_implementation=lambda a, b: a + b,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
            ),
            input_arity=2,
        ),
        z3_translation=lambda a, b: Z3Template(
            f"""
            params 2;
            bounded params 0;
            ReturnExpr x_0 + x_1;
            ReturnPred None;
            """,
            a,
            b,
        ),
    )

    # Add examples
    for i in range(10):
        for j in range(10):
            concept.add_example((i, j, i + j))
    for i in range(10):
        for j in range(10):
            concept.add_nonexample((i, j, i + j + 1))

    return concept


@pytest.fixture
def add_three_numbers():
    """Create a ternary addition concept for testing."""
    concept = Concept(
        name="add_three_numbers",
        description="Addition of three numbers",
        symbolic_definition=lambda a, b, c: ConceptApplication(
            addition,
            ConceptApplication(addition, a, b),
            c,
        ),
        computational_implementation=lambda a, b, c: a + b + c,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
            ),
            input_arity=3,
        ),
        z3_translation=lambda a, b, c: Z3Template(
            f"""
            params 3;
            bounded params 0;
            ReturnExpr x_0 + x_1 + x_2;
            ReturnPred None;
            """,
            a,
            b,
            c,
        ),
    )

    # Add examples
    for i in range(10):
        for j in range(10):
            for k in range(10):
                concept.add_example((i, j, k, i + j + k))
    # Add nonexamples
    for i in range(10):
        for j in range(10):
            for k in range(10):
                concept.add_nonexample((i, j, k, i + j + k + 1))

    return concept


@pytest.fixture
def many_args_concept():
    """Create a concept with many arguments for testing."""
    concept = Concept(
        name="many_args",
        description="A function with many arguments",
        symbolic_definition=lambda a, b, c, d, e, f, g, h: Set(
            domain=NatDomain(), elements=[a, b, c, d, e, f, g, h]
        ),
        computational_implementation=lambda a, b, c, d, e, f, g, h: {
            a,
            b,
            c,
            d,
            e,
            f,
            g,
            h,
        },
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC,) * 8 + (ExampleType.SET,),
            input_arity=8,
        ),
    )
    return concept


@pytest.fixture
def mixed_types_concept():
    """Create a concept with mixed argument types for testing type compatibility."""
    concept = Concept(
        name="mixed_types",
        description="A function with mixed argument types",
        symbolic_definition=lambda a, b, c: ConceptApplication(
            ConceptApplication(
                addition,
                a,
                ConceptApplication(SetCardinality, b),
            ),
            c,
        ),
        computational_implementation=lambda a, b, c: a + len(b) + c,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(
                ExampleType.NUMERIC,
                ExampleType.SET,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
            ),
            input_arity=3,
        ),
    )
    return concept


@pytest.fixture
def equals_predicate():
    """Create a predicate concept for testing."""
    concept = Concept(
        name="equals",
        description="Equality of two numbers",
        symbolic_definition=lambda a, b: Equals(a, b),
        computational_implementation=lambda a, b: a == b,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
    )

    for i in range(10):
        for j in range(10):
            concept.add_example((i, i))

    for i in range(10):
        for j in range(10):
            if i != j:
                concept.add_nonexample((i, j))

    return concept


@pytest.fixture
def set_of_divisors_concept():
    """Create a concept that returns a set."""
    concept = Concept(
        name="set_of_divisors",
        description="Returns the set of divisors of a number",
        symbolic_definition=lambda n: Set(
            domain=NatDomain(),
            predicate=Lambda("x", ConceptApplication(divides, Var("x"), n)),
        ),
        computational_implementation=lambda n: {
            i for i in range(1, n + 1) if n % i == 0
        },
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.SET),
            input_arity=1,
        ),
    )

    # Add examples
    concept.add_example((1, {1}))
    concept.add_example((2, {1, 2}))
    concept.add_example((3, {1, 3}))
    concept.add_example((4, {1, 2, 4}))
    concept.add_example((6, {1, 2, 3, 6}))
    concept.add_example((8, {1, 2, 4, 8}))
    concept.add_example((12, {1, 2, 3, 4, 6, 12}))

    return concept


@pytest.fixture
def divides_predicate():
    """Create a divides predicate for testing."""
    concept = Concept(
        name="divides",
        description="Checks if one number divides another",
        symbolic_definition=lambda a, b: Equals(
            ConceptApplication(divides, a, b),
        ),
        computational_implementation=lambda a, b: b % a == 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
    )

    # Add some examples for testing
    for i in range(1, 20):
        for j in range(1, 20):
            if j % i == 0:
                concept.add_example((i, j))
            else:
                concept.add_nonexample((i, j))

    return concept


@pytest.fixture
def prime_of_form_n2_plus_n_plus_1_predicate():
    """Create a unary predicate for prime numbers of the form n^2 + n + 1."""
    concept = Concept(
        name="is_prime_of_form_n2_plus_n_plus_1",
        description="Checks if a number is prime and of the form n^2 + n + 1",
        symbolic_definition=lambda p: Exists(
            "n",
            NatDomain(),
            And(
                ConceptApplication(is_prime, p),
                Equals(
                    ConceptApplication(
                        addition,
                        ConceptApplication(
                            addition,
                            ConceptApplication(multiplication, Var("n"), Var("n")),
                            Var("n"),
                        ),
                        Nat(1),
                    ),
                    p,
                ),
            ),
        ),
        computational_implementation=lambda p: p > 1
        and all(p % i != 0 for i in range(2, int(p**0.5) + 1))
        and any(p == n**2 + n + 1 for n in range(int(p**0.5) + 1)),
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1,
        ),
    )

    # Add examples
    for i in range(2, 20):
        if all(i % j != 0 for j in range(2, int(i**0.5) + 1)) and any(
            i == n**2 + n + 1 for n in range(int(i**0.5) + 1)
        ):
            concept.add_example((i,))
        else:
            concept.add_nonexample((i,))

    return concept


@pytest.fixture
def multi_arg_predicate():
    """Create a predicate with multiple arguments for testing different quantification patterns."""
    concept = Concept(
        name="sum_equals_product",
        description="Checks if a + b + c = a * b * c",
        symbolic_definition=lambda a, b, c: Equals(
            ConceptApplication(addition, ConceptApplication(addition, a, b), c),
            ConceptApplication(
                multiplication, ConceptApplication(multiplication, a, b), c
            ),
        ),
        computational_implementation=lambda a, b, c: a + b + c == a * b * c,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
            ),
            input_arity=3,
        ),
    )

    # Add examples
    for i in range(10):
        for j in range(10):
            for k in range(10):
                if i + j + k == i * j * k:
                    concept.add_example((i, j, k))
                else:
                    concept.add_nonexample((i, j, k))

    return concept


@pytest.fixture
def divides_and_even():
    """Create a concept that tests if b divides a and b is even."""
    concept = Concept(
        name="divides_and_even",
        description="Tests if b divides a and b is even",
        symbolic_definition=lambda a, b: And(
            ConceptApplication(divides, b, a), ConceptApplication(is_even, b)
        ),
        computational_implementation=lambda a, b: b != 0 and a % b == 0 and b % 2 == 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
        z3_translation=lambda a, b: Z3Template(
            f"""
            params 2;
            bounded params 0;
            ReturnExpr None;
            ReturnPred And(And(x_0 % x_1 == 0, x_1 % 2 == 0), Not(x_1 == 0));
            """,
            a,
            b,
        ),
    )

    # Add examples
    for a in range(1, 20):
        for b in range(1, 20):
            if a % b == 0 and b % 2 == 0:
                concept.add_example((a, b))
            else:
                concept.add_nonexample((a, b))

    return concept


@pytest.fixture
def add_multiply():
    """Create a function that returns multiple outputs: f(x,y) = (x+y, x*y)"""
    concept = Concept(
        name="add_multiply",
        description="Function that returns both sum and product of inputs",
        symbolic_definition=lambda x, y: (
            ConceptApplication(addition, x, y),
            ConceptApplication(multiplication, x, y),
        ),
        computational_implementation=lambda x, y: (x + y, x * y),
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(
                ExampleType.NUMERIC,  # First input
                ExampleType.NUMERIC,  # Second input
                ExampleType.NUMERIC,  # First output (sum)
                ExampleType.NUMERIC,  # Second output (product)
            ),
            input_arity=2,  # Takes two inputs
        ),
    )

    # Add examples
    for i in range(1, 20):
        for j in range(1, 20):
            concept.add_example((i, j, i + j, i * j))

    return concept


@pytest.fixture
def is_proper_divisor():
    """Create a concept that tests if d is a proper divisor of n."""
    concept = Concept(
        name="is_proper_divisor",
        description="Tests if d is a proper divisor of n (d divides n and d != n)",
        symbolic_definition=lambda n, d: And(
            ConceptApplication(divides, d, n),
            ConceptApplication(less_than, d, n),
        ),
        computational_implementation=lambda n, d: n != d and n % d == 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
        z3_translation=lambda n, d: Z3Template(
            f"""
            params 2;
            bounded params 0;
            ReturnExpr None;
            ReturnPred And(And(x_0 % x_1 == 0, Not(x_0 == x_1)), Not(x_1 == 0));
            """,
            n,
            d,
        ),
    )

    # Add examples for is_proper_divisor
    for n in range(1, 20):
        for d in range(1, 20):
            if n != d and n % d == 0:
                concept.add_example((n, d))
            else:
                concept.add_nonexample((n, d))

    return concept


@pytest.fixture
def divisor_mod_k_equals_r():
    """Create a concept that tests if d divides n and d % k == r."""

    concept = Concept(
        name="divisor_mod_k_equals_r",
        description="Tests if d divides n and d % k == r",
        symbolic_definition=lambda n, d, k, r: And(
            ConceptApplication(divides, d, n),
            Equals(ConceptApplication(lambda x, y: x % y, d, k), r),
        ),
        computational_implementation=lambda n, d, k, r: (
            d != 0 and n % d == 0 and d % k == r
        ),
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
                ExampleType.NUMERIC,
            ),
            input_arity=4,
        ),
        z3_translation=lambda n, d, k, r: Z3Template(
            f"""
            params 4;
            bounded params 0;
            ReturnExpr None;
            ReturnPred And(And(x_0 % x_1 == 0, x_1 % x_2 == x_3), Not(x_1 == 0));
            """,
            n,
            d,
            k,
            r,
        ),
    )

    # Add examples
    for n in range(1, 10):
        for d in range(n + 1):
            for k in range(1, 10):
                for r in range(k):
                    if d != 0 and n % d == 0 and d % k == r:
                        concept.add_example((n, d, k, r))
                    else:
                        concept.add_nonexample((n, d, k, r))

    return concept


@pytest.fixture
def commutes():
    """Create a concept that tests if two elements commute under group operation."""
    from frame.knowledge_base.entities import Equals, Apply, Z2, Z3, S3

    concept = Concept(
        name="commutes",
        description="Tests if two elements commute under group operation",
        symbolic_definition=lambda G, a, b: Equals(
            Apply(G.op, a, b), Apply(G.op, b, a)
        ),
        computational_implementation=lambda G, a, b: G.op(a, b) == G.op(b, a),
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(
                ExampleType.GROUP,
                ExampleType.GROUPELEMENT,
                ExampleType.GROUPELEMENT,
            ),
            input_arity=3,
        ),
    )

    # Z2 is abelian, all elements commute
    concept.add_example((Z2(), GroupElement(0), GroupElement(0)))
    concept.add_example((Z2(), GroupElement(0), GroupElement(1)))
    concept.add_example((Z2(), GroupElement(1), GroupElement(0)))
    concept.add_example((Z2(), GroupElement(1), GroupElement(1)))

    # Z3 is abelian, all elements commute
    concept.add_example((Z3(), GroupElement(0), GroupElement(1)))
    concept.add_example((Z3(), GroupElement(1), GroupElement(2)))

    # S3 is non-abelian, not all elements commute
    s3 = S3()
    id_perm = (0, 1, 2)
    perm12 = (1, 0, 2)
    perm23 = (0, 2, 1)
    concept.add_example(
        (s3, GroupElement(id_perm), GroupElement(perm12))
    )  # Identity commutes with everything
    concept.add_example(
        (s3, GroupElement(perm12), GroupElement(perm12))
    )  # Element commutes with itself
    concept.add_nonexample(
        (s3, GroupElement(perm12), GroupElement(perm23))
    )  # These don't commute

    return concept
