"""
Defines basic initial concepts commonly used as starting points for theory building,
accessed via factory functions.
"""

from frame.knowledge_base.entities import (
    Concept, ExampleStructure, ExampleType, ConceptType,
    Zero, Succ, Nat, Fold, Lambda, Var, Exists, NatDomain, Equals,
    ConceptApplication
)
from frame.tools.z3_template import Z3Template

def create_successor_concept():
    """Factory function to create the successor concept."""
    succ = Concept(
        name="successor",
        description="The successor function n -> n+1",
        symbolic_definition=lambda n: Succ(n),
        computational_implementation=lambda n: n + 1,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=1,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda n: Z3Template(
            """
            params 1;
            bounded params 0;
            ReturnExpr x_0 + 1;
            ReturnPred None;
            """,
            n
        )
    )
    # Add examples
    for i in range(20):
        succ.add_example((i, i + 1))
    # Add non-example
    succ.add_nonexample((3, 5))
    succ.add_nonexample((0, 2))
    succ.add_nonexample((0, 0))
    succ.add_nonexample((1, 1))
    succ.add_nonexample((2, 2))
    succ.add_nonexample((3, 3))
    succ.add_nonexample((4, 4))
    succ.add_nonexample((5, 5))
    succ.add_nonexample((2, 4))
    return succ


def create_zero_concept():
    """Factory function to create the zero concept."""
    zero = Concept(
        name="zero",
        description="The constant zero",
        symbolic_definition=lambda: Zero(),
        computational_implementation=lambda: 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.CONSTANT,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda: Z3Template(
            """
            params 0;
            bounded params 0;
            ReturnExpr 0;
            ReturnPred None;
            """,
        )
    )
    zero.add_example((0,))
    zero.add_nonexample((1,))
    zero.add_nonexample((2,))
    zero.add_nonexample((3,))
    zero.add_nonexample((4,))
    return zero


def create_one_concept():
    """Factory function to create the one concept."""
    one = Concept(
        name="one",
        description="The constant one",
        symbolic_definition=lambda: Succ(Zero()),
        computational_implementation=lambda: 1,
        example_structure=ExampleStructure(
            concept_type=ConceptType.CONSTANT,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda: Z3Template(
            """
            params 0;
            bounded params 0;
            ReturnExpr 1;
            ReturnPred None;
            """,
        )
    )
    one.add_example((1,))
    one.add_nonexample((0,))
    one.add_nonexample((2,))
    one.add_nonexample((3,))
    one.add_nonexample((4,))
    return one

def create_two_concept():
    """Factory function to create the two concept."""
    two = Concept(
        name="two",
        description="The constant two",
        symbolic_definition=lambda: Succ(Succ(Zero())),
        computational_implementation=lambda: 2,
        example_structure=ExampleStructure(
            concept_type=ConceptType.CONSTANT,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda: Z3Template(
            """
            params 0;
            bounded params 0;
            ReturnExpr 2;
            ReturnPred None;
            """,
        )
    )
    two.add_example((2,))
    two.add_nonexample((0,))    
    two.add_nonexample((1,))
    two.add_nonexample((3,))
    two.add_nonexample((4,))
    return two



def create_addition_concept():
    """Factory function to create the addition concept."""
    add = Concept(
        name="add",
        description="Addition of natural numbers defined by repeated succession",
        symbolic_definition=lambda a, b: Fold(b, a, Lambda("x", Succ(Var("x")))),
        computational_implementation=lambda a, b: a + b,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda a, b: Z3Template(
            """
            params 2;
            bounded params 0;
            ReturnExpr x_0 + x_1;
            ReturnPred None;
            """,
            a, b
        )
    )
    # Add examples (identity, commutativity, small values)
    for i in range(5):
        add.add_example((i, 0, i))
        add.add_example((0, i, i))
    add.add_example((2, 3, 5))
    add.add_example((3, 2, 5))
    add.add_example((4, 5, 9))
    # Add non-examples
    add.add_nonexample((2, 3, 6))
    add.add_nonexample((1, 1, 3))
    add.add_nonexample((0, 2, 1))
    add.add_nonexample((0, 0, 1))
    add.add_nonexample((0, 1, 2))
    add.add_nonexample((1, 0, 2))
    add.add_nonexample((1, 1, 3))
    add.add_nonexample((1, 2, 4))
    add.add_nonexample((2, 1, 4))
    add.add_nonexample((2, 2, 5))
    add.add_nonexample((2, 3, 6))
    add.add_nonexample((3, 2, 6))
    return add


def create_multiplication_concept():
    """Factory function to create the multiplication concept."""
    mult = Concept(
        name="multiply",
        description="Multiplication of natural numbers defined by repeated addition",
        symbolic_definition=lambda a, b: Fold(
            b, Zero(), Lambda("x", ConceptApplication(addition_concept, Var("x"), a))
        ),
        computational_implementation=lambda a, b: a * b,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda a, b: Z3Template(
            """
            params 2;
            bounded params 0;
            ReturnExpr x_0 * x_1;
            ReturnPred None;
            """,
            a, b
        )
    )
    # Add examples (zero, identity, commutativity, small values)
    for i in range(5):
        mult.add_example((i, 0, 0))
        mult.add_example((0, i, 0))
        mult.add_example((i, 1, i))
        mult.add_example((1, i, i))
    mult.add_example((2, 3, 6))
    mult.add_example((3, 2, 6))
    mult.add_example((4, 5, 20))
    mult.add_example((5, 4, 20))
    # Add non-examples
    mult.add_nonexample((2, 3, 5))
    mult.add_nonexample((4, 1, 5))
    mult.add_nonexample((0, 0, 1))
    mult.add_nonexample((1, 0, 1))
    mult.add_nonexample((0, 1, 1))
    mult.add_nonexample((1, 1, 2))
    
    return mult


def create_divides_concept():
    """Factory function to create the divides concept (predicate)."""
    div = Concept(
        name="divides",
        description="a divides b if there exists n such that b = a*n",
        symbolic_definition=lambda a, b: Exists(
            "n", NatDomain(), Equals(b, ConceptApplication(multiplication_concept, a, Var("n")))
        ),
        computational_implementation=lambda a, b: (a == 0 and b == 0) or (a != 0 and b % a == 0),
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda a, b: Z3Template(
            """
            params 2;
            bounded params 0;
            ReturnExpr None;
            ReturnPred Exists([b_0], b_0 * x_0 = x_1);
            """,
            a, b
        )
    )
    # Add examples (reflexive, 1 divides all, a divides 0, etc.)
    div.add_example((0, 0))
    for i in range(1, 6):
        div.add_example((i, i))
        div.add_example((1, i))
        div.add_example((i, 0))
    div.add_example((2, 4))
    div.add_example((2, 6))
    div.add_example((3, 6))
    # Add non-examples
    div.add_nonexample((0, 5))
    div.add_nonexample((5, 1))
    div.add_nonexample((2, 5))
    div.add_nonexample((3, 5))
    div.add_nonexample((2, 3))
    return div


def create_leq_than_concept():
    """Factory function to create the less-than-or-equal concept (predicate)."""
    leq = Concept(
        name="leq_than",
        description="a ≤ b if there exists m such that b = a + m",
        symbolic_definition=lambda a, b: Exists(
            "m", NatDomain(), Equals(b, ConceptApplication(addition_concept, a, Var("m")))
        ),
        computational_implementation=lambda a, b: a <= b,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda a, b: Z3Template(
            """
            params 2;
            bounded params 0;
            ReturnExpr None;
            ReturnPred (x_0 <= x_1);
            """,
            a, b
        )
    )
    # Add examples (reflexive, transitive hint, specific cases)
    for i in range(5):
        leq.add_example((i, i))
        leq.add_example((i, i + 1))
    leq.add_example((2, 5))
    leq.add_example((0, 2))
    leq.add_example((0, 3))
    leq.add_example((0, 4))
    leq.add_example((0, 5))
    # Add non-examples
    for i in range(1, 5):
        leq.add_nonexample((i + 1, i))
    leq.add_nonexample((5, 2))
    leq.add_nonexample((2, 0))
    leq.add_nonexample((3, 0))
    leq.add_nonexample((4, 0))
    leq.add_nonexample((5, 0))
    return leq


def create_equality_concept():
    """Factory function to create the equality concept (predicate)."""
    eq = Concept(
        name="eq",
        description="a equals b",
        symbolic_definition=lambda a, b: Equals(a, b),
        computational_implementation=lambda a, b: a == b,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
        can_add_examples=True,
        can_add_nonexamples=True,
        z3_translation=lambda a, b: Z3Template(
            """
            params 2;
            bounded params 0;
            ReturnExpr None;
            ReturnPred x_0 == x_1;
            """,
            a, b
        )
    )
    # Add examples (reflexive)
    for i in range(5):
        eq.add_example((i, i))
    # Add non-examples
    eq.add_nonexample((2, 3))
    eq.add_nonexample((0, 1))
    eq.add_nonexample((4, 0))
    eq.add_nonexample((0, 2))
    eq.add_nonexample((0, 3))
    eq.add_nonexample((0, 4))
    eq.add_nonexample((0, 5))
    eq.add_nonexample((1, 2))
    eq.add_nonexample((1, 3))
    eq.add_nonexample((1, 4))
    eq.add_nonexample((1, 5))
    eq.add_nonexample((2, 3))
    eq.add_nonexample((2, 4))
    eq.add_nonexample((2, 5))
    eq.add_nonexample((3, 4))
    eq.add_nonexample((3, 5))
    return eq

# =============================================================================
# Instantiate concepts for easy import elsewhere (e.g., demonstrations.py)
# =============================================================================
zero_concept = create_zero_concept()
one_concept = create_one_concept()
successor_concept = create_successor_concept()
addition_concept = create_addition_concept()
multiplication_concept = create_multiplication_concept()
divides_concept = create_divides_concept()
leq_than_concept = create_leq_than_concept() 
equality_concept = create_equality_concept()