"""
Test suite for knowledge graph construction and manipulation.

This file tests the construction of mathematical knowledge using production rules
and the knowledge graph's ability to track relationships between entities.
"""

from datetime import datetime
from frame.productions.concepts.map_iterate import MapIterateRule
from frame.knowledge_base.initial_concepts import zero_concept, one_concept, create_successor_concept
from frame.productions.concepts.compose import ComposeRule
from frame.productions.concepts.negate import NegateRule
from frame.productions.concepts.exists import ExistsRule
from frame.productions.concepts.match import MatchRule
from frame.productions.concepts.specialize import SpecializeRule
from frame.productions.concepts.constant import ConstantRule
from frame.productions.concepts.forall import ForallRule
from frame.productions.conjectures.equivalence import EquivalenceRule
from frame.productions.conjectures.implication import ImplicationRule
from frame.productions.conjectures.nonexistence import NonexistenceRule
from frame.productions.conjectures.exclusivity import ExclusivityRule


from frame.knowledge_base.knowledge_graph import (
    KnowledgeGraph,
    ConstructionStep,
    NodeType,
)
from frame.knowledge_base.entities import (
    Concept,
    Example,
    ExampleType,
    ExampleStructure,
    ConceptType,
    Expression,
    Var,
    ConceptApplication,
    Succ,
    Zero,
    Nat,
    Exists,
    And,
    Not,
    Equals,
    Lambda,
    Fold,
    ExampleCollection,
)
from frame.interestingness.learning.dsl_primitives import get_entity_step_age
from frame.knowledge_base.demonstrations import is_prime
import os


def test_knowledge_graph_construction():
    """Test building up mathematical knowledge using production rules in the knowledge graph."""
    print("\n=== Testing Knowledge Graph Construction ===")

    # Initialize knowledge graph and rules
    graph = KnowledgeGraph()
    map_iterate = MapIterateRule()
    match = MatchRule()
    exists = ExistsRule()
    specialize = SpecializeRule()
    negate = NegateRule()
    compose = ComposeRule()
    nonexistence = NonexistenceRule()
    equivalence = EquivalenceRule()
    exclusivity = ExclusivityRule()
    
    # Create successor concept using helper function
    successor = create_successor_concept()

    # Add successor to graph
    successor_id = graph.add_concept(successor)

    # Create addition as fold over successor
    print("\nCreating addition from successor...")
    addition = map_iterate.apply(successor)
    addition.name = "addition"  # Rename for clarity
    addition_id = graph.add_concept(
        addition, ConstructionStep(map_iterate, [successor_id], {}, datetime.now())
    )

    # Add examples for addition
    addition.add_example((2, 3, 5))
    addition.add_example((0, 5, 5))
    addition.add_example((5, 0, 5))
    addition.add_example((0, 0, 0))  # Adding 0 to 0
    addition.add_example((1, 0, 1))  # Adding 0 to 1
    addition.add_example((0, 1, 1))  # Adding 1 to 0
    addition.add_example((4, 0, 4))  # Adding 0 to 4
    addition.add_example((0, 4, 4))  # Adding 4 to 0

    # Add zero concept to graph
    zero_id = graph.add_concept(zero_concept)

    # Create one_concept using ConstantRule
    print("\nCreating one_concept using ConstantRule...")
    constant = ConstantRule()
    # Create one by applying successor to zero
    successor_of_zero = successor.compute(0)  # Get the value 1
    one = constant.apply(
        successor, example=Example((0, 1), successor.examples.example_structure)
    )
    one.name = "one"
    one_id = graph.add_concept(
        one,
        ConstructionStep(
            constant,
            [successor_id],
            {"example": Example((0, 1), successor.examples.example_structure)},
            datetime.now(),
        ),
    )

    # Alternative way to create one (commented out):
    print("\nAlternative: Creating one_concept using SpecializeRule...")
    one_alt = specialize.apply(successor, zero_concept, index_to_specialize=0)
    one_alt.name = "one_alt"
    # This would create one as successor(zero), fixing the input to zero

    # Create multiplication as fold over addition
    print("\nCreating multiplication from addition...")
    multiplication = map_iterate.apply(addition, zero_concept)
    multiplication.name = "multiplication"
    multiplication_id = graph.add_concept(
        multiplication,
        ConstructionStep(map_iterate, [addition_id, zero_id], {}, datetime.now()),
    )

    # Add examples for multiplication
    multiplication.add_example((2, 3, 6))
    multiplication.add_example((3, 4, 12))
    multiplication.add_example((5, 0, 0))  # 5 * 0 = 0 (no iterations of addition)
    multiplication.add_example((0, 5, 0))  # 0 * 5 = 0 (adding 0 five times)
    multiplication.add_example((0, 0, 0))  # 0 * 0 = 0 (no iterations of adding 0)
    multiplication.add_example((1, 0, 0))  # 1 * 0 = 0 (no iterations of addition)
    multiplication.add_example((0, 1, 0))  # 0 * 1 = 0 (adding 0 once)
    multiplication.add_example((4, 0, 0))  # 4 * 0 = 0 (no iterations of addition)
    multiplication.add_example((0, 4, 0))  # 0 * 4 = 0 (adding 0 four times)

    # TODO: Synthesize Succ(Zero) as a concept for the accumulator
    print("\nCreating power from multiplication...")
    power = map_iterate.apply(multiplication, one_concept)
    power.name = "power"
    power_id = graph.add_concept(
        power, ConstructionStep(map_iterate, [multiplication_id], {}, datetime.now())
    )

    # Add examples for power
    power.add_example((2, 3, 8))
    power.add_example((3, 2, 9))
    power.add_example((2, 4, 16))

    # Create square from multiplication using match
    print("\nCreating square from multiplication...")
    square = match.apply(multiplication, indices_to_match=[0, 1])
    square.name = "square"
    square_id = graph.add_concept(
        square,
        ConstructionStep(
            match, [multiplication_id], {"indices_to_match": [0, 1]}, datetime.now()
        ),
    )

    # Add examples for square
    square.add_example((2, 4))
    square.add_example((3, 9))
    square.add_example((4, 16))

    # Create divides from multiplication using exists
    print("\nCreating divides from multiplication...")
    divides = exists.apply(
        multiplication, indices_to_quantify=[0]
    )  # Fixed parameter name
    divides.name = "divides"

    # Note(_; 2/19): The exists rule currently does not generate computational implementations.
    # We need to find a way to automatically derive computational interpretations
    # for existentially quantified concepts. For now, we manually add the implementation.
    divides.set_computational_implementation(
        lambda a, b: a != 0
        and any(multiplication.compute(k, a) == b for k in range(abs(b) + 1))
    )

    divides_id = graph.add_concept(
        divides,
        ConstructionStep(
            exists, [multiplication_id], {"indices_to_quantify": [0]}, datetime.now()
        ),
    )

    # Add examples for divides
    divides.add_example((2, 4))  # 2 divides 4
    divides.add_example((3, 6))  # 3 divides 6
    divides.add_nonexample((2, 5))  # 2 does not divide 5

    # Create constant concept for 2
    two_concept = Concept(
        name="two",
        description="The constant two",
        symbolic_definition=lambda: Nat(2),
        computational_implementation=lambda: 2,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC,),  # Output type is numeric
            input_arity=0,  # No inputs for a constant
        ),
    )
    two_id = graph.add_concept(two_concept)

    # Create constant concept for 4
    four_concept = Concept(
        name="four",
        description="The constant four",
        symbolic_definition=lambda: Nat(4),
        computational_implementation=lambda: 4,
        example_structure=ExampleStructure(
            concept_type=ConceptType.FUNCTION,
            component_types=(ExampleType.NUMERIC,),  # Output type is numeric
            input_arity=0,  # No inputs for a constant
        ),
    )
    four_id = graph.add_concept(four_concept)

    # Create is_even by specializing divides
    print("\nCreating is_even from divides...")
    print(f"Type of divides: {type(divides)}")
    print(f"Module of divides type: {type(divides).__module__}")
    print(f"Is divides a Concept? {isinstance(divides, Concept)}")
    print(f"Concept class module: {Concept.__module__}")

    print(f"\nType of two_concept: {type(two_concept)}")
    print(f"Module of two_concept type: {type(two_concept).__module__}")
    print(f"Is two_concept a Concept? {isinstance(two_concept, Concept)}")
    if isinstance(two_concept, Concept):
        print(
            f"two_concept concept type: {two_concept.examples.example_structure.concept_type}"
        )
        print(
            f"two_concept component types: {two_concept.examples.example_structure.component_types}"
        )
        print(
            f"two_concept input arity: {two_concept.examples.example_structure.input_arity}"
        )

    is_even = specialize.apply(divides, two_concept, index_to_specialize=0)
    is_even.name = "is_even"
    is_even_id = graph.add_concept(
        is_even,
        ConstructionStep(
            specialize, [divides_id], {"index_to_specialize": 0}, datetime.now()
        ),
    )

    # Add examples for is_even
    is_even.add_example((2,))
    is_even.add_example((4,))
    is_even.add_nonexample((1,))
    is_even.add_nonexample((3,))

    # Create is_odd by negating is_even
    print("\nCreating is_odd from is_even...")
    is_odd = negate.apply(is_even)
    is_odd.name = "is_odd"
    is_odd_id = graph.add_concept(
        is_odd, ConstructionStep(negate, [is_even_id], {}, datetime.now())
    )

    # Create divisible_by_4 by specializing divides
    print("\nCreating divisible_by_4 from divides...")
    divisible_by_4 = specialize.apply(divides, four_concept, index_to_specialize=0)
    divisible_by_4.name = "divisible_by_4"
    divisible_by_4_id = graph.add_concept(
        divisible_by_4,
        ConstructionStep(
            specialize, [divides_id], {"index_to_specialize": 0}, datetime.now()
        ),
    )

    # Add examples for divisible_by_4
    divisible_by_4.add_example((4,))
    divisible_by_4.add_example((8,))
    divisible_by_4.add_nonexample((2,))
    divisible_by_4.add_nonexample((6,))

    # Create has_all_even_divisors using ForallRule on divides and is_even
    print("\nCreating has_all_even_divisors from divides and is_even...")
    forall = ForallRule()
    has_all_even_divisors = forall.apply(divides, is_even, indices_to_quantify=[0], indices_to_map={0: 0})
    has_all_even_divisors.name = "has_all_even_divisors"
    has_all_even_divisors_id = graph.add_concept(
        has_all_even_divisors,
        ConstructionStep(forall, [divides_id, is_even_id], {}, datetime.now()),
    )

    # Add examples for has_all_even_divisors
    has_all_even_divisors.add_nonexample((4,))  # 4 has divisors 1,2,4 - not all even
    has_all_even_divisors.add_nonexample((8,))  # 8 has divisors 1,2,4,8 - not all even
    has_all_even_divisors.add_nonexample((2,))  # 2 has divisors 1,2 - not all even
    has_all_even_divisors.add_nonexample((6,))  # 6 has divisors 1,2,3,6 - not all even

    # TODO: Synthesize is_prime via production rules
    print("\nImporting is_prime (should be synthesized)...")
    is_prime_id = graph.add_concept(is_prime)

    # Create prime_square via compose
    print("\nCreating prime_square from is_prime and square...")
    # square takes 1 input and returns 1 output
    # is_prime takes 1 input and returns 1 output (boolean)
    # We want to check if square(x) is prime, so map square's output to is_prime's input
    prime_square = compose.apply(
        square,  # inner concept f(x) = x²
        is_prime,  # outer concept P(y) = is_prime(y)
        output_to_input_map={0: 0},  # map square's output to is_prime's input
    )
    prime_square.name = "prime_square"
    prime_square_id = graph.add_concept(
        prime_square,
        ConstructionStep(compose, [is_prime_id, square_id], {}, datetime.now()),
    )

    # Add examples for prime_square (all should be nonexamples)
    prime_square.add_nonexample((4,))
    prime_square.add_nonexample((9,))
    prime_square.add_nonexample((25,))

    # Create conjecture about nonexistence of prime squares
    print("\nCreating conjecture about nonexistence of prime squares...")
    no_prime_squares = nonexistence.apply(prime_square)
    no_prime_squares_id = graph.add_conjecture(
        no_prime_squares,
        ConstructionStep(nonexistence, [prime_square_id], {}, datetime.now()),
    )

    # Create is_square predicate from square function using exists rule
    print("\nCreating is_square predicate from square function...")
    is_square = exists.apply(square, indices_to_quantify=[0])  # Quantify over the input
    is_square.name = "is_square"
    is_square.set_computational_implementation(lambda n: int(n**0.5) ** 2 == n)
    is_square_id = graph.add_concept(
        is_square,
        ConstructionStep(
            exists, [square_id], {"indices_to_quantify": [0]}, datetime.now()
        ),
    )

    # Add examples for is_square
    is_square.add_example((0,))  # 0 = 0²
    is_square.add_example((1,))  # 1 = 1²
    is_square.add_example((4,))  # 4 = 2²
    is_square.add_example((9,))  # 9 = 3²
    is_square.add_example((16,))  # 16 = 4²
    is_square.add_nonexample((2,))  # Not a square
    is_square.add_nonexample((3,))  # Not a square
    is_square.add_nonexample((5,))  # Not a square
    is_square.add_nonexample((6,))  # Not a square
    is_square.add_nonexample((7,))  # Not a square
    is_square.add_nonexample((8,))  # Not a square

    # Create even_square by composing is_even and is_square predicates
    print("\nCreating even_square from is_even and is_square...")
    even_square = compose.apply(
        is_even,  # first predicate
        is_square,  # second predicate
        shared_vars={0: 0},  # share the first variable between predicates
    )
    even_square.name = "even_square"
    even_square_id = graph.add_concept(
        even_square,
        ConstructionStep(compose, [is_even_id, is_square_id], {}, datetime.now()),
    )

    # Add construction edges
    graph.add_construction_edge(is_even_id, even_square_id)
    graph.add_construction_edge(is_square_id, even_square_id)

    # Clear existing examples to ensure consistency
    even_square.examples = ExampleCollection(even_square.examples.example_structure)
    divisible_by_4.examples = ExampleCollection(
        divisible_by_4.examples.example_structure
    )

    # Programmatically add examples and nonexamples for both concepts
    # This ensures consistency between the two concepts for the equivalence rule
    print("\nAdding consistent examples for even_square and divisible_by_4...")

    # Define computational implementations for both predicates
    def is_even_square(n):
        return n % 2 == 0 and int(n**0.5) ** 2 == n

    def is_divisible_by_4(n):
        return n % 4 == 0

    # Add examples and nonexamples up to a certain limit
    for n in range(50):  # Check numbers 0 through 49
        is_es = is_even_square(n)
        is_div4 = is_divisible_by_4(n)

        # If both predicates agree, add as example or nonexample to both
        if is_es and is_div4:
            even_square.add_example((n,))
            divisible_by_4.add_example((n,))
            print(f"  Added {n} as example to both concepts")
        elif not is_es and not is_div4:
            even_square.add_nonexample((n,))
            divisible_by_4.add_nonexample((n,))
            print(f"  Added {n} as nonexample to both concepts")
        # If they disagree, this disproves the equivalence
        elif is_es and not is_div4:
            even_square.add_example((n,))
            divisible_by_4.add_nonexample((n,))
            print(
                f"  Added {n} as example to even_square but nonexample to divisible_by_4"
            )
        else:  # not is_es and is_div4
            even_square.add_nonexample((n,))
            divisible_by_4.add_example((n,))
            print(
                f"  Added {n} as nonexample to even_square but example to divisible_by_4"
            )

    # Create even_square_div4 by composing even_square and divisible_by_4
    print("\nTesting if equivalence rule can be applied...")
    can_apply = equivalence.can_apply(even_square, divisible_by_4, verbose=True)
    print(f"Can apply equivalence rule: {can_apply}")

    if can_apply:
        even_square_div4 = equivalence.apply(even_square, divisible_by_4)
        even_square_div4_id = graph.add_conjecture(
            even_square_div4,
            ConstructionStep(
                equivalence, [even_square_id, divisible_by_4_id], {}, datetime.now()
            ),
        )
    else:
        # If equivalence doesn't apply, we can use implication instead
        print("\nUsing implication rule instead...")
        implication = ImplicationRule()

        # Check if even_square implies divisible_by_4
        if implication.can_apply(even_square, divisible_by_4):
            even_square_implies_div4 = implication.apply(even_square, divisible_by_4)
            graph.add_conjecture(
                even_square_implies_div4,
                ConstructionStep(
                    implication, [even_square_id, divisible_by_4_id], {}, datetime.now()
                ),
            )

        # Check if divisible_by_4 implies even_square
        if implication.can_apply(divisible_by_4, even_square):
            div4_implies_even_square = implication.apply(divisible_by_4, even_square)
            graph.add_conjecture(
                div4_implies_even_square,
                ConstructionStep(
                    implication, [divisible_by_4_id, even_square_id], {}, datetime.now()
                ),
            )

    
    # Create even_prime via compose
    print("\nCreating even_prime from is_even and is_prime...")
    even_prime = compose.apply(
        is_even,    # first predicate
        is_prime,  # second predicate
        shared_vars={0: 0}  # share the first variable between predicates
    )
    even_prime.name = "even_prime"
    even_prime_id = graph.add_concept(even_prime,
        ConstructionStep(compose, [is_even_id, is_prime_id], {}, datetime.now()))

    # Add construction edges
    graph.add_construction_edge(is_even_id, even_prime_id)
    graph.add_construction_edge(is_prime_id, even_prime_id)

    # Add examples for even_prime
    even_prime.add_example((2,))     
    even_prime.add_nonexample((4,))
    even_prime.add_nonexample((9,))
    even_prime.add_nonexample((25,))
    
    # Create and apply Exclusivity
    print("\nChecking if rule can be applied to even_prime...")
    can_apply = exclusivity.can_apply(even_prime, valid_set={(2,)})
    print(f"Can apply: {can_apply}")
    
    if can_apply:
        print("\nApplying rule to create excl_even_prime_numbers conjecture...")
        excl_even_prime = exclusivity.apply(even_prime, valid_set={(2,)})
        graph.add_conjecture(
            excl_even_prime,
            ConstructionStep(
                exclusivity, [even_prime_id], {}, datetime.now()
            ),
        )

    # Print graph statistics
    print("\nKnowledge Graph Statistics:")
    print(f"Number of concepts: {len(graph.get_all_concepts())}")
    print(f"Number of conjectures: {len(graph.get_all_conjectures())}")
    print(f"Total number of nodes: {graph.number_of_nodes()}")
    print(f"Total number of edges: {graph.number_of_edges()}")
    
    # Print information about tracked instances
    print("\nTracked Instances by Structure:")
    for structure, values in graph._instances_by_structure.items():
        print(f"\nStructure: {structure}")
        print(f"  Number of instances: {len(values)}")
        print("  Instances:", end=" ")
        # Convert to sorted list and take first 20 items
        sorted_values = sorted(values)
        if len(sorted_values) > 20:
            print(sorted_values[:20], "...")
        else:
            print(sorted_values)

    # Test related concepts methods
    print("\nTesting Related Concepts:")
    print("\nConcepts constructed from successor:")
    for concept_id in graph.get_related_concepts(successor_id):
        concept = graph.nodes[concept_id]["entity"]
        print(f"- {concept.name}")

    print("\nConcepts used in constructing multiplication:")
    for concept_id in graph.get_related_concepts(multiplication_id):
        concept = graph.nodes[concept_id]["entity"]
        print(f"- {concept.name}")

    print("\nConcepts related to is_even:")
    for concept_id in graph.get_related_concepts(is_even_id):
        concept = graph.nodes[concept_id]["entity"]
        print(f"- {concept.name}")

    # Test dependency depth
    print("\nConstruction Depths:")
    for concept_id in [addition_id, multiplication_id, power_id, square_id, is_even_id]:
        concept = graph.nodes[concept_id]["entity"]
        depth = graph.construction_depth(concept_id)
        print(f"{concept.name}: depth {depth}")

    # Test saving and loading
    print("\nTesting graph persistence...")
    os.makedirs("data/graphs", exist_ok=True)
    graph.save("data/graphs/test_graph")
    loaded_graph = KnowledgeGraph.load("data/graphs/test_graph.dill")

    # Verify the loaded graph has the same structure
    assert loaded_graph.number_of_nodes() == graph.number_of_nodes()
    assert loaded_graph.number_of_edges() == graph.number_of_edges()

    # Visualize the loaded graph to verify it looks the same
    print("\nVisualizing loaded graph...")
    os.makedirs("data/visualizations", exist_ok=True)
    loaded_graph.visualize_construction_tree(
        output_file="data/visualizations/loaded_graph_construction"
    )

    return graph


def test_step_age_tracking():
    """Test the step counter and entity step age functionality."""
    print("\n=== Testing Step Age Tracking ===")
    
    # Initialize knowledge graph
    graph = KnowledgeGraph()
    
    # Initial step counter should be 0
    assert graph.get_step_counter() == 0
    
    # Add some concepts
    successor = create_successor_concept()
    successor_id = graph.add_concept(successor)
    
    # Step counter should increment after each addition
    assert graph.get_step_counter() == 1
    assert graph.get_entity_creation_step(successor_id) == 0
    assert graph.get_entity_step_age(successor_id) == 1
    
    # Add zero concept to graph
    zero_id = graph.add_concept(zero_concept)
    assert graph.get_step_counter() == 2
    assert graph.get_entity_creation_step(zero_id) == 1
    assert graph.get_entity_step_age(zero_id) == 1
    # First entity should now be older
    assert graph.get_entity_step_age(successor_id) == 2
    
    # Add a third concept
    map_iterate = MapIterateRule()
    addition = map_iterate.apply(successor)
    addition.name = "addition"
    addition_id = graph.add_concept(
        addition, ConstructionStep(map_iterate, [successor_id], {}, datetime.now())
    )
    
    assert graph.get_step_counter() == 3
    assert graph.get_entity_creation_step(addition_id) == 2
    assert graph.get_entity_step_age(addition_id) == 1
    assert graph.get_entity_step_age(zero_id) == 2
    assert graph.get_entity_step_age(successor_id) == 3
    
    # Test the DSL primitive function
    assert get_entity_step_age(successor_id, graph) == 3
    assert get_entity_step_age(zero_id, graph) == 2
    assert get_entity_step_age(addition_id, graph) == 1
    
    print("Step age tracking test passed!")


def main():
    """Run knowledge graph tests"""
    print("\n=== Running Knowledge Graph Tests ===")
    test_knowledge_graph_construction()
    test_step_age_tracking()
    print("\n=== All Tests Complete ===")


if __name__ == "__main__":
    main()
