"""Mathematical Discovery Environment based on Gymnasium.

This environment allows an agent to explore mathematical concepts by:
1. Observing the current knowledge graph state
2. Selecting production rules and their inputs
3. Receiving rewards based on interestingness measures
"""

from typing import List, Dict, Any, Tuple, Optional, Union, Set, Literal
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from dataclasses import dataclass
from datetime import datetime
import logging
import traceback
import copy
import signal
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
import os

from frame.knowledge_base.knowledge_graph import KnowledgeGraph, NodeType, ConstructionStep
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import Entity, Concept, Conjecture, Theorem, ConceptType, Nat
from frame.provers.proof import Proof, Z3Proof
from frame.tools.cache import EpisodicCache, ProofStatus
from frame.environments.ground_truth_entities import (
    get_ground_truth_entity,
    update_entity_implementation,
    is_ground_truth_entity
)

# Define logger at the top of the file
logger = logging.getLogger(__name__)

# Default timeout for rule applications in seconds
# This can be overridden by the configuration
DEFAULT_RULE_APPLICATION_TIMEOUT = 0.5

# Default timeout for Z3 prover calls within a step (in seconds)
DEFAULT_Z3_PROVER_TIMEOUT = 2.50

class TimeoutError(Exception):
    """Custom exception for timeout errors."""
    pass

def timeout_handler(signum, frame):
    """Signal handler for timeout."""
    raise TimeoutError("Rule application timed out")

# Export these for use by other modules
__all__ = ['MathEnv', 'ValidAction', 'DEFAULT_RULE_APPLICATION_TIMEOUT']

@dataclass
class ValidAction:
    """Represents a valid action in the environment."""
    rule_idx: int
    input_nodes: List[str]
    params: Dict[str, Any]
    
    def __eq__(self, other):
        """Check if two actions are equal."""
        if not isinstance(other, ValidAction):
            return False
        return (self.rule_idx == other.rule_idx and 
                self.input_nodes == other.input_nodes and 
                self.params == other.params)
    
    def __hash__(self):
        """Hash function for ValidAction to enable using it in sets."""
        # Convert params dict to a tuple of (key, value) pairs for hashing
        param_items = tuple(sorted((k, str(v)) for k, v in self.params.items()))
        return hash((self.rule_idx, tuple(self.input_nodes), param_items))

class MathEnv(gym.Env):
    """
    Gymnasium environment for mathematical discovery.
    
    State Space (Observation):
        - Raw KnowledgeGraph object
        
    Action Space:
        - For enumerated policies: Discrete space indexing into valid_actions list
        - For non-enumerated policies: Direct ValidAction objects
    """
    
    # Define forbidden patterns as class methods with clear naming
    @classmethod
    def _pattern_self_implication(cls, graph: KnowledgeGraph, rule: ProductionRule, input_nodes: List[str]) -> bool:
        """Check for P -> P pattern."""
        return (rule.name == "implication" and 
                len(input_nodes) == 2 and 
                input_nodes[0] == input_nodes[1])
    
    @classmethod
    def _pattern_self_equivalence(cls, graph: KnowledgeGraph, rule: ProductionRule, input_nodes: List[str]) -> bool:
        """Check for P <-> P pattern."""
        return (rule.name == "equivalence" and 
                len(input_nodes) == 2 and 
                input_nodes[0] == input_nodes[1])
    
    @classmethod
    def _pattern_double_negation(cls, graph: KnowledgeGraph, rule: ProductionRule, input_nodes: List[str]) -> bool:
        """Check for not not P pattern."""
        if rule.name != "negation":
            return False
        # For each input node, check if it was created by negation
        for node_id in input_nodes:
            node_data = graph.nodes[node_id]
            if 'construction_step' in node_data:
                construction = node_data['construction_step']
                if construction and construction.rule.name == "negation":
                    return True
        return False
    
    @classmethod
    def _pattern_implication_to_negation(cls, graph: KnowledgeGraph, rule: ProductionRule, input_nodes: List[str]) -> bool:
        """Check for P -> not P pattern."""
        if rule.name != "implication" or len(input_nodes) != 2:
            return False
        # Check if the second node is a negation of the first
        target_node = graph.nodes[input_nodes[1]]
        if 'construction_step' in target_node:
            construction = target_node['construction_step']
            if (construction and 
                construction.rule.name == "negation" and 
                construction.input_node_ids[0] == input_nodes[0]):
                return True
        return False
    
    @classmethod
    def _pattern_equivalence_to_negation(cls, graph: KnowledgeGraph, rule: ProductionRule, input_nodes: List[str]) -> bool:
        """Check for P <-> not P pattern."""
        if rule.name != "equivalence" or len(input_nodes) != 2:
            return False
        # Check if either node is a negation of the other
        for i, j in [(0, 1), (1, 0)]:  # Check both directions
            target_node = graph.nodes[input_nodes[j]]
            if 'construction_step' in target_node:
                construction = target_node['construction_step']
                if (construction and 
                    construction.rule.name == "negation" and 
                    construction.input_node_ids[0] == input_nodes[i]):
                    return True
        return False
    
    @classmethod
    def _pattern_self_forall(cls, graph: KnowledgeGraph, rule: ProductionRule, input_nodes: List[str]) -> bool:
        """Check for forall(P, P) pattern - using the same concept twice in a forall rule."""
        return (rule.name == "forall" and 
                len(input_nodes) == 2 and 
                input_nodes[0] == input_nodes[1])
    
    @classmethod
    def _pattern_compose_predicates_same_concept(cls, graph: KnowledgeGraph, rule: ProductionRule, input_nodes: List[str]) -> bool:
        """Check for P(x) ∧ P(x) pattern."""
        return (rule.name == "compose" and 
                len(input_nodes) == 2 and 
                input_nodes[0] == input_nodes[1] and
                graph.nodes[input_nodes[0]]['entity'].is_predicate() and
                graph.nodes[input_nodes[1]]['entity'].is_predicate())
    
    @classmethod
    def _get_forbidden_patterns(cls):
        """Get the list of forbidden pattern checking methods."""
        return [
            cls._pattern_self_implication,
            cls._pattern_self_equivalence,
            cls._pattern_double_negation,
            cls._pattern_implication_to_negation,
            cls._pattern_equivalence_to_negation,
            cls._pattern_self_forall,
            cls._pattern_compose_predicates_same_concept,
        ]

    def __init__(
        self,
        initial_graph: KnowledgeGraph,
        production_rules: List[ProductionRule],
        max_steps: int = 1000,
        enumerate_actions: bool = True,
        update_ground_truth_implementations: bool = True,
        allow_entity_removal: bool = False,
        rule_application_timeout: float = DEFAULT_RULE_APPLICATION_TIMEOUT,
        episodic_cache: EpisodicCache = None,
        z3_prover_timeout: float = DEFAULT_Z3_PROVER_TIMEOUT,
        z3_example_search_timeout: float = 0.5, # Default timeout for Z3 example search
        use_z3_prover: bool = True, # Flag to enable/disable Z3 prover for conjectures
        use_z3_example_search: bool = True # Flag to enable/disable Z3 example search
    ):
        super().__init__()
        
        self.initial_graph = initial_graph
        self.graph = initial_graph.copy()
        self.rules = production_rules
        self.max_steps = max_steps
        self.current_step = 0
        self.enumerate_actions = enumerate_actions
        self.update_ground_truth_implementations = update_ground_truth_implementations
        self.allow_entity_removal = allow_entity_removal
        self.rule_application_timeout = rule_application_timeout
        self.z3_prover_timeout = z3_prover_timeout
        self.episodic_cache = episodic_cache # Store episodic cache instance
        self.z3_example_search_timeout = z3_example_search_timeout # Store Z3 example search timeout
        self.use_z3_prover = use_z3_prover # Store Z3 prover usage flag
        self.use_z3_example_search = use_z3_example_search # Store Z3 example search usage flag
        
        # Log the configured timeout
        logger.info(f"Rule application timeout set to {self.rule_application_timeout} seconds")
        
        # Initialize forbidden patterns first
        self.forbidden_patterns = self._get_forbidden_patterns()
        
        # Track applied actions to avoid repetition
        self.applied_actions = set()
        
        # Track timeout statistics
        self.timeout_count = 0
        self.last_timeout_step = None
        self.timed_out_rules = {}
        
        # Initialize valid actions if enumeration is enabled
        self.valid_actions = self._compute_valid_actions() if enumerate_actions else []
        
        # Action space is choosing which valid action to take if enumeration is enabled
        # Otherwise, it's a custom space for direct ValidAction objects
        if enumerate_actions:
            self.action_space = spaces.Discrete(len(self.valid_actions))
        else:
            self.action_space = spaces.Space()  # Custom space for direct ValidAction objects
        
        # Observation space is custom - will be handled by custom methods
        self.observation_space = spaces.Space()  # Custom space for knowledge graph
        
        # NOTE: Removed shared executor initialization - will create per-call executors

    def reset(self, seed=None, options=None) -> Tuple[KnowledgeGraph, Dict]:
        """Reset environment to initial state."""
        super().reset(seed=seed)
        
        self.graph = self.initial_graph.copy()
        self.current_step = 0
        
        # Reset the applied actions set
        self.applied_actions = set()
        
        # Reset timeout statistics
        self.timeout_count = 0
        self.last_timeout_step = None
        self.timed_out_rules = {}
        
        # Recompute valid actions for initial state if enumeration is enabled
        if self.enumerate_actions:
            self.valid_actions = self._compute_valid_actions()
            self.action_space = spaces.Discrete(len(self.valid_actions))
            
            # Log the number of valid actions
            logger.info(f"Initial valid actions count: {len(self.valid_actions)}")
            
            # Debug validation of initial actions
            # TODO(_; 3/6): This is just for debugging and should be removed once the issue is fixed.
            self._debug_validate_all_actions()
        else:
            self.valid_actions = []
        
        return self.graph, {}
        
    def step(self, action: Union[int, ValidAction], is_simulation: bool = False) -> Tuple[KnowledgeGraph, float, bool, bool, Dict]:
        """
        Take a step in the environment by applying a production rule.
        
        Args:
            action: Either an index into valid_actions list or a ValidAction object
            is_simulation: If True, skips potentially expensive operations like Z3 proving.
                
        Returns:
            observation: New knowledge graph state
            reward: Reward for this step
            done: Whether the episode is done
            truncated: Whether the episode was truncated
            info: Additional information
        """
        # Initialize info dictionary with default values
        info = {
            'new_entities': [],  # Always include this key, even if empty
        }
        
        # Convert index to ValidAction if needed
        if isinstance(action, int):
            if action < 0 or action >= len(self.valid_actions):
                info['error'] = f"Invalid action index: {action}"
                return self.graph, 0.0, False, False, info
            action_obj = self.valid_actions[action]
        else:
            action_obj = action
            
        # Get the rule and inputs
        rule = self.rules[action_obj.rule_idx]
        
        # Check if this action has already been applied (including timeouts)
        if action_obj in self.applied_actions:
            logger.info(f"Action already applied: {action_obj}")
            raise ValueError(f"Attempted to apply action that has already been applied or timed out: {rule.name}")
        
        # Get the actual entity objects for the input nodes
        input_entities = []
        input_entity_names = []
        for node_id in action_obj.input_nodes:
            entity, _, _ = self.graph.get_node(node_id)
            input_entities.append(entity)
            input_entity_names.append(entity.name if hasattr(entity, 'name') else str(entity))
            
        # Update info with rule and input entities
        info['rule_attempted'] = rule.name
        info['input_entities'] = action_obj.input_nodes
            
        # Log the rule being attempted
        if not is_simulation:
            logger.info(f"Trying rule: {rule.name}")
            logger.info(f"  Input entities: {input_entity_names}")
            logger.info(f"  Parameters: {action_obj.params}")

        self.current_step += 1
        done = self.current_step >= self.max_steps

        # Apply the rule with timeout using a new ThreadPoolExecutor for each call
        # This prevents stuck threads from affecting future rule applications
        future = None
        start_time = time.time()
        executor = None
        
        try:
            # Create a fresh executor for this specific rule application
            executor = ThreadPoolExecutor(max_workers=1)
                
            # Submit the rule application task
            if not is_simulation:
                logger.info(f"Applying rule {rule.name} with timeout of {self.rule_application_timeout} seconds")
            future = executor.submit(rule.apply, *input_entities, **action_obj.params)
                
            # Wait for the result with timeout
            try:
                new_entity = future.result(timeout=self.rule_application_timeout)

                # Log execution time
                execution_time = time.time() - start_time
                if not is_simulation:
                    logger.info(f"Rule {rule.name} completed in {execution_time:.3f} seconds")

                # TODO(_; 5/3): Log the new entity examples and nonexamples for debugging purposes.
                if not isinstance(new_entity, Conjecture) and not is_simulation:
                    logger.info(f"New entity examples: {[ex.value for ex in new_entity.get_examples()]}")
                    logger.info(f"New entity nonexamples: {[ex.value for ex in new_entity.get_nonexamples()]}")

                # Default reward and status
                reward = 0.0
                new_entity_id = None # Will store the ID of the added/promoted entity
                entity_to_add_or_check = new_entity # Start with the entity returned by the rule

                # --- Z3 Proof Attempt for New Conjectures (if enabled and not simulation) ---
                assert isinstance(new_entity, Entity)
                skip_z3 = False # Flag to skip Z3 block if cache hit

                if self.use_z3_prover and not is_simulation: # Check if Z3 proving is enabled
                    # --- Check Proof Cache before Z3 Call --- 
                    if isinstance(new_entity, Conjecture) and self.episodic_cache:
                        cached_status = self.episodic_cache.get_proof_status(new_entity.name)
                        if cached_status in ["Proven", "Disproven"]:
                            logger.info(f"Cache hit for {new_entity.name}: Status '{cached_status}'. Skipping Z3.")
                            if cached_status == "Proven":
                                # Promote to Theorem using cached status
                                cached_proof = Z3Proof(conjecture=new_entity.name, timestamp=datetime.now(), proof_object="Proven from Cache") 
                                theorem = Theorem(
                                    name=new_entity.name,
                                    description=new_entity.description,
                                    symbolic_definition=new_entity._symbolic,
                                    proof=cached_proof,
                                    example_structure=new_entity.examples.example_structure,
                                    lean4_translation=new_entity._lean4,
                                    prolog_translation=new_entity._prolog,
                                    z3_translation=new_entity._z3,
                                    computational_implementation=new_entity._compute,
                                    can_add_examples=new_entity.can_add_examples,
                                    can_add_nonexamples=new_entity.can_add_nonexamples,
                                )
                                entity_to_add_or_check = theorem
                                z3_proved = True # Use existing flag to trigger node modification logic below
                            skip_z3 = True # Set flag to skip the Z3 block
                    # --- End Cache Check --- 

                    # --- Z3 Prover Call --- 
                    if not is_simulation and isinstance(new_entity, Conjecture) and new_entity.has_z3_translation() and not skip_z3:
                        logger.info(f"New conjecture {new_entity.name} has Z3 translation. Attempting proof (Z3 prover enabled)... ")
                        z3_executor = None
                        z3_future = None
                        z3_start_time = time.time()
                        z3_proved = False # Track if Z3 succeeded
                        z3_final_status: ProofStatus = "Unknown" # Track status for cache update
                        try:
                            z3_template = new_entity.to_z3()
                            logger.info(f"Z3 template: {z3_template}")
                            z3_executor = ThreadPoolExecutor(max_workers=1)
                            # Ensure program attribute exists before calling run
                            z3_future = z3_executor.submit(z3_template.run)
                            
                            try:
                                z3_result = z3_future.result(timeout=self.z3_prover_timeout)
                                z3_execution_time = time.time() - z3_start_time
                                logger.info(f"Z3 Result for {new_entity.name}: Proved={z3_result.proved} (took {z3_execution_time:.3f}s)")

                                if z3_result.proved:
                                    z3_proved = True # Mark as proved for node modification
                                    z3_final_status = "Proven"
                                    logger.info(f"Conjecture {new_entity.name} proven by Z3! Promoting to Theorem.")
                                    proof = Z3Proof(
                                        conjecture=new_entity.name, 
                                        timestamp=datetime.now(),
                                        proof_object=z3_result
                                    )
                                    # Create Theorem object - this will become the entity in the node
                                    theorem = Theorem(
                                        name=new_entity.name,
                                        description=new_entity.description,
                                        symbolic_definition=new_entity._symbolic,
                                        proof=proof,
                                        example_structure=new_entity.examples.example_structure,
                                        lean4_translation=new_entity._lean4,
                                        prolog_translation=new_entity._prolog,
                                        z3_translation=new_entity._z3,
                                        computational_implementation=new_entity._compute,
                                        can_add_examples=new_entity.can_add_examples,
                                        can_add_nonexamples=new_entity.can_add_nonexamples,
                                    )
                                    # Update the entity we'll add/check later
                                    entity_to_add_or_check = theorem
                                else:
                                    if z3_result.timed_out:
                                        z3_final_status = "Unknown"
                                        logger.info(f"Conjecture {new_entity.name} Z3 Unknown due to timeout.")
                                    else:
                                        z3_final_status = "Disproven"
                                        logger.info(f"Conjecture {new_entity.name} disproven.")

                            except FuturesTimeoutError:
                                z3_execution_time = time.time() - z3_start_time
                                z3_final_status = "Unknown"
                                logger.warning(f"Z3 proof attempt for {new_entity.name} timed out after {z3_execution_time:.3f}s (limit: {self.z3_prover_timeout}s)")
                                if z3_future: z3_future.cancel()
                            except Exception as z3_run_error:
                                logger.error(f"Error running Z3 for conjecture {new_entity.name}: {z3_run_error}")
                                z3_final_status = "Unknown"
                                logger.error(f"Z3 template code: {z3_template.code if z3_template else 'N/A'}")
                                if z3_future: z3_future.cancel()

                        except Exception as z3_setup_error:
                            logger.error(f"Error setting up Z3 check for {new_entity.name}: {z3_setup_error}")
                            z3_final_status = "Unknown"
                        finally:
                            if z3_executor:
                                z3_executor.shutdown(wait=False)
                        # --- End Z3 Prover Call ---

                        # --- Modify Node In Place If Proved (Inside Z3 Prover Check) ---
                        if z3_proved:
                            # Add original conjecture first to get ID - essential step
                            construction_step_for_promo = ConstructionStep(
                                rule=rule,
                                input_node_ids=action_obj.input_nodes,
                                parameters=action_obj.params,
                                timestamp=datetime.now()
                            )
                            # Add the *original* entity to secure the node ID
                            conjecture_id = self.graph.add_conjecture(new_entity, construction_step_for_promo)
                            new_entity_id = conjecture_id # Store the ID

                            # Now modify the node data with the Theorem
                            self.graph.nodes[conjecture_id]['entity'] = entity_to_add_or_check # Use the created Theorem object
                            self.graph.nodes[conjecture_id]['node_type'] = NodeType.THEOREM
                            self.graph.nodes[conjecture_id]['proof'] = proof # Add proof object
                            info['promoted_to_theorem'] = True # Still useful info
                            # Add construction edges (might be redundant)
                            for input_id in action_obj.input_nodes:
                                self.graph.add_construction_edge(input_id, new_entity_id)
                        # --- End Modify Node --- 

                        # --- Update Proof Cache After Z3 Attempt (if Z3 was run and enabled) ---
                        if self.episodic_cache and z3_final_status != "Unknown": # Only update if Z3 ran and produced a result
                            logger.info(f"Updating proof status for {new_entity.name} to {z3_final_status} in cache")
                            self.episodic_cache.update_proof_status(new_entity.name, "Conjecture", z3_final_status)
                        # --- End Cache Update --- 
                    # --- End of Z3 Prover Logic within use_z3_prover check ---
                # --- End Z3 Proof Attempt Check --- 

                # --- Add Entity / Check Override (Outside Z3 Prover Check) ---
                # If Z3 ran and promoted, new_entity_id is already set. Otherwise, add the entity.
                if new_entity_id is None:
                    # Add the entity (original Concept/Conjecture or Theorem if Z3 failed but entity was somehow Theorem)
                    construction_step = ConstructionStep(
                        rule=rule,
                        input_node_ids=action_obj.input_nodes,
                        parameters=action_obj.params,
                        timestamp=datetime.now()
                    )
                    if isinstance(entity_to_add_or_check, Concept):
                        new_entity_id = self.graph.add_concept(entity_to_add_or_check, construction_step)
                    elif isinstance(entity_to_add_or_check, Conjecture):
                        new_entity_id = self.graph.add_conjecture(entity_to_add_or_check, construction_step)
                    elif isinstance(entity_to_add_or_check, Theorem): # Should only happen if Z3 proved it but adding failed above somehow
                        new_entity_id = self.graph.add_theorem(entity_to_add_or_check, construction_step)
                    else:
                        # Should not happen
                        logger.error(f"Entity to add has unexpected type: {type(entity_to_add_or_check)}")
                        # Handle error appropriately, maybe raise or return error state

                    # Add construction edges if entity was newly added
                    if new_entity_id:
                        for input_id in action_obj.input_nodes:
                            self.graph.add_construction_edge(input_id, new_entity_id)

                # Now that the node exists (either added or modified), apply post-processing
                # Make sure we have a valid ID before proceeding
                if new_entity_id:
                    current_entity_in_node = self.graph.nodes[new_entity_id]['entity']

                    # --- Z3-Based Example/Non-Example Finding (if enabled) ---
                    if self.use_z3_example_search: # Check if Z3 example search is enabled
                        if isinstance(current_entity_in_node, Concept) and \
                           current_entity_in_node.has_z3_translation() and \
                           (not current_entity_in_node.get_examples() or not current_entity_in_node.get_nonexamples()) and not is_simulation:

                            needs_example = not current_entity_in_node.get_examples()
                            needs_nonexample = not current_entity_in_node.get_nonexamples()

                            # --- Check Cache Before Z3 Search --- 
                            if self.episodic_cache:
                                cached_pos, cached_neg = self.episodic_cache.get_examples(current_entity_in_node.name)
                                if needs_example and cached_pos:
                                    logger.info(f"Found {len(cached_pos)} cached example(s) for '{current_entity_in_node.name}'. Adding them.")
                                    added_from_cache = 0
                                    for example_to_add in cached_pos:
                                        try:
                                            current_entity_in_node.add_example(example_to_add, override=True)
                                            added_from_cache += 1
                                        except Exception as cache_add_err:
                                            logger.warning(f"Failed to add cached example {example_to_add} to '{current_entity_in_node.name}': {cache_add_err}")
                                    if added_from_cache > 0:
                                        needs_example = False # No longer need to search Z3 if we added at least one

                                if needs_nonexample and cached_neg:
                                    logger.info(f"Found {len(cached_neg)} cached non-example(s) for '{current_entity_in_node.name}'. Adding them.")
                                    added_from_cache = 0
                                    for nonexample_to_add in cached_neg:
                                        try:
                                            current_entity_in_node.add_nonexample(nonexample_to_add, override=True)
                                            added_from_cache += 1
                                        except Exception as cache_add_err:
                                            logger.warning(f"Failed to add cached non-example {nonexample_to_add} to '{current_entity_in_node.name}': {cache_add_err}")
                                    if added_from_cache > 0:
                                        needs_nonexample = False # No longer need to search Z3 if we added at least one
                            # --- End Cache Check --- 

                            # --- Proceed with Z3 Search if Still Needed --- 
                            if (needs_example or needs_nonexample) and not is_simulation:
                                logger.info(f"Concept '{current_entity_in_node.name}' still lacks examples/nonexamples after cache check. Attempting Z3 search (timeout: {self.z3_example_search_timeout}s)...")
                                try:
                                    # Note(_; 5/4): Works for Nat types right, would need to change for other types.
                                    arity = len(current_entity_in_node.examples.example_structure.component_types)
                                    if arity is None:
                                        logger.warning(f"Cannot perform Z3 example search for '{current_entity_in_node.name}': Invalid or missing arity ({arity}).")
                                    else:
                                        z3_template = current_entity_in_node.to_z3()
                                        if hasattr(z3_template, 'program') and hasattr(z3_template, 'check_example'):
                                            max_attempts = 50 # Limit attempts per search type

                                            if needs_example:
                                                logger.debug(f"Searching for example for '{current_entity_in_node.name}'...")
                                                found_example = self._find_example_with_z3(
                                                    entity=current_entity_in_node,
                                                    search_type='example',
                                                    arity=arity,
                                                    z3_template=z3_template,
                                                    timeout_seconds=self.z3_example_search_timeout / 2, # Split timeout
                                                    max_attempts=max_attempts
                                                )
                                                if not found_example:
                                                    logger.warning(f"Failed to find Z3 example for '{current_entity_in_node.name}'.")

                                            if needs_nonexample:
                                                logger.debug(f"Searching for non-example for '{current_entity_in_node.name}'...")
                                                found_nonexample = self._find_example_with_z3(
                                                    entity=current_entity_in_node,
                                                    search_type='nonexample',
                                                    arity=arity,
                                                    z3_template=z3_template,
                                                    timeout_seconds=self.z3_example_search_timeout / 2, # Split timeout
                                                    max_attempts=max_attempts
                                                )
                                                if not found_nonexample:
                                                    logger.warning(f"Failed to find Z3 non-example for '{current_entity_in_node.name}'.")
                                                else:
                                                    logger.warning(f"Concept '{current_entity_in_node.name}' has Z3 translation but no suitable 'check_example' method found in Z3 program.")
                                        else:
                                            logger.warning(f"Concept '{current_entity_in_node.name}' has Z3 translation but no suitable 'check_example' method found in Z3 program.")
                                except Exception as search_setup_err:
                                    logger.error(f"Error during Z3 example search setup for '{current_entity_in_node.name}': {search_setup_err}")
                    # --- End Z3-Based Example Finding ---

                    # --- Name Override Check ---
                    if not is_simulation:
                        override_name = self._check_name_override(current_entity_in_node) # Check the entity now in the node
                        has_override = False
                        if override_name:
                            logger.info(f"Applying name override: {current_entity_in_node.name} -> {override_name}")
                            # Update the name *in the node's entity object*
                            self.graph.nodes[new_entity_id]['entity'].name = override_name
                            has_override = True
                            reward = 1.0 # Reward for ground truth
                            self.graph.nodes[new_entity_id]['has_manual_override'] = True
                    # --- End Name Override Check ---


                # Add this action to the applied actions set only if entity was successfully added/promoted
                if new_entity_id:
                    self.applied_actions.add(action_obj)
                    info['new_entities'] = [new_entity_id] # Update info with the actual ID
                    info['rule_applied'] = rule.name # Indicate which rule was successfully applied

                # Update valid actions if enumeration is enabled and entity was added/promoted
                if self.enumerate_actions and new_entity_id:
                    try:
                        self._update_valid_actions([new_entity_id])
                            
                        # Debug validation of all actions
                        # TODO(_; 3/6): This is just for debugging and should be removed once the issue is fixed.
                        self._debug_validate_all_actions()
                            
                        # Remove the action from valid_actions to prevent applying it again
                        if isinstance(action, int):
                            self.valid_actions.pop(action)
                        else:
                            # Find and remove the action by comparing rule_idx, input_nodes, and params
                            for i, valid_action in enumerate(self.valid_actions):
                                if (valid_action.rule_idx == action_obj.rule_idx and
                                    valid_action.input_nodes == action_obj.input_nodes and
                                    valid_action.params == action_obj.params):
                                    self.valid_actions.pop(i)
                                    break
                    except Exception as e:
                        error_traceback = traceback.format_exc()
                        logger.error(f"Error updating actions: {e}")
                        logger.error(error_traceback)
                        logger.error(f"Error occurred while updating actions for new entity: {new_entity.name}")
                        # Continue execution instead of raising the exception
                        # This allows the step to complete even if there was an error updating actions
                        pass
                    
                # Update info with success data
                info['new_entities'] = [new_entity_id]
                info['rule_applied'] = rule.name
                    
                # Log the number of valid actions after the step if enumeration is enabled
                if self.enumerate_actions and not is_simulation:
                    logger.info(f"Valid actions after step: {len(self.valid_actions)}")
                    
                # Make sure we safely shutdown the executor
                if executor:
                    executor.shutdown(wait=False)
                    executor = None
                    
                # Return observation, reward, done, truncated, info
                return self.graph, reward, done, False, info
                
            except FuturesTimeoutError:
                # Handle timeout within the try block to ensure executor cleanup
                execution_time = time.time() - start_time
                logger.error(f"Rule {rule.name} timed out after {execution_time:.3f} seconds (limit: {self.rule_application_timeout}s)")
                
                # Update timeout statistics
                self.timeout_count += 1
                self.last_timeout_step = self.current_step
                if rule.name in self.timed_out_rules:
                    self.timed_out_rules[rule.name] += 1
                else:
                    self.timed_out_rules[rule.name] = 1
                
                # Cancel the future and aggressive shutdown of the executor
                if future:
                    future.cancel()
                
                # Forcefully shutdown the executor to kill the thread
                if executor:
                    executor.shutdown(wait=False)
                    executor = None
                
                # Add action to applied_actions to prevent it from being tried again
                self.applied_actions.add(action_obj)
                
                # Return error info
                info['error'] = f"Rule {rule.name} timed out after {execution_time:.3f} seconds (limit: {self.rule_application_timeout}s)"
                info['timeout'] = True  # Add a specific flag for timeouts
                info['timeout_stats'] = {
                    'total_timeouts': self.timeout_count,
                    'timed_out_rules': self.timed_out_rules.copy()
                }
                
                return self.graph, 0.0, done, False, info
                
        except Exception as e:
            # If rule application fails, return the current state with zero reward
            error_traceback = traceback.format_exc()
            logger.error(f"Rule application failed: {e}")
            logger.error(error_traceback)
            
            # Cancel the future if possible
            if future:
                future.cancel()
            
            # Forcefully shutdown the executor to kill the thread
            if executor:
                executor.shutdown(wait=False)
                executor = None
            
            # Return the unchanged graph with error information
            info['error'] = str(error_traceback)
            
            return self.graph, 0.0, done, False, info
    
    def _is_forbidden_pattern(self, rule: ProductionRule, input_nodes: List[str]) -> bool:
        """
        Check if a potential action would create a forbidden pattern.
        
        Args:
            rule: The production rule to be applied
            input_nodes: List of input node IDs
            
        Returns:
            bool: True if this would create a forbidden pattern
        """
        # Check each registered pattern
        for pattern_check in self.forbidden_patterns:
            if pattern_check(self.graph, rule, input_nodes):
                logger.debug(f"Forbidden pattern detected: {pattern_check.__name__} for rule {rule.name}")
                return True
        return False

    def _compute_valid_actions(self, concepts: Optional[List[str]] = None, conjectures: Optional[List[str]] = None) -> List[ValidAction]:
        """
        Compute all valid actions for current state.
        
        This method computes all possible valid actions by:
        1. Iterating through all production rules
        2. For each rule, finding all valid combinations of input entities
        3. For each combination, checking if the rule can be applied with valid parameterizations
        4. Filtering out any actions that would create forbidden patterns
        
        Args:
            concepts: Optional list of concept IDs to consider. If None, uses all concepts in the graph.
            conjectures: Optional list of conjecture IDs to consider. If None, uses all conjectures in the graph.
        
        Returns:
            List[ValidAction]: All valid actions in the current state
        """
        valid_actions = []
        
        # Get all entities from the graph if not provided
        all_concepts = concepts if concepts is not None else self.graph.get_all_concepts()
        all_conjectures = conjectures if conjectures is not None else self.graph.get_all_conjectures()
        
        # Process each rule
        for rule_idx, rule in enumerate(self.rules):
            # Get input types for this rule
            input_types = rule.get_input_types()
            
            # Handle both formats of input_types (single list or list of alternatives)
            if input_types and isinstance(input_types[0], list):
                # New format: list of alternative input specifications
                for alternative in input_types:
                    self._generate_actions_for_input_types(rule_idx, rule, alternative, 
                                                  all_concepts, all_conjectures, valid_actions)
            else:
                # Original format: single list of tuples
                self._generate_actions_for_input_types(rule_idx, rule, input_types, 
                                              all_concepts, all_conjectures, valid_actions)
        
        # Filter out actions that would create forbidden patterns
        valid_actions = [
            action for action in valid_actions 
            if not self._is_forbidden_pattern(self.rules[action.rule_idx], action.input_nodes)
        ]
        
        return valid_actions
    
    def _generate_actions_for_input_types(self, rule_idx, rule, input_types, 
                                  all_concepts, all_conjectures, valid_actions):
        """
        Generate valid actions for a rule with the given input types.
        
        Args:
            rule_idx: Index of the rule
            rule: The rule object
            input_types: Input types specification for the rule
            all_concepts: List of all concept IDs
            all_conjectures: List of all conjecture IDs
            valid_actions: List to append valid actions to
        """
        try:
            # Create a list of lists, where each inner list contains tuples of (entity, entity_id)
            input_entities_by_type = []
            
            # Collect all entities matching each input type
            for entity_type, concept_type in input_types:
                entities = []
                
                if entity_type == Concept:
                    # Filter concepts based on their type
                    for concept_id in all_concepts:
                        entity, node_type, _ = self.graph.get_node(concept_id)
                        if not isinstance(entity, Concept):
                            continue
                            
                        actual_type = entity.examples.example_structure.concept_type
                        
                        # Check if the concept matches the required type
                        if concept_type is None or actual_type == concept_type or (
                            isinstance(concept_type, list) and actual_type in concept_type
                        ):
                            entities.append((entity, concept_id))
                
                elif entity_type == Conjecture:
                    # Add all conjectures (no filtering by type)
                    for conjecture_id in all_conjectures:
                        entity, node_type, _ = self.graph.get_node(conjecture_id)
                        if isinstance(entity, Conjecture):
                            entities.append((entity, conjecture_id))
                
                input_entities_by_type.append(entities)
            
            # Generate all combinations of input entities
            try:
                combinations = self._generate_input_combinations(input_entities_by_type)
            except Exception as e:
                logger.error(f"Error generating input combinations for rule {rule.name}: {e}")
                return
            
            # Check each combination to see if it's valid
            for input_entity_pairs in combinations:
                # Extract entities and their IDs
                input_entities = [pair[0] for pair in input_entity_pairs]
                input_ids = [pair[1] for pair in input_entity_pairs]
                
                # Get valid parameterizations for this combination
                try:
                    parameterizations = rule.get_valid_parameterizations(*input_entities)
                except Exception as e:
                    input_entity_names = [entity.name if hasattr(entity, 'name') else str(entity) for entity in input_entities]
                    logger.error(f"Error getting valid parameterizations for rule {rule.name} with inputs {input_entity_names}: {e}")
                    continue
                
                for params in parameterizations:
                    # Check if the rule can be applied with these inputs and parameters
                    can_apply_params = copy.deepcopy(params)
                    
                    # Pass verbose=False to can_apply to suppress debug prints during action enumeration
                    try:
                        can_apply_result = rule.can_apply(*input_entities, **can_apply_params, verbose=False)
                    except Exception as e:
                        input_entity_names = [entity.name if hasattr(entity, 'name') else str(entity) for entity in input_entities]
                        logger.error(f"Error checking if rule {rule.name} can be applied with inputs {input_entity_names} and params {params}: {e}")
                        continue
                        
                    if can_apply_result:
                        # Create a valid action
                        valid_actions.append(
                            ValidAction(
                                rule_idx=rule_idx,
                                input_nodes=input_ids,
                                params=params
                            )
                        )
        except Exception as e:
            error_traceback = traceback.format_exc()
            logger.error(f"Error in _generate_actions_for_input_types for rule {rule.name}: {e}")
            logger.error(error_traceback)
    
    def _generate_input_combinations(self, input_entities_by_type):
        """
        Generate all combinations of input entities.
        
        Args:
            input_entities_by_type: List of lists of entities, one list per required type
            
        Returns:
            List of input entity combinations
        """
        # If no input types, return empty list
        if not input_entities_by_type:
            return []
            
        # If any input type has no entities, return empty list
        if any(len(entities) == 0 for entities in input_entities_by_type):
            return []
            
        # If only one input type, return each entity as a single-element list
        if len(input_entities_by_type) == 1:
            return [[entity] for entity in input_entities_by_type[0]]
        
        # Generate all combinations (original logic)
        combinations = [[entity] for entity in input_entities_by_type[0]]
        
        # Add each additional input type
        for entities in input_entities_by_type[1:]:
            new_combinations = []
            for combo in combinations:
                for entity in entities:
                    new_combinations.append(combo + [entity])
            combinations = new_combinations
            
        return combinations
    
    def _update_valid_actions(self, new_entities: List[str]) -> None:
        """
        Update valid actions based on newly created entities.
        
        This method checks all possible rule applications that involve at least
        one of the new entities and adds them to the valid actions list.
        
        Args:
            new_entities: List of newly created entity IDs
        """
        if not new_entities:
            return
            
        # First validate existing actions to ensure they're still valid
        self._validate_existing_actions()
        
        # Get all entities from the graph
        all_concepts = self.graph.get_all_concepts()
        all_conjectures = self.graph.get_all_conjectures()
        
        # Process each new entity separately
        for new_entity_id in new_entities:
            # Process each rule
            for rule_idx, rule in enumerate(self.rules):
                # Get input types for this rule
                input_types = rule.get_input_types()
                
                # Handle both formats of input_types (single list or list of alternatives)
                if input_types and isinstance(input_types[0], list):
                    # New format: list of alternative input specifications
                    for alternative in input_types:
                        self._check_rule_with_new_entity(
                            rule_idx, rule, alternative, new_entity_id, all_concepts, all_conjectures
                        )
                else:
                    # Original format: single list of tuples
                    self._check_rule_with_new_entity(
                        rule_idx, rule, input_types, new_entity_id, all_concepts, all_conjectures
                    )
                    
        # Filter out any actions that would create forbidden patterns
        self.valid_actions = [
            action for action in self.valid_actions 
            if not self._is_forbidden_pattern(self.rules[action.rule_idx], action.input_nodes)
        ]
    
    def _check_rule_with_new_entity(self, rule_idx, rule, input_types, new_entity_id, all_concepts, all_conjectures):
        """
        Check all possible applications of a rule that involve the new entity.
        
        Args:
            rule_idx: Index of the rule
            rule: The rule object
            input_types: Input types specification for the rule
            new_entity_id: ID of the new entity
            all_concepts: List of all concept IDs
            all_conjectures: List of all conjecture IDs
        """
        try:
            # Get the entity object for the new entity
            new_entity, new_node_type, _ = self.graph.get_node(new_entity_id)
            
            # Create a list of lists, where each inner list contains tuples of (entity, entity_id)
            input_entities_by_type = []
            
            # Flag to check if the new entity is compatible with any input type
            new_entity_is_compatible = False
            
            # Collect all entities matching each input type
            for entity_type, concept_type in input_types:
                entities = []
                
                if entity_type == Concept and new_node_type == NodeType.CONCEPT:
                    # Get all concepts of the required type
                    for concept_id in all_concepts:
                        entity, node_type, _ = self.graph.get_node(concept_id)
                        if not isinstance(entity, Concept):
                            continue
                            
                        # Check if concept matches required type
                        actual_type = entity.examples.example_structure.concept_type
                        if concept_type is None or actual_type == concept_type or (
                            isinstance(concept_type, list) and actual_type in concept_type
                        ):
                            entities.append((entity, concept_id))
                            
                            # Check if new entity is compatible with this position
                            if concept_id == new_entity_id:
                                new_entity_is_compatible = True
                
                elif entity_type == Conjecture and new_node_type == NodeType.CONJECTURE:
                    # Get all conjectures
                    for conjecture_id in all_conjectures:
                        entity, node_type, _ = self.graph.get_node(conjecture_id)
                        if isinstance(entity, Conjecture):
                            entities.append((entity, conjecture_id))
                            
                            # Check if new entity is compatible with this position
                            if conjecture_id == new_entity_id:
                                new_entity_is_compatible = True
                
                input_entities_by_type.append(entities)
            
            # If the new entity isn't compatible with any input type, we can skip this rule
            if not new_entity_is_compatible:
                return
                
            # Generate all combinations of inputs
            try:
                combinations = self._generate_input_combinations(input_entities_by_type)
            except Exception as e:
                logger.error(f"Error generating input combinations for rule {rule.name}: {e}")
                return
            
            # Check each combination to see if it involves the new entity and is valid
            for combo in combinations:
                # Extract entities and their IDs
                input_entities = [pair[0] for pair in combo]
                input_ids = [pair[1] for pair in combo]
                
                # Skip combinations that don't include the new entity
                if new_entity_id not in input_ids:
                    continue
                
                # Get valid parameterizations for this combination
                try:
                    parameterizations = rule.get_valid_parameterizations(*input_entities)
                except Exception as e:
                    input_entity_names = [entity.name if hasattr(entity, 'name') else str(entity) for entity in input_entities]
                    logger.error(f"Error getting valid parameterizations for rule {rule.name} with inputs {input_entity_names}: {e}")
                    continue
                
                for params in parameterizations:
                    # Check if the rule can be applied with these inputs and parameters
                    # Use deep copy instead of shallow copy for nested dictionaries
                    can_apply_params = copy.deepcopy(params)
                    
                    # Pass verbose=False to can_apply to suppress debug prints during action enumeration
                    try:
                        can_apply_result = rule.can_apply(*input_entities, **can_apply_params, verbose=False)
                    except Exception as e:
                        input_entity_names = [entity.name if hasattr(entity, 'name') else str(entity) for entity in input_entities]
                        logger.error(f"Error checking if rule {rule.name} can be applied with inputs {input_entity_names} and params {params}: {e}")
                        continue
                        
                    if can_apply_result:
                        # Create a new valid action
                        new_action = ValidAction(rule_idx, input_ids, params)
                        
                        # Check for duplicates before adding
                        is_duplicate = False
                        for existing_action in self.valid_actions:
                            if (existing_action.rule_idx == new_action.rule_idx and
                                existing_action.input_nodes == new_action.input_nodes and
                                existing_action.params == new_action.params):
                                is_duplicate = True
                                break
                                
                        if not is_duplicate:
                            self.valid_actions.append(new_action)
        except Exception as e:
            error_traceback = traceback.format_exc()
            logger.error(f"Error in _check_rule_with_new_entity for rule {rule.name} with new entity {new_entity_id}: {e}")
            logger.error(error_traceback)
    
    def _find_example_with_z3(
        self,
        entity: Concept,
        search_type: Literal['example', 'nonexample'],
        arity: int,
        z3_template: Any, # Should be Z3Template, avoid circular import
        timeout_seconds: float,
        max_attempts: int
    ) -> bool:
        """
        Attempts to find and add an example or non-example for a concept using its Z3 representation.

        Args:
            entity: The Concept entity to find an example/non-example for.
            search_type: 'example' or 'nonexample'.
            arity: The number of arguments the concept takes.
            z3_template: The Z3Template object for the concept.
            timeout_seconds: Maximum time allowed for this search.
            max_attempts: Maximum number of random tuples to check.

        Returns:
            True if an example/non-example of the specified type was found and added, False otherwise.
        """
        search_start_time = time.time()
        found = False
        attempts = 0

        while time.time() - search_start_time < timeout_seconds and attempts < max_attempts:
            attempts += 1
            # Note(_; 5/4): Works for Nat types right, would need to change for other types.
            # Range chosen so that comparable examples can be found for concepts, preventing some irrelevant conjectures.
            # Generate random ar within [0, 5] range (inclusive)
            random_tuple = np.random.randint(0, 6, size=arity)
            # Convert to Nat
            random_tuple = [Nat(int(t)) for t in random_tuple]
            # also tuple with no Nat
            random_tuple_as_example = tuple([t.value for t in random_tuple])

            try:
                # Assumes program.check_example(tuple) -> bool exists
                logger.info(f"Checking {[t.value for t in random_tuple]} for {entity.name}") # TODO(_; 5/4): Remove after debugging.
                result = z3_template.check_example(entity.is_function() or entity.is_constant(), random_tuple)
                is_satisfying = result.proved
                timed_out = result.timed_out
                if timed_out:
                    logger.debug(f"Z3 check for {random_tuple_as_example} timed out.")
                    continue # Skip if check timed out

                if is_satisfying is None:
                    logger.debug(f"Z3 check for {random_tuple_as_example} returned Unknown.")
                    continue # Skip if check returns Unknown

                if search_type == 'example' and is_satisfying:
                    logger.info(f"Z3 search found example for '{entity.name}': {random_tuple_as_example}")
                    entity.add_example(random_tuple_as_example, override=True)
                    # --- Store Found Example in Cache --- 
                    if self.episodic_cache:
                        self.episodic_cache.add_example(entity.name, random_tuple_as_example)
                    # --- End Store Example --- 
                    found = True
                    break
                elif search_type == 'nonexample' and not is_satisfying:
                    logger.info(f"Z3 search found non-example for '{entity.name}': {random_tuple_as_example}")
                    entity.add_nonexample(random_tuple_as_example, override=True)
                    # --- Store Found Non-Example in Cache --- 
                    if self.episodic_cache:
                        self.episodic_cache.add_nonexample(entity.name, random_tuple_as_example)
                    # --- End Store Non-Example --- 
                    found = True
                    break

            except Exception as check_err:
                logger.warning(f"Error checking Z3 {search_type} {random_tuple} for '{entity.name}': {check_err}")
                # Continue searching even if one check fails

        elapsed_time = time.time() - search_start_time
        if not found:
            if attempts >= max_attempts:
                logger.warning(f"Z3 {search_type} search for '{entity.name}' reached max attempts ({max_attempts}).")
            elif elapsed_time >= timeout_seconds:
                logger.warning(f"Z3 {search_type} search for '{entity.name}' timed out after {elapsed_time:.3f}s (limit: {timeout_seconds}s).")

        return found

    def render(self, mode='human'):
        """Render the current state."""
        logger.info(f"\nCurrent State (Step {self.current_step}):")
        logger.info(f"Number of valid actions: {len(self.valid_actions)}")
        logger.info(f"Number of concepts: {len(self.graph.get_all_concepts())}")
        logger.info(f"Number of conjectures: {len(self.graph.get_all_conjectures())}")
        logger.info(f"Number of theorems: {len(self.graph.get_all_theorems())}")
    
    def close(self):
        """Clean up environment resources."""
        # Report timeout statistics
        if self.timeout_count > 0:
            logger.info(f"Timeout statistics for this environment:")
            logger.info(f"  Total timeouts: {self.timeout_count}")
            logger.info(f"  Last timeout step: {self.last_timeout_step}")
            logger.info(f"  Timed out rules: {self.timed_out_rules}")
        
        # Ensure all signals are reset
        try:
            signal.alarm(0)
        except Exception:
            pass
            
        # Call the parent class's close method
        super().close()

    # TODO(_; 3/13): Remove this when we feel it is no longer necessary.
    def _debug_validate_all_actions(self) -> None:
        """
        Debug method to validate all actions in the valid_actions list.
        
        This method checks each action in the valid_actions list to ensure
        that the rule can still be applied with the given inputs and parameters.
        It prints warnings for any actions that are no longer valid but doesn't
        remove them.
        
        """
        if not self.enumerate_actions or not self.valid_actions:
            return
            
        invalid_actions = []
        
        for i, action in enumerate(self.valid_actions):
            # Get the rule and inputs
            rule = self.rules[action.rule_idx]
            
            # Get the actual entity objects for the input nodes
            input_entities = []
            input_entity_names = []
            entity_ids_found = True
            
            try:
                for node_id in action.input_nodes:
                    entity, _, _ = self.graph.get_node(node_id)
                    input_entities.append(entity)
                    input_entity_names.append(entity.name if hasattr(entity, 'name') else str(entity))
                
                # Check if the rule can still be applied
                # Use deep copy instead of shallow copy for nested dictionaries
                can_apply_params = copy.deepcopy(action.params)
                
                try:
                    can_apply_result = rule.can_apply(*input_entities, **can_apply_params, verbose=False)
                    
                    if not can_apply_result:
                        invalid_actions.append((i, rule.name, input_entity_names, action.params))
                except Exception as e:
                    logger.error(f"Error checking if action {i} is valid: {e}")
                    invalid_actions.append((i, rule.name, input_entity_names, action.params))
            except Exception as e:
                logger.error(f"Error getting entity for action {i}: {e}")
                invalid_actions.append((i, rule.name, action.input_nodes, action.params))
                
        # Log warnings for invalid actions
        if invalid_actions:
            logger.warning(f"Found {len(invalid_actions)} invalid actions in valid_actions list:")
            for i, rule_name, inputs, params in invalid_actions:
                logger.warning(f"  - Action {i}: Rule: {rule_name}, Inputs: {inputs}, Params: {params}")

    def _validate_existing_actions(self) -> None:
        """
        Validate all existing actions to ensure they're still valid.
        
        This method checks each action in the valid_actions list to ensure
        that the rule can still be applied with the given inputs and parameters.
        If an action is no longer valid, it is removed from the list.
        """
        if not self.enumerate_actions or not self.valid_actions:
            return
            
        valid_actions = []
        invalid_actions = []
        
        for i, action in enumerate(self.valid_actions):
            # Get the rule and inputs
            rule = self.rules[action.rule_idx]
            
            # Get the actual entity objects for the input nodes
            input_entities = []
            input_entity_names = []
            try:
                for node_id in action.input_nodes:
                    entity, _, _ = self.graph.get_node(node_id)
                    input_entities.append(entity)
                    input_entity_names.append(entity.name if hasattr(entity, 'name') else str(entity))
                    
                # Check if the rule can still be applied
                # Use deep copy instead of shallow copy for nested dictionaries
                can_apply_params = copy.deepcopy(action.params)
                
                try:
                    can_apply_result = rule.can_apply(*input_entities, **can_apply_params, verbose=False)
                    
                    if can_apply_result:
                        valid_actions.append(action)
                    else:
                        invalid_actions.append((i, rule.name, input_entity_names, action.params))
                except Exception as e:
                    logger.error(f"Error checking if action {i} is valid: {e}")
                    invalid_actions.append((i, rule.name, input_entity_names, action.params))
            except Exception as e:
                logger.error(f"Error getting entity for action {i}: {e}")
                invalid_actions.append((i, rule.name, action.input_nodes, action.params))
        
        # Log information about removed actions
        if invalid_actions:
            logger.info(f"Removed {len(invalid_actions)} invalid actions from valid_actions list")
            for i, rule_name, inputs, params in invalid_actions:
                logger.info(f"  - Removed action {i}: Rule: {rule_name}, Inputs: {inputs}, Params: {params}")
                
        # Update the valid_actions list
        self.valid_actions = valid_actions
        
        # Update action space
        self.action_space = spaces.Discrete(len(self.valid_actions)) 

    def _check_name_override(self, entity: Union[Concept, Conjecture]) -> Optional[str]:
        """
        Check if an entity's name matches any ground truth entity.
        
        This method:
        1. Checks if the entity has a name
        2. Uses the ground truth system to check if the name matches any known entity
        3. If found, returns the canonical name for that entity
        4. If the canonical name already exists, adds a distinguishing tag
        
        Args:
            entity: The entity to check for name overrides
            
        Returns:
            Optional[str]: The canonical name if found, None otherwise
        """
        if not hasattr(entity, 'name'):
            return None
            
        # Check if this is a ground truth entity
        if is_ground_truth_entity(entity.name):
            # Get the ground truth entity and update the entity's implementation
            ground_truth = get_ground_truth_entity(entity.name)
            if ground_truth:
                canonical_name = ground_truth.canonical_name
                
                # Check if this canonical name already exists in the graph
                existing_entities = []
                for concept_id in self.graph.get_all_concepts():
                    concept, _, _ = self.graph.get_node(concept_id)
                    if hasattr(concept, 'name') and concept.name == canonical_name:
                        existing_entities.append(concept)
                for conjecture_id in self.graph.get_all_conjectures():
                    conjecture, _, _ = self.graph.get_node(conjecture_id)
                    if hasattr(conjecture, 'name') and conjecture.name == canonical_name:
                        existing_entities.append(conjecture)
                
                # If there are existing entities with this name, add a distinguishing tag
                if existing_entities:
                    # Find the next available tag number
                    tag = 1
                    while any(e.name == f"{canonical_name}{'*'*tag}" for e in existing_entities):
                        tag += 1
                    canonical_name = f"{canonical_name}{'*'*tag}"
                
                # Update name and optionally update implementation based on configuration
                update_entity_implementation(
                    entity, 
                    entity.name, 
                    update_implementation=self.update_ground_truth_implementations
                )
                # Override the canonical name if we added a tag
                entity.name = canonical_name
                return canonical_name
                
        return None 

    def remove_entity(
        self,
        entity_id: str,
        allow_rediscovery: bool = False,
        update_valid_actions: bool = True # Note(_; 3/28): Presently, I don't see any reason to set this to False.
    ) -> bool:
        """
        Remove an entity (concept, conjecture, or theorem) from the graph.
        
        Args:
            entity_id: ID of the entity to remove
            allow_rediscovery: If True, removes the construction action from applied_actions
            update_valid_actions: If True, updates valid_actions list immediately
            
        Returns:
            bool: True if removal was successful, False otherwise
        """
        # Validation checks
        if not self.allow_entity_removal:
            logger.warning("Entity removal is not enabled in this environment")
            return False
            
        if entity_id not in self.graph:
            logger.warning(f"Entity {entity_id} not found in graph")
            return False
            
        # Get entity type and construction info before removal
        entity, node_type, _ = self.graph.get_node(entity_id)
        
        # Log removal attempt
        logger.info(f"Attempting to remove {node_type.value} {entity_id} ({entity.name})")
        
        try:
            # Remove based on entity type and get all construction steps
            removed_nodes = []
            removed_steps = []
            if node_type == NodeType.CONCEPT:
                removed_nodes, removed_steps = self.graph.remove_concept(entity_id)
            elif node_type == NodeType.CONJECTURE:
                removed_nodes, removed_steps = self.graph.remove_conjecture(entity_id)
            elif node_type == NodeType.THEOREM:
                removed_nodes, removed_steps = self.graph.remove_theorem(entity_id)
            else:
                logger.error(f"Unknown entity type: {node_type}")
                return False
                
            # Remove construction actions for all removed entities if allow_rediscovery is True
            if allow_rediscovery:
                for step in removed_steps:
                    action = ValidAction(
                        rule_idx=self.rules.index(step.rule),
                        input_nodes=step.input_node_ids,
                        params=step.parameters
                    )
                    # Find and remove the matching action from applied_actions
                    matching_action = next(
                        (applied_action for applied_action in self.applied_actions
                         if isinstance(applied_action, ValidAction) and
                         applied_action.rule_idx == action.rule_idx and
                         applied_action.input_nodes == action.input_nodes and
                         applied_action.params == action.params),
                        None
                    )
                    if matching_action:
                        self.applied_actions.remove(matching_action)
                        logger.info(f"Removed construction action from applied_actions for entity")
            # Update valid actions if enabled and if we're tracking them
            if update_valid_actions and self.enumerate_actions:
                # Remove any valid actions that involve any of the removed entities
                self.valid_actions = [
                    action for action in self.valid_actions 
                    if not any(entity_id in action.input_nodes for entity_id in removed_nodes)
                ]
                # Also remove any actions that were in applied_actions (unless allow_rediscovery is True)
                if not allow_rediscovery:
                    self.valid_actions = [
                        action for action in self.valid_actions 
                        if not any(action.rule_idx == applied_action.rule_idx and
                                  action.input_nodes == applied_action.input_nodes and
                                  action.params == applied_action.params
                                  for applied_action in self.applied_actions)
                    ]
                logger.info("Updated valid actions after entity removal")
                
            logger.info(f"Successfully removed {node_type.value} {entity_id}")
            return True
            
        except Exception as e:
            logger.error(f"Error removing entity {entity_id}: {str(e)}")
            return False
            
    def remove_concept(
        self,
        concept_id: str,
        allow_rediscovery: bool = False,
        update_valid_actions: bool = True
    ) -> bool:
        """Convenience method for removing concepts."""
        return self.remove_entity(concept_id, allow_rediscovery, update_valid_actions)
        
    def remove_conjecture(
        self,
        conjecture_id: str,
        allow_rediscovery: bool = False,
        update_valid_actions: bool = True
    ) -> bool:
        """Convenience method for removing conjectures."""
        return self.remove_entity(conjecture_id, allow_rediscovery, update_valid_actions)
        
    def remove_theorem(
        self,
        theorem_id: str,
        allow_rediscovery: bool = False,
        update_valid_actions: bool = True
    ) -> bool:
        """Convenience method for removing theorems."""
        return self.remove_entity(theorem_id, allow_rediscovery, update_valid_actions) 