"""Tests for the math environment with natural numbers knowledge graph.

This test file creates a non-enumerative policy that follows the same sequence
of actions as in test_knowledge_graph.py to test the math environment.
"""

import pytest
import numpy as np
from typing import List, Dict, Any, Optional, Union, Tuple
from datetime import datetime
import os

from frame.environments.math_env import MathEnv, ValidAction
from frame.knowledge_base.knowledge_graph import (
    KnowledgeGraph,
    ConstructionStep,
    NodeType,
)
from frame.knowledge_base.entities import (
    Concept,
    Example,
    ExampleType,
    ExampleStructure,
    ConceptType,
    Expression,
    Var,
    ConceptApplication,
    Succ,
    Nat,
    Conjecture,
)
from frame.knowledge_base.demonstrations import is_prime
from frame.productions.concepts.map_iterate import MapIterateRule
from frame.productions.concepts.compose import ComposeRule
from frame.productions.concepts.negate import NegateRule
from frame.productions.concepts.exists import ExistsRule
from frame.productions.concepts.match import MatchRule
from frame.productions.concepts.specialize import SpecializeRule
from frame.productions.concepts.size import SizeRule
from frame.productions.concepts.constant import ConstantRule
from frame.productions.concepts.forall import ForallRule
from frame.productions.conjectures.equivalence import EquivalenceRule
from frame.productions.conjectures.implication import ImplicationRule
from frame.productions.conjectures.nonexistence import NonexistenceRule
from frame.productions.conjectures.exclusivity import ExclusivityRule
from frame.knowledge_base.initial_concepts import create_successor_concept, zero_concept, one_concept
from frame.policies.base import Policy
from frame.productions.base import ProductionRule

class NatSequencePolicy(Policy):
    """
    A non-enumerative policy that follows a predefined sequence of actions
    based on the test_knowledge_graph.py test.
    """

    def __init__(self):
        """Initialize the policy with a predefined sequence of actions."""
        super().__init__(requires_enumeration=False)
        self.action_sequence = []
        self.current_step = 0
        self._setup_action_sequence()

    def _setup_action_sequence(self):
        """Set up the sequence of actions to follow."""
        # The sequence will be populated based on the graph state
        # This will be filled in during the test
        pass
    
    def select_action(self, env: MathEnv) -> Optional[ValidAction]:
        """
        Select the next action in the predefined sequence.

        Args:
            env: The math environment containing the current state
            
        Returns:
            A ValidAction object or None if no more actions
        """
        if self.current_step >= len(self.action_sequence):
            return None

        action = self.action_sequence[self.current_step]
        self.current_step += 1
        return action
    
    def update(self, 
              env: MathEnv,
              action: Union[int, ValidAction], 
              reward: float, 
              done: bool) -> None:
        """Update is not needed for this policy."""
        pass


@pytest.fixture
def initial_graph():
    """Create an initial knowledge graph with Zero and Succ."""
    graph = KnowledgeGraph()
    graph.add_concept(zero_concept)
    graph.add_concept(create_successor_concept())
    return graph


@pytest.fixture
def production_rules():
    """Create a list of production rules for testing."""
    return [
        MapIterateRule(),
        MatchRule(),
        ExistsRule(),
        SpecializeRule(),
        NegateRule(),
        ComposeRule(),
        SizeRule(),
        NonexistenceRule(),
        EquivalenceRule(),
        ImplicationRule(),
        ExclusivityRule(),
        ForallRule(),
    ]


def test_math_env_nat_sequence(initial_graph, production_rules):
    """Test the math environment with a natural numbers knowledge graph."""
    # Create the environment
    env = MathEnv(
        initial_graph=initial_graph,
        production_rules=production_rules,
        max_steps=20,
        enumerate_actions=False,  # Disable action enumeration
    )

    # Create the policy
    policy = NatSequencePolicy()

    # Reset the environment
    graph, _ = env.reset()

    # Find indices of the rules we'll use
    map_iterate_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, MapIterateRule)
    )
    match_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, MatchRule)
    )
    exists_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, ExistsRule)
    )
    specialize_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, SpecializeRule)
    )
    negate_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, NegateRule)
    )
    compose_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, ComposeRule)
    )
    nonexistence_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, NonexistenceRule)
    )
    equivalence_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, EquivalenceRule)
    )
    exclusivity_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, ExclusivityRule)
    )

    # Step 1: Create successor concept and add it to the graph
    successor = create_successor_concept()
    successor_id = graph.add_concept(successor)

    # Step 2: Add zero concept to the graph
    zero_id = graph.add_concept(zero_concept)
    print(f"Added zero concept with ID: {zero_id}")

    # Create one_concept using ConstantRule
    print("\nCreating one_concept using ConstantRule...")
    constant = ConstantRule()
    # Create one by applying successor to zero
    successor_of_zero = successor.compute(0)  # Get the value 1
    one = constant.apply(
        successor, example=Example((0, 1), successor.examples.example_structure)
    )
    one.name = "one"
    one_id = graph.add_concept(
        one,
        ConstructionStep(
            constant,
            [successor_id],
            {"example": Example((0, 1), successor.examples.example_structure)},
            datetime.now(),
        ),
    )
    print(f"Added one concept with ID: {one_id}")

    # Step 3: Create addition from successor using MapIterateRule
    policy.action_sequence.append(
        ValidAction(rule_idx=map_iterate_idx, input_nodes=[successor_id], params={})
    )

    # Step 4: Create multiplication from addition using MapIterateRule
    # Note: We need to run the first action to get the addition_id
    action = policy.select_action(env)
    next_graph, reward, done, truncated, info = env.step(action)
    graph = next_graph

    # Get the ID of the newly created addition concept
    addition_id = info["new_entities"][0]

    # Rename the concept to "addition"
    graph.nodes[addition_id]["entity"].name = "addition"

    # Now add the action to create multiplication
    policy.action_sequence.append(
        ValidAction(
            rule_idx=map_iterate_idx, input_nodes=[addition_id, zero_id], params={}
        )
    )

    # Step 5: Create power from multiplication using MapIterateRule
    # Run the second action to get the multiplication_id
    action = policy.select_action(env)
    next_graph, reward, done, truncated, info = env.step(action)
    graph = next_graph

    # Get the ID of the newly created multiplication concept
    multiplication_id = info["new_entities"][0]

    # Rename the concept to "multiplication"
    graph.nodes[multiplication_id]["entity"].name = "multiplication"

    # Add one_concept to the graph
    one_id = graph.add_concept(one_concept)

    # Now add the action to create power
    policy.action_sequence.append(
        ValidAction(
            rule_idx=map_iterate_idx, input_nodes=[multiplication_id, one_id], params={}
        )
    )

    # Step 6: Create square from multiplication using MatchRule
    policy.action_sequence.append(
        ValidAction(
            rule_idx=match_idx,
            input_nodes=[multiplication_id],
            params={"indices_to_match": [0, 1]},
        )
    )

    # Step 7: Create divides from multiplication using ExistsRule
    policy.action_sequence.append(
        ValidAction(
            rule_idx=exists_idx,
            input_nodes=[multiplication_id],
            params={"indices_to_quantify": [0]},
        )
    )

    # Run the environment for all steps
    done = False
    truncated = False
    total_reward = 0

    while not (done or truncated) and policy.current_step < len(policy.action_sequence):
        action = policy.select_action(env)
        if action is None:
            break

        next_graph, reward, done, truncated, info = env.step(action)
        total_reward += reward
        graph = next_graph

        # Print information about the step
        print(f"Step {policy.current_step}: Reward {reward}")
        if "new_entities" in info and info["new_entities"]:
            print(f"New entities: {info['new_entities']}")

            # Rename concepts based on step number
            if policy.current_step == 3:  # Power concept
                power_id = info["new_entities"][0]
                graph.nodes[power_id]["entity"].name = "power"
            elif policy.current_step == 4:  # Square concept
                square_id = info["new_entities"][0]
                graph.nodes[square_id]["entity"].name = "square"
            elif policy.current_step == 5:  # Divides concept
                divides_id = info["new_entities"][0]
                graph.nodes[divides_id]["entity"].name = "divides"

                # Set computational implementation for divides
                divides = graph.nodes[divides_id]["entity"]
                multiplication = graph.nodes[multiplication_id]["entity"]
                divides.set_computational_implementation(
                    lambda a, b: a != 0 and b % a == 0
                )

                # Create two_concept
                two_concept = Concept(
                    name="two",
                    description="The natural number 2",
                    symbolic_definition=lambda: 2,
                    computational_implementation=lambda: 2,
                    example_structure=ExampleStructure(
                        concept_type=ConceptType.CONSTANT,
                        component_types=(ExampleType.NUMERIC,),
                        input_arity=0,
                    ),
                )
                two_concept.add_example((2,))  # Add example for the value 2
                two_id = graph.add_concept(two_concept)

                # Create is_even by specializing divides
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=specialize_idx,
                        input_nodes=[divides_id, two_id],
                        params={"index_to_specialize": 0},
                    )
                )

                # Execute the action to create is_even
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created is_even concept
                is_even_id = info["new_entities"][0]

                # Rename the concept to "is_even"
                graph.nodes[is_even_id]["entity"].name = "is_even"

                # Add examples for is_even
                is_even = graph.nodes[is_even_id]["entity"]
                is_even.add_example((2,))
                is_even.add_example((4,))
                is_even.add_example((0,))
                is_even.add_nonexample((1,))
                is_even.add_nonexample((3,))

                # Create is_odd by negating is_even
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=negate_idx, input_nodes=[is_even_id], params={}
                    )
                )

                # Execute the action to create is_odd
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                print("info: ", info)
                # Get the ID of the newly created is_odd concept
                is_odd_id = info["new_entities"][0]

                # Rename the concept to "is_odd"
                graph.nodes[is_odd_id]["entity"].name = "is_odd"

                # Create constant concept for 4
                four_concept = Concept(
                    name="four",
                    description="The constant four",
                    symbolic_definition=lambda: Nat(4),
                    computational_implementation=lambda: 4,
                    example_structure=ExampleStructure(
                        concept_type=ConceptType.FUNCTION,
                        component_types=(
                            ExampleType.NUMERIC,
                        ),  # Output type is numeric
                        input_arity=0,  # No inputs for a constant
                    ),
                )
                four_id = graph.add_concept(four_concept)

                # Create divisible_by_4 by specializing divides
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=specialize_idx,
                        input_nodes=[divides_id, four_id],
                        params={"index_to_specialize": 0},
                    )
                )

                # Execute the action to create divisible_by_4
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created divisible_by_4 concept
                divisible_by_4_id = info["new_entities"][0]

                # Rename the concept to "divisible_by_4"
                graph.nodes[divisible_by_4_id]["entity"].name = "divisible_by_4"

                # Find the index of the SizeRule
                size_idx = next(
                    i for i, rule in enumerate(env.rules) if isinstance(rule, SizeRule)
                )

                # Create tau function (number of divisors) using SizeRule on divides
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=size_idx,
                        input_nodes=[divides_id],
                        params={
                            "indices_to_quantify": [1]
                        },  # Count the second argument (divisors)
                    )
                )

                # Execute the action to create tau
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created tau concept
                tau_id = info["new_entities"][0]

                # Rename the concept to "tau"
                graph.nodes[tau_id]["entity"].name = "tau"

                # Add examples for tau
                tau = graph.nodes[tau_id]["entity"]
                tau.add_example((1, 1))  # 1 has 1 divisor
                tau.add_example((2, 2))  # 2 has 2 divisors (1, 2)
                tau.add_example((3, 2))  # 3 has 2 divisors (1, 3)
                tau.add_example((4, 3))  # 4 has 3 divisors (1, 2, 4)
                tau.add_example((6, 4))  # 6 has 4 divisors (1, 2, 3, 6)
                tau.add_example((7, 2))  # 7 has 2 divisors (1, 7)
                tau.add_example((8, 4))  # 8 has 4 divisors (1, 2, 4, 8)
                tau.add_example((12, 6))  # 12 has 6 divisors (1, 2, 3, 4, 6, 12)
                tau.add_nonexample((1, 4))  # 1 does not have 4 divisors

                # Set computational implementation for tau
                divides = graph.nodes[divides_id]["entity"]
                tau.set_computational_implementation(
                    lambda n: sum(1 for i in range(1, n + 1) if divides.compute(i, n))
                )

                # check that tau's new computational implementation is correct
                assert tau.compute(1) == 1, "1 has 1 divisor"
                assert tau.compute(2) == 2, "2 has 2 divisors"
                assert tau.compute(3) == 2, "3 has 2 divisors"
                assert tau.compute(4) == 3, "4 has 3 divisors"
                assert tau.compute(6) == 4, "6 has 4 divisors"
                

                # Create is_prime by specializing tau to check if a number has exactly 2 divisors
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=specialize_idx,
                        input_nodes=[tau_id, two_id],
                        params={
                            "index_to_specialize": 1
                        },  # Fix the second argument to 2
                    )
                )

                # Execute the action to create is_prime
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created is_prime concept
                is_prime_id = info["new_entities"][0]

                # Rename the concept to "is_prime"
                graph.nodes[is_prime_id]["entity"].name = "is_prime"

                # Add examples for is_prime
                is_prime = graph.nodes[is_prime_id]["entity"]
                is_prime.add_example((2,))  # 2 is prime
                is_prime.add_example((3,))  # 3 is prime
                is_prime.add_example((5,))  # 5 is prime
                is_prime.add_example((7,))  # 7 is prime
                is_prime.add_example((11,))  # 11 is prime
                is_prime.add_nonexample((1,))  # 1 is not prime (has only 1 divisor)
                is_prime.add_nonexample((4,))  # 4 is not prime (has 3 divisors)
                is_prime.add_nonexample((6,))  # 6 is not prime (has 4 divisors)
                is_prime.add_nonexample((8,))  # 8 is not prime (has 4 divisors)
                is_prime.add_nonexample((9,))  # 9 is not prime (has 3 divisors)

                # Create prime_square via compose
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=compose_idx,
                        input_nodes=[square_id, is_prime_id],
                        params={
                            "output_to_input_map": {0: 0}
                        },  # map square's output to is_prime's input
                    )
                )

                # Execute the action to create prime_square
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created prime_square concept
                prime_square_id = info["new_entities"][0]

                # Rename the concept to "prime_square"
                graph.nodes[prime_square_id]["entity"].name = "prime_square"

                # Add nonexamples to prime_square
                prime_square = graph.nodes[prime_square_id]["entity"]
                prime_square.add_nonexample((4,))  # 4 = 2², but 4 is not prime
                prime_square.add_nonexample((9,))  # 9 = 3², but 9 is not prime
                prime_square.add_nonexample((16,))  # 16 = 4², but 16 is not prime
                prime_square.add_nonexample((25,))  # 25 = 5², but 25 is not prime

                # Create conjecture about nonexistence of prime squares
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=nonexistence_idx,
                        input_nodes=[prime_square_id],
                        params={},
                    )
                )

                # Execute the action to create no_prime_squares conjecture
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created no_prime_squares conjecture
                no_prime_squares_id = info["new_entities"][0]

                # Create is_square predicate from square function using exists rule
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=exists_idx,
                        input_nodes=[square_id],
                        params={"indices_to_quantify": [0]},  # Quantify over the input
                    )
                )

                # Execute the action to create is_square
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created is_square concept
                is_square_id = info["new_entities"][0]

                # Rename the concept to "is_square"
                graph.nodes[is_square_id]["entity"].name = "is_square"

                # Set computational implementation for is_square
                is_square = graph.nodes[is_square_id]["entity"]
                is_square.set_computational_implementation(
                    lambda n: int(n**0.5) ** 2 == n
                )

                # Add examples for is_square
                is_square.add_example((0,))  # 0 = 0²
                is_square.add_example((1,))  # 1 = 1²
                is_square.add_example((4,))  # 4 = 2²
                is_square.add_example((9,))  # 9 = 3²
                is_square.add_example((16,))  # 16 = 4²
                is_square.add_nonexample((2,))  # Not a square
                is_square.add_nonexample((3,))  # Not a square
                is_square.add_nonexample((5,))  # Not a square

                # Create even_square by composing is_even and is_square predicates
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=compose_idx,
                        input_nodes=[is_even_id, is_square_id],
                        params={
                            "shared_vars": {0: 0}
                        },  # share the first variable between predicates
                    )
                )

                # Execute the action to create even_square
                action = policy.select_action(env)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created even_square concept
                even_square_id = info["new_entities"][0]

                # Rename the concept to "even_square"
                graph.nodes[even_square_id]["entity"].name = "even_square"

                # Add examples for even_square
                even_square = graph.nodes[even_square_id]["entity"]
                even_square.add_example((0,))  # 0 is even and 0 = 0²
                even_square.add_example((4,))  # 4 is even and 4 = 2²
                even_square.add_example((16,))  # 16 is even and 16 = 4²
                even_square.add_nonexample((1,))  # 1 is odd (though it's a square)
                even_square.add_nonexample((2,))  # 2 is even but not a square
                even_square.add_nonexample((9,))  # 9 is a square but odd

                # Add examples for divisible_by_4
                divisible_by_4 = graph.nodes[divisible_by_4_id]["entity"]
                divisible_by_4.add_example((4,))  # 4 is divisible by 4
                divisible_by_4.add_example((8,))  # 8 is divisible by 4
                divisible_by_4.add_nonexample((2,))  # 2 is not divisible by 4
                divisible_by_4.add_nonexample((6,))  # 6 is not divisible by 4

                # Create has_all_even_divisors using ForallRule on divides and is_even
                print("\nCreating has_all_even_divisors from divides and is_even...")
                forall_idx = next(
                    i
                    for i, rule in enumerate(env.rules)
                    if isinstance(rule, ForallRule)
                )
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=forall_idx,
                        input_nodes=[divides_id, is_even_id],
                        params={"indices_to_quantify": [1], "indices_to_map": {1: 0}},
                    )
                )

                # Execute the action to create has_all_even_divisors
                action = policy.select_action(graph)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph

                # Get the ID of the newly created has_all_even_divisors concept
                has_all_even_divisors_id = info["new_entities"][0]

                # Rename the concept to "has_all_even_divisors"
                has_all_even_divisors = graph.nodes[has_all_even_divisors_id]["entity"]
                has_all_even_divisors.name = "has_all_even_divisors"

                # Add examples for has_all_even_divisors
                has_all_even_divisors.add_example(
                    (4,)
                )  # 4 has divisors 1,2,4 - not all even
                has_all_even_divisors.add_example(
                    (8,)
                )  # 8 has divisors 1,2,4,8 - not all even
                has_all_even_divisors.add_nonexample(
                    (2,)
                )  # 2 has divisors 1,2 - not all even
                has_all_even_divisors.add_nonexample(
                    (6,)
                )  # 6 has divisors 1,2,3,6 - not all even

                # Note(5/7): Commented out because the examples are not consistent, found due to example improvements. 
                # # Create even_square_div4 by creating an equivalence conjecture
                # policy.action_sequence.append(
                #     ValidAction(
                #         rule_idx=equivalence_idx,
                #         input_nodes=[even_square_id, divisible_by_4_id],
                #         params={},
                #     )
                # )

                # # Execute the action to create even_square_div4 conjecture
                # action = policy.select_action(env)
                # next_graph, reward, done, truncated, info = env.step(action)
                # graph = next_graph

                # # Get the ID of the newly created even_square_div4 conjecture
                # even_square_div4_id = info['new_entities'][0]

                # Create even_prime by composing is_even and is_prime predicates
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=compose_idx,
                        input_nodes=[is_even_id, is_prime_id],
                        params={"shared_vars": {0: 0}}  # share the first variable between predicates
                    )
                )
                
                # Execute the action to create even_square
                action = policy.select_action(graph)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph
                
                # Get the ID of the newly created even_square concept
                even_prime_id = info['new_entities'][0]
                
                # Rename the concept to "even_prime"
                graph.nodes[even_prime_id]["entity"].name = "even_prime"
                
                # Add examples for even_prime
                even_prime = graph.nodes[even_prime_id]["entity"]
                even_prime.add_example((2,)) 
                even_prime.add_nonexample((1,))  # 1 is odd
                even_prime.add_nonexample((4,))  # 2 is not a prime
                even_prime.add_nonexample((5,))  # 5 is odd                

                # Create appl_even_prime by creating an exclusivity conjecture
                policy.action_sequence.append(
                    ValidAction(
                        rule_idx=exclusivity_idx,
                        input_nodes=[even_prime_id],
                        params={"valid_set": {(2,)}}
                    )
                )
                
                # Execute the action to create appl_even_prime conjecture
                action = policy.select_action(graph)
                next_graph, reward, done, truncated, info = env.step(action)
                graph = next_graph
                
                # Get the ID of the newly created appl_even_prime conjecture
                appl_even_prime = info['new_entities'][0]

            # If we want to add more actions that depend on the newly created entities,
            # we could do so here by checking the step number and adding to the action sequence

    # Check that the environment ran successfully
    assert policy.current_step > 0

    # Print final statistics
    print(f"\nFinal statistics:")
    print(f"Total steps: {policy.current_step}")
    print(f"Total reward: {total_reward}")
    print(f"Number of concepts: {len(graph.get_all_concepts())}")
    print(f"Number of conjectures: {len(graph.get_all_conjectures())}")

    # Visualize the knowledge graph
    print("\nGenerating knowledge graph visualizations...")

    # Create the data/visualizations directory if it doesn't exist
    os.makedirs("data/visualizations", exist_ok=True)

    # Generate the construction tree visualization
    graph.visualize_construction_tree(
        output_file="data/visualizations/math_env_nat_construction_tree"
    )

    # Verify that we've created the expected concepts
    concept_names = [
        graph.nodes[c_id]["entity"].name for c_id in graph.get_all_concepts()
    ]
    print(f"Concepts created: {concept_names}")

    # Check that we have the expected concepts
    assert "addition" in concept_names
    assert "multiplication" in concept_names
    assert "power" in concept_names
    assert "square" in concept_names
    assert "divides" in concept_names
    assert "two" in concept_names
    assert "is_even" in concept_names
    assert "is_odd" in concept_names
    assert "four" in concept_names
    assert "divisible_by_4" in concept_names
    assert "is_prime" in concept_names
    assert "prime_square" in concept_names
    assert "is_square" in concept_names
    assert "even_square" in concept_names
    assert "tau" in concept_names
    assert "has_all_even_divisors" in concept_names
    assert "even_prime" in concept_names

    # Check that we have the expected conjectures
    conjecture_names = [
        graph.nodes[c_id]["entity"].name for c_id in graph.get_all_conjectures()
    ]
    print(f"Conjectures created: {conjecture_names}")

    assert (
        len(graph.get_all_conjectures()) == 2
    )  # no_prime_squares, and appl_even_prime


def test_action_enumeration_validation(initial_graph, production_rules):
    """
    Test that the action enumeration correctly includes the expected next actions.
    This test verifies that after each rule application, the enumerated valid actions
    contain the next rule we expect to use in the sequence.
    """
    # Create the environment with action enumeration enabled
    env = MathEnv(
        initial_graph=initial_graph,
        production_rules=production_rules,
        max_steps=20,
        enumerate_actions=True,  # Enable action enumeration
    )

    # Reset the environment
    graph, info = env.reset()

    # Find indices of the rules we'll use
    map_iterate_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, MapIterateRule)
    )
    match_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, MatchRule)
    )
    exists_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, ExistsRule)
    )
    specialize_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, SpecializeRule)
    )
    negate_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, NegateRule)
    )
    compose_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, ComposeRule)
    )
    nonexistence_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, NonexistenceRule)
    )
    equivalence_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, EquivalenceRule)
    )
    exclusivity_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, ExclusivityRule)
    )


    print("\n=== Testing Action Enumeration Validation ===")

    # Step 1: Create successor concept and add it to the graph
    successor = create_successor_concept()
    successor_id = graph.add_concept(successor)
    print(f"Added successor concept with ID: {successor_id}")

    # Step 2: Add zero concept to the graph
    zero_id = graph.add_concept(zero_concept)
    print(f"Added zero concept with ID: {zero_id}")

    # Step 3: Add one concept to the graph
    one_id = graph.add_concept(one_concept)
    print(f"Added one concept with ID: {one_id}")

    # Compute valid actions after adding all concepts
    env.valid_actions = env._compute_valid_actions()
    print(f"Computed {len(env.valid_actions)} valid actions after adding all concepts")

    # Helper function to find and execute an action
    def find_and_execute_action(step_name, rule_idx, input_nodes, params=None):
        """Helper function to find and execute an action"""
        print(f"\n{step_name}")
        print(f"Looking for action in {len(env.valid_actions)} valid actions")

        # Construct the expected action
        expected_action = ValidAction(
            rule_idx=rule_idx,
            input_nodes=input_nodes,
            params={} if params is None else params,
        )

        print(
            f"Expected: Rule={env.rules[rule_idx].name}, Inputs={input_nodes}, Params={params}"
        )

        # Print available actions for debugging
        print("Available actions:")
        for i, action in enumerate(env.valid_actions):
            print(
                f"  {i}: Rule={env.rules[action.rule_idx].name}, Inputs={action.input_nodes}, Params={action.params}"
            )

            # Check if this action matches what we're looking for
            if (
                action.rule_idx == rule_idx
                and action.input_nodes == input_nodes
                and (params is None or action.params == params)
            ):
                print(f"  ✓ Match found at index {i}")

        # Find the action in the valid_actions list
        found = False
        for i, action in enumerate(env.valid_actions):
            if (
                action.rule_idx == rule_idx
                and action.input_nodes == input_nodes
                and (params is None or action.params == params)
            ):
                found = True
                action_idx = i
                break

        assert found, f"Expected action not found in valid_actions: {expected_action}"
        print(f"✓ Found expected action at index {action_idx}")

        # Execute the action
        next_graph, _, _, _, info = env.step(action_idx)

        # Get the new entity ID
        new_entity_id = info["new_entities"][0]
        print(f"Created entity with ID: {new_entity_id}")

        return next_graph, new_entity_id

    # Execute the sequence of actions
    # Step 1: Create addition from successor
    graph, addition_id = find_and_execute_action(
        "Step 1: Creating addition concept", map_iterate_idx, [successor_id]
    )

    # Step 2: Create multiplication from addition and zero
    graph, multiplication_id = find_and_execute_action(
        "Step 2: Creating multiplication concept",
        map_iterate_idx,
        [addition_id, zero_id],
    )

    # Step 3: Create power from multiplication and one
    graph, power_id = find_and_execute_action(
        "Step 3: Creating power concept", map_iterate_idx, [multiplication_id, one_id]
    )

    # Step 4: Create square from multiplication
    graph, square_id = find_and_execute_action(
        "Step 4: Creating square concept",
        match_idx,
        [multiplication_id],
        {"indices_to_match": [0, 1]},
    )

    # Step 5: Create divides concept from multiplication
    graph, divides_id = find_and_execute_action(
        "Step 5: Creating divides concept", exists_idx, [multiplication_id]
    )

    # Print all valid actions before applying NegateRule
    print("\n=== Valid Actions Before Applying NegateRule ===")
    for i, action in enumerate(env.valid_actions):
        if action.rule_idx == negate_idx:
            print(
                f"  {i}: Rule={env.rules[action.rule_idx].name}, Inputs={action.input_nodes}"
            )

    # Step 6: Create not_divides from divides
    graph, not_divides_id = find_and_execute_action(
        "Step 6: Creating not_divides concept", negate_idx, [divides_id]
    )

    # Print all valid actions after applying NegateRule
    print("\n=== Valid Actions After Applying NegateRule ===")
    for i, action in enumerate(env.valid_actions):
        if action.rule_idx == negate_idx:
            print(
                f"  {i}: Rule={env.rules[action.rule_idx].name}, Inputs={action.input_nodes}"
            )

    print("\n✓ All expected actions were found in the enumerated valid actions!")
    print("=== Action Enumeration Validation Test Completed Successfully ===\n")


def test_entity_removal(initial_graph, production_rules):
    """Test the entity removal functionality of the environment."""
    # Create environment with entity removal enabled
    env = MathEnv(
        initial_graph=initial_graph,
        production_rules=production_rules,
        allow_entity_removal=True
    )
    
    # Get the zero and successor concepts from the initial graph
    zero_id = next(id for id in env.graph.get_all_concepts() 
                   if env.graph.nodes[id]["entity"].name == "zero")
    successor_id = next(id for id in env.graph.get_all_concepts() 
                       if env.graph.nodes[id]["entity"].name == "successor")
    
    # Test removing non-existent entity
    assert not env.remove_concept("non_existent_id")
    
    # Test removing when removal is disabled
    env.allow_entity_removal = False
    assert not env.remove_concept(successor_id)
    
    # Test removal with rediscovery
    env.allow_entity_removal = True
    # Create a new concept using MapIterateRule
    map_iterate_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, MapIterateRule)
    )
    action = ValidAction(rule_idx=map_iterate_idx, input_nodes=[successor_id], params={})
    env.step(action)
    
    # Get the created concept ID
    concept_ids = env.graph.get_all_concepts()
    created_concept_id = next(id for id in concept_ids 
                            if id not in [zero_id, successor_id])
    
    # Verify the action is in applied_actions
    assert action in env.applied_actions
    
    # Remove the concept with rediscovery enabled
    assert env.remove_concept(created_concept_id, allow_rediscovery=True)
    assert created_concept_id not in env.graph
    assert action not in env.applied_actions
    

    # Verify we can reapply the exact same action
    env.step(action)  # Should work since we removed it from applied_actions
    concept_ids = env.graph.get_all_concepts()
    assert len(concept_ids) > 0
    
    # Test removal without rediscovery
    assert env.remove_concept(created_concept_id, allow_rediscovery=False)
    assert created_concept_id not in env.graph
    assert action in env.applied_actions  # Action should still be in applied_actions
    
    # Verify we cannot recreate the concept
    step_worked = False
    try:
        env.step(action)
        step_worked = True
    except ValueError as e:
        print(f"Error: {e}")

    assert not step_worked
        
def test_entity_removal_with_valid_actions(initial_graph, production_rules):
    """Test entity removal when valid actions are being tracked."""
    # Create environment with both entity removal and action enumeration enabled
    env = MathEnv(
        initial_graph=initial_graph,
        production_rules=production_rules,
        allow_entity_removal=True,
        enumerate_actions=True
    )
    
    # Get the zero, successor concepts from the initial graph
    zero_id = next(id for id in env.graph.get_all_concepts() 
                   if env.graph.nodes[id]["entity"].name == "zero")
    successor_id = next(id for id in env.graph.get_all_concepts() 
                       if env.graph.nodes[id]["entity"].name == "successor")
    
    # Create a concept using MapIterateRule
    map_iterate_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, MapIterateRule)
    )
    action = ValidAction(rule_idx=map_iterate_idx, input_nodes=[successor_id], params={})
    env.step(action)
    
    # Get the created concept ID
    concept_ids = env.graph.get_all_concepts()
    created_concept_id = next(id for id in concept_ids 
                            if id != successor_id and id != zero_id)
    
    # Store initial valid actions count
    initial_valid_actions_count = len(env.valid_actions)
    
    # Remove the concept with update_valid_actions=True
    assert env.remove_concept(created_concept_id, allow_rediscovery=True, update_valid_actions=True)
    
    # Verify valid actions were updated
    assert len(env.valid_actions) < initial_valid_actions_count
    
    # Test with update_valid_actions=False
    # Create another concept with a new action
    new_action = ValidAction(rule_idx=map_iterate_idx, input_nodes=[successor_id], params={})
    env.step(new_action)
    concept_ids = env.graph.get_all_concepts()
    created_concept_id = next(id for id in concept_ids 
                            if id != successor_id)
    
    # Store valid actions count before removal
    before_removal_count = len(env.valid_actions)
    
    # Remove without updating valid actions
    assert env.remove_concept(created_concept_id, update_valid_actions=False)
    
    # Verify valid actions count hasn't changed
    assert len(env.valid_actions) == before_removal_count

def test_entity_removal_with_dependencies(initial_graph, production_rules):
    """Test removing entities that have dependent entities in the graph."""
    # Create environment with entity removal enabled
    env = MathEnv(
        initial_graph=initial_graph,
        production_rules=production_rules,
        allow_entity_removal=True
    )

    # Get the zero and successor concepts from the initial graph
    zero_id = next(id for id in env.graph.get_all_concepts()
                   if env.graph.nodes[id]["entity"].name == "zero")
    successor_id = next(id for id in env.graph.get_all_concepts()
                       if env.graph.nodes[id]["entity"].name == "successor")

    # Find indices of the rules we'll use
    map_iterate_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, MapIterateRule)
    )
    exists_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, ExistsRule)
    )
    specialize_idx = next(
        i for i, rule in enumerate(env.rules) if isinstance(rule, SpecializeRule)
    )

    # Step 1: Create addition from successor and zero
    action = ValidAction(
        rule_idx=map_iterate_idx,
        input_nodes=[successor_id],
        params={}
    )
    env.step(action)
    addition_id = env.graph.get_all_concepts()[-1]
    env.graph.nodes[addition_id]["entity"].name = "addition"

    # Step 2: Create multiplication from addition and zero
    action = ValidAction(
        rule_idx=map_iterate_idx,
        input_nodes=[addition_id, zero_id],
        params={}
    )
    env.step(action)
    multiplication_id = env.graph.get_all_concepts()[-1]
    env.graph.nodes[multiplication_id]["entity"].name = "multiplication"

    # Step 3: Create divides from multiplication
    action = ValidAction(
        rule_idx=exists_idx,
        input_nodes=[multiplication_id],
        params={"indices_to_quantify": [0]}  # Quantify over the first argument
    )
    env.step(action)
    divides_id = env.graph.get_all_concepts()[-1]
    env.graph.nodes[divides_id]["entity"].name = "divides"

    # Create two_concept for specializing divides
    two_concept = Concept(
        name="two",
        description="The natural number 2",
        symbolic_definition=lambda: 2,
        computational_implementation=lambda: 2,
        example_structure=ExampleStructure(
            concept_type=ConceptType.CONSTANT,
            component_types=(ExampleType.NUMERIC,),
            input_arity=0,
        ),
    )
    two_concept.add_example((2,))  # Add example for the value 2
    two_id = env.graph.add_concept(two_concept)

    # Step 4: Create is_even by specializing divides with two
    action = ValidAction(
        rule_idx=specialize_idx,
        input_nodes=[divides_id, two_id],
        params={"index_to_specialize": 0}  # Specialize the first argument to 2
    )
    env.step(action)
    is_even_id = env.graph.get_all_concepts()[-1]
    env.graph.nodes[is_even_id]["entity"].name = "is_even"

    # Now we have a chain of dependencies:
    # successor + zero -> addition + zero -> multiplication -> divides + two -> is_even

    # Test removing multiplication (should also remove divides and is_even)
    assert env.remove_concept(multiplication_id)
    assert multiplication_id not in env.graph
    assert divides_id not in env.graph
    assert is_even_id not in env.graph
    assert addition_id in env.graph  # This should still exist
    assert successor_id in env.graph  # This should still exist
    assert zero_id in env.graph  # This should still exist
    assert two_id in env.graph  # This should still exist

    # Test removing addition (should not affect successor)
    successor_id = next(id for id in env.graph.get_all_concepts() 
                       if env.graph.nodes[id]["entity"].name == "successor")
    assert env.remove_concept(addition_id)
    assert addition_id not in env.graph
    assert successor_id in env.graph  # This should still exist
    
    # Test removing successor (should not affect zero)
    zero_id = next(id for id in env.graph.get_all_concepts() 
                   if env.graph.nodes[id]["entity"].name == "zero")
    assert env.remove_concept(successor_id)
    assert successor_id not in env.graph
    assert zero_id in env.graph  # This should still exist
    
    # Test removing zero (should be safe as it has no dependents)
    assert env.remove_concept(zero_id)
    assert zero_id not in env.graph
    
    # remove two
    assert env.remove_concept(two_id)
    assert two_id not in env.graph

    assert len(env.graph.get_all_concepts()) == 0

if __name__ == "__main__":
    # For running this test directly
    import sys
    import os

    # Add the parent directory to the Python path
    parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    sys.path.insert(0, parent_dir)

    # Now import the necessary modules
    from frame.environments.math_env import MathEnv, ValidAction
    from frame.knowledge_base.knowledge_graph import KnowledgeGraph
    from frame.productions.base import ProductionRule
    from frame.productions.concepts.map_iterate import MapIterateRule
    from frame.productions.concepts.match import MatchRule
    from frame.productions.concepts.exists import ExistsRule
    from frame.productions.concepts.specialize import SpecializeRule
    from frame.productions.concepts.negate import NegateRule
    from frame.productions.concepts.compose import ComposeRule
    from frame.productions.concepts.size import SizeRule
    from frame.productions.concepts.constant import ConstantRule
    from frame.productions.conjectures.implication import ImplicationRule
    from frame.productions.conjectures.nonexistence import NonexistenceRule
    from frame.productions.conjectures.equivalence import EquivalenceRule
    from frame.productions.conjectures.exclusivity import ExclusivityRule

    # Create fixtures manually
    graph = initial_graph()
    rules = production_rules()

    # Run the test
    test_action_enumeration_validation(graph, rules)
