"""
This module implements concrete mathematical concepts and demonstrations using the expression framework.

It contains implementations of:
- Basic arithmetic operations (addition, multiplication)
- Number theory concepts (divisibility, primality, perfect numbers)
- Comparison operators (less than, greater than, etc.)
- Set theory operations (union, intersection, power set)
- Mathematical conjectures (infinitude of primes, twin prime conjecture)

Each concept is defined with:
- Symbolic representation using the expression framework
- Computational implementation for concrete evaluation
- Translation rules for different target languages (Lean 4, Prolog, Z3)
- Example management and testing infrastructure

The module also includes comprehensive tests and demonstrations that can be run directly.
"""

from frame.knowledge_base.entities import (
    Var, Not, And, Or, Implies,
    Forall, Exists, Set, SetDomain, In, Zero, Succ, Nat,
    NatDomain, Lambda, Equals, SetCardinality, SetSum,
    Fold, Concept, ExampleType, Conjecture, ConceptApplication,
    ExampleStructure, ConceptType,
)

# =============================================================================
# Arithmetic/Number Theory Concepts & Conjectures
# =============================================================================

addition = Concept(
    name="add",
    description="Addition of natural numbers defined by repeated succession",
    symbolic_definition=lambda a, b: Fold(b, a, Lambda("x", Succ(Var("x")))),
    computational_implementation=lambda a, b: a + b,
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
    lean4_translation=lambda a, b: f"({a} + {b})",
    prolog_translation=lambda a, b: f"plus({a}, {b}, Result)",
    z3_translation=lambda a, b: f"(+ {a} {b})",
)

multiplication = Concept(
    name="multiply",
    description="Multiplication of natural numbers defined by repeated addition",
    symbolic_definition=lambda a, b: Fold(
        b, Zero(), Lambda("x", addition(Var("x"), a))
    ),
    computational_implementation=lambda a, b: a * b,
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
    lean4_translation=lambda a, b: f"({a} * {b})",
    prolog_translation=lambda a, b: f"times({a}, {b}, Result)",
    z3_translation=lambda a, b: f"(* {a} {b})",
)

# Update divides concept to use simpler manual translations
divides = Concept(
    name="divides",
    description="a divides b if there exists n such that b = a*n",
    symbolic_definition=lambda a, b: Exists(
        "n", NatDomain(), Equals(b, multiplication(a, Var("n")))
    ),
    computational_implementation=lambda a, b: a == 0
    and b == 0  # Special case: 0 divides 0
    or a != 0
    and b % a == 0,  # Normal case
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
    lean4_translation=lambda a, b: f"(∃ n : ℕ, {b} = {a} * n)",
    prolog_translation=lambda a, b: f"(0 is {b} mod {a})",
    z3_translation=lambda a, b: f"(exists ((n Int)) (= {b} (* {a} n)))",
)

less_than = Concept(
    name="less_than",
    description="a < b if there exists m such that b = a + succ(m)",
    symbolic_definition=lambda a, b: Exists(
        "m", NatDomain(), Equals(b, addition(a, Succ(Var("m"))))
    ),
    computational_implementation=lambda a, b: a < b,
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
    lean4_translation=lambda a, b, ctx=None: f"({a} < {b})",
    prolog_translation=lambda a, b: f"({a} < {b})",
    z3_translation=lambda a, b: f"(< {a} {b})",
)

# Now we can define proper_divisors using divides and less_than
proper_divisors = Concept(
    name="proper_divisors",
    description="The set of proper divisors of n (positive divisors less than n)",
    symbolic_definition=lambda n: Set(
        domain=NatDomain(),
        predicate=Lambda(
            "d",
            And(
                And(
                    ConceptApplication(divides, Var("d"), n),
                    ConceptApplication(less_than, Zero(), Var("d")),
                ),
                ConceptApplication(less_than, Var("d"), n),
            ),
        ),
    ),
    computational_implementation=lambda n: {i for i in range(1, n) if n % i == 0},
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.NUMERIC, ExampleType.SET),
        input_arity=1,
    ),
)

# Define is_prime using proper_divisors and divides
is_prime = Concept(
    name="is_prime",
    description="A number greater than 1 whose number of divisors is exactly 2",
    symbolic_definition=lambda n: And(
        ConceptApplication(less_than, Nat(1), n),
        Equals(
            SetCardinality(
                Set(
                    domain=NatDomain(),
                    predicate=Lambda("k", ConceptApplication(divides, Var("k"), n)),
                )
            ),
            Nat(2),
        ),
    ),
    computational_implementation=lambda n: n > 1
    and all(n % i != 0 for i in range(2, int(n**0.5) + 1)),
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC,),
        input_arity=1,
    ),
    lean4_translation=lambda n: f"(Nat.Prime {n})",
    prolog_translation=lambda n: f"is_prime({n})",
    z3_translation=lambda n: f"""
n := Exec({n.to_z3()})
x := NatVar();
return Not(Exists(x, And(And(x < n, x > 1), n % x == 0)))""",
)

is_perfect = Concept(
    name="is_perfect",
    description="A number that equals the sum of its proper divisors",
    symbolic_definition=lambda n: Equals(
        n, SetSum(ConceptApplication(proper_divisors, n))
    ),
    computational_implementation=lambda n: n
    == sum(i for i in range(1, n) if n % i == 0),
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC,),
        input_arity=1,
    ),
)

is_even = Concept(
    name="is_even",
    description="A number divisible by 2",
    symbolic_definition=lambda n: ConceptApplication(divides, Nat(2), n),
    computational_implementation=lambda n: n % 2 == 0,
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC,),
        input_arity=1,
    ),
)

is_odd = Concept(
    name="is_odd",
    description="Tests if a number is not divisible by 2",
    symbolic_definition=lambda n: Not(ConceptApplication(is_even, n)),
    computational_implementation=lambda n: n % 2 == 1,
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC,),
        input_arity=1,
    ),
)

leq_than = Concept(
    name="leq_than",
    description="a ≤ b if there exists m such that b = a + m",
    symbolic_definition=lambda a, b: Exists(
        "m", NatDomain(), Equals(b, addition(a, Var("m")))
    ),
    computational_implementation=lambda a, b: a <= b,
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
    lean4_translation=lambda a, b, ctx=None: f"({a} ≤ {b})",
    prolog_translation=lambda a, b: f"({a} =< {b})",
    z3_translation=lambda a, b: f"(<= {a} {b})",
)

greater_than = Concept(
    name="greater_than",
    description="a > b if not (a ≤ b)",
    symbolic_definition=lambda a, b: Not(ConceptApplication(leq_than, a, b)),
    computational_implementation=lambda a, b: a > b,
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
    lean4_translation=lambda a, b: f"({a} > {b})",
    prolog_translation=lambda a, b: f"({a} > {b})",
    z3_translation=lambda a, b: f"(> {a} {b})",
)

geq_than = Concept(
    name="geq_than",
    description="a ≥ b if not (a < b)",
    symbolic_definition=lambda a, b: Not(ConceptApplication(less_than, a, b)),
    computational_implementation=lambda a, b: a >= b,
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
    lean4_translation=lambda a, b: f"({a} ≥ {b})",
    prolog_translation=lambda a, b: f"({a} >= {b})",
    z3_translation=lambda a, b: f"(>= {a} {b})",
)

gcd = Concept(
    name="gcd",
    description="Greatest common divisor - largest number that divides both inputs",
    symbolic_definition=lambda a, b: Lambda(
        "d",
        And(
            And(
                ConceptApplication(divides, Var("d"), a),
                ConceptApplication(divides, Var("d"), b),
            ),
            Forall(
                "k",
                NatDomain(),
                Implies(
                    And(
                        ConceptApplication(divides, Var("k"), a),
                        ConceptApplication(divides, Var("k"), b),
                    ),
                    ConceptApplication(leq_than, Var("k"), Var("d")),
                ),
            ),
        ),
    ),
    computational_implementation=lambda a, b: b if a == 0 else gcd.compute(b % a, a),
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
)

lcm = Concept(
    name="lcm",
    description="Least common multiple - smallest number divisible by both inputs",
    symbolic_definition=lambda a, b: Lambda(
        "l",
        And(
            And(
                ConceptApplication(divides, a, Var("l")),
                ConceptApplication(divides, b, Var("l")),
            ),
            Forall(
                "k",
                NatDomain(),
                Implies(
                    And(
                        ConceptApplication(divides, a, Var("k")),
                        ConceptApplication(divides, b, Var("k")),
                    ),
                    ConceptApplication(leq_than, Var("l"), Var("k")),
                ),
            ),
        ),
    ),
    computational_implementation=lambda a, b: abs(a * b) // gcd.compute(a, b),
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC, ExampleType.NUMERIC),
        input_arity=2,
    ),
)


# TODO(_; 2/14): Reconsider example_structure for conjectures. Different conjectures may need different
# example handling approaches. For instance, some conjectures might need counterexamples, others might
# need families of examples, and others might not need examples at all. Consider making example_structure
# optional for conjectures or developing a more flexible system for conjecture examples.

infinitude_of_primes = Conjecture(
    name="infinitude_of_primes",
    description="For every natural number n, there exists a prime number greater than n",
    symbolic_definition=lambda: Forall(
        "n",
        NatDomain(),
        Exists(
            "p",
            NatDomain(),
            And(
                ConceptApplication(greater_than, Var("p"), Var("n")),
                ConceptApplication(is_prime, Var("p")),
            ),
        ),
    ),
    computational_implementation=lambda bound=1000: all(
        any(p > n and is_prime.compute(p) for p in range(n + 1, min(2 * n, bound)))
        for n in range(1, bound)
    ),
)

twin_prime_conjecture = Conjecture(
    name="twin_prime_conjecture",
    description="There are infinitely many pairs of prime numbers that differ by 2",
    symbolic_definition=lambda: Forall(
        "n",
        NatDomain(),
        Exists(
            "p",
            NatDomain(),
            And(
                ConceptApplication(greater_than, Var("p"), Var("n")),
                And(
                    ConceptApplication(is_prime, Var("p")),
                    ConceptApplication(is_prime, addition(Var("p"), Nat(2))),
                ),
            ),
        ),
    ),
    computational_implementation=lambda bound=1000: all(
        any(
            is_prime.compute(p) and is_prime.compute(p + 2)
            for p in range(n + 1, min(2 * n, bound))
        )
        for n in range(1, bound)
    ),
)

goldbach_conjecture = Conjecture(
    name="goldbach_conjecture",
    description="Every even integer greater than 2 can be expressed as the sum of two primes",
    symbolic_definition=lambda: Forall(
        "n",
        NatDomain(),
        Implies(
            And(
                ConceptApplication(greater_than, Var("n"), Nat(2)),
                ConceptApplication(is_even, Var("n")),
            ),
            Exists(
                "p",
                NatDomain(),
                Exists(
                    "q",
                    NatDomain(),
                    And(
                        And(
                            ConceptApplication(is_prime, Var("p")),
                            ConceptApplication(is_prime, Var("q")),
                        ),
                        Equals(Var("n"), addition(Var("p"), Var("q"))),
                    ),
                ),
            ),
        ),
    ),
    computational_implementation=lambda bound=1000: all(
        any(is_prime.compute(p) and is_prime.compute(n - p) for p in range(2, n - 1))
        for n in range(4, bound, 2)
    ),
)

# =============================================================================
# Demonstrations: Set Theory Concepts & Conjectures
# =============================================================================

# Define set union
set_union = Concept(
    name="union",
    description="Union of two sets (A ∪ B) = {x | x ∈ A ∨ x ∈ B}",
    symbolic_definition=lambda A, B: Set(
        domain=A.domain,  # Assuming both sets have same domain
        predicate=Lambda("x", Or(In(Var("x"), A), In(Var("x"), B))),
    ),
    computational_implementation=lambda A, B: A.union(B),
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.SET, ExampleType.SET, ExampleType.SET),
        input_arity=2,
    ),
    lean4_translation=lambda A, B: f"({A} ∪ {B})",
    prolog_translation=lambda A, B: f"union({A}, {B}, Result)",
    z3_translation=lambda A, B: f"(union {A} {B})",
)

# Define set intersection
set_intersection = Concept(
    name="intersection",
    description="Intersection of two sets (A ∩ B) = {x | x ∈ A ∧ x ∈ B}",
    symbolic_definition=lambda A, B: Set(
        domain=A.domain,  # Assuming both sets have same domain
        predicate=Lambda("x", And(In(Var("x"), A), In(Var("x"), B))),
    ),
    computational_implementation=lambda A, B: A.intersection(B),
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.SET, ExampleType.SET, ExampleType.SET),
        input_arity=2,
    ),
    lean4_translation=lambda A, B: f"({A} ∩ {B})",
    prolog_translation=lambda A, B: f"intersection({A}, {B}, Result)",
    z3_translation=lambda A, B: f"(inter {A} {B})",
)

# Define set difference
set_difference = Concept(
    name="difference",
    description="Set difference (A \\ B) = {x | x ∈ A ∧ x ∉ B}",
    symbolic_definition=lambda A, B: Set(
        domain=A.domain,  # Assuming both sets have same domain
        predicate=Lambda("x", And(In(Var("x"), A), Not(In(Var("x"), B)))),
    ),
    computational_implementation=lambda A, B: A.difference(B),
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.SET, ExampleType.SET, ExampleType.SET),
        input_arity=2,
    ),
    lean4_translation=lambda A, B: f"({A} \\ {B})",
    prolog_translation=lambda A, B: f"subtract({A}, {B}, Result)",
    z3_translation=lambda A, B: f"(setminus {A} {B})",
)

# Define symmetric difference
symmetric_difference = Concept(
    name="symmetric_difference",
    description="Symmetric difference of two sets (A △ B) = {x | (x ∈ A ∧ x ∉ B) ∨ (x ∈ B ∧ x ∉ A)}",
    symbolic_definition=lambda A, B: Set(
        domain=A.domain,  # Assuming both sets have same domain
        predicate=Lambda(
            "x",
            Or(
                And(In(Var("x"), A), Not(In(Var("x"), B))),
                And(In(Var("x"), B), Not(In(Var("x"), A))),
            ),
        ),
    ),
    computational_implementation=lambda A, B: A.symmetric_difference(B),
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.SET, ExampleType.SET, ExampleType.SET),
        input_arity=2,
    ),
    lean4_translation=lambda A, B: f"(({A} \\ {B}) ∪ ({B} \\ {A}))",
    prolog_translation=lambda A, B: f"symmetric_difference({A}, {B}, Result)",
    z3_translation=lambda A, B: f"(union (setminus {A} {B}) (setminus {B} {A}))",
)

# Define subset
subset = Concept(
    name="subset",
    description="A is a subset of B (A ⊆ B) if every element of A is in B",
    symbolic_definition=lambda A, B: Forall(
        "x", A.domain, Implies(In(Var("x"), A), In(Var("x"), B))
    ),
    computational_implementation=lambda A, B: A.issubset(B),
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.SET, ExampleType.SET),
        input_arity=2,
    ),
    lean4_translation=lambda A, B: f"({A} ⊆ {B})",
    prolog_translation=lambda A, B: f"subset({A}, {B})",
    z3_translation=lambda A, B: f"(subset {A} {B})",
)

# Define power set
power_set = Concept(
    name="power_set",
    description="Power set of A (𝒫(A)) is the set of all subsets of A",
    symbolic_definition=lambda A: Set(
        domain=SetDomain(),  # The domain is now a set of sets
        predicate=Lambda("B", ConceptApplication(subset, Var("B"), A)),
    ),
    computational_implementation=lambda A: {
        frozenset(x for j, x in enumerate(A) if i & (1 << j))
        for i in range(1 << len(A))
    },
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.SET, ExampleType.SET),
        input_arity=1,
    ),
)

distributivity_union_over_intersection = Conjecture(
    name="distributivity_union_over_intersection",
    description="Union distributes over intersection: A ∪ (B ∩ C) = (A ∪ B) ∩ (A ∪ C)",
    symbolic_definition=lambda: Forall(
        "A",
        SetDomain(),
        Forall(
            "B",
            SetDomain(),
            Forall(
                "C",
                SetDomain(),
                Equals(
                    ConceptApplication(
                        set_union,
                        Var("A"),
                        ConceptApplication(set_intersection, Var("B"), Var("C")),
                    ),
                    ConceptApplication(
                        set_intersection,
                        ConceptApplication(set_union, Var("A"), Var("B")),
                        ConceptApplication(set_union, Var("A"), Var("C")),
                    ),
                ),
            ),
        ),
    ),
    computational_implementation=lambda bound=10: all(
        A.union(B.intersection(C)) == A.union(B).intersection(A.union(C))
        for A in [{i for i in range(bound)} for _ in range(3)]
        for B in [{i for i in range(bound)} for _ in range(3)]
        for C in [{i for i in range(bound)} for _ in range(3)]
    ),
)

# =============================================================================
# Demonstrations: Group Theory Concepts & Conjectures
# =============================================================================

# Define group cardinality
group_cardinality = Concept(
    name="group_cardinality",
    description=(
        "The cardinality of a group is the number of elements in its carrier (underlying set). "
        "For a group represented as an instance of Group, this is simply the size of its carrier."
    ),
    # Symbolically, we extract the carrier (a Set) from the group and apply SetCardinality.
    symbolic_definition=lambda grp: SetCardinality(grp.carrier),
    # Computationally, we evaluate the carrier to a concrete Python set and take its length.
    computational_implementation=lambda grp: len(grp.carrier.evaluate()),
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(ExampleType.GROUP, ExampleType.NUMERIC),
        input_arity=1,
    ),
    lean4_translation=lambda grp: f"|{grp.carrier.to_lean4()}|",
    prolog_translation=lambda grp: f"group_cardinality({grp.carrier.to_prolog()}, Result)",
    z3_translation=lambda grp: f"(cardinality {grp.carrier.to_z3()})",
)

# Define abelian group
is_abelian = Concept(
    name="is_abelian",
    description="A group is abelian if its operation is commutative.",
    symbolic_definition=lambda grp: Forall(
        "a",
        grp.carrier,
        Forall(
            "b",
            grp.carrier,
            Equals(
                grp.op(Var("a"), Var("b")),
                grp.op(Var("b"), Var("a")),
            ),
        ),
    ),
    computational_implementation=lambda grp: all(
        grp.op(a, b) == grp.op(b, a)
        for a in grp.carrier.evaluate()
        for b in grp.carrier.evaluate()
    ),
    example_structure=ExampleStructure(
        concept_type=ConceptType.PREDICATE,
        component_types=(ExampleType.GROUP,),
        input_arity=1,
    ),
)

# Define conjugation by element
conjugation = Concept(
    name="conjugation",
    description=(
        "Conjugation of a group element x by an element g is defined as g * x * g⁻¹. "
        "This operation is fundamental in group theory, capturing how one element 'acts' on another."
    ),
    # Symbolic definition: for a group 'grp', g and x in its carrier, the conjugate is:
    # op(op(g, x), inverse(g))
    symbolic_definition=lambda grp, g, x: ConceptApplication(
        grp.op, ConceptApplication(grp.op, g, x), ConceptApplication(grp.inverse, g)
    ),
    # Computational implementation: perform the operation using the group's op and inverse functions.
    computational_implementation=lambda grp, g, x: grp.op(grp.op(g, x), grp.inverse(g)),
    example_structure=ExampleStructure(
        concept_type=ConceptType.FUNCTION,
        component_types=(
            ExampleType.GROUP,  # group (an instance of Group)
            ExampleType.GROUPELEMENT,  # element g
            ExampleType.GROUPELEMENT,  # element x
            ExampleType.GROUPELEMENT,  # result
        ),
        input_arity=3,
    ),
    lean4_translation=lambda grp, g, x: f"({g} * {x} * {grp.inverse(g)})",
    prolog_translation=lambda grp, g, x: f"conjugate({grp.to_prolog()}, {g}, {x}, Result)",
    z3_translation=lambda grp, g, x: f"(conjugate {grp.to_z3()} {g} {x})",
)
