"""Theory builder for mathematical discovery.

This module coordinates the process of mathematical discovery by:
1. Setting up the environment with an initial knowledge graph
2. Using a policy to guide the discovery process
3. Tracking and logging discoveries and their interestingness
"""

import os
import logging
import importlib
from typing import Dict, Any, Optional, List, Tuple, Union, Set
from dataclasses import dataclass
import hydra
from omegaconf import DictConfig, OmegaConf
import numpy as np
import copy
import secrets
import time
import asyncio
import sys
import signal
import atexit
import multiprocessing
import traceback
from datetime import datetime
import yaml # Needed for parsing the definition file
import concurrent.futures # Add this import at the top of theory_builder.py if not present
# Use pebble for process pool with termination capabilities
import pebble
from pebble import ProcessExpired # Import the specific exception for pebble timeouts
from concurrent.futures import TimeoutError as FuturesTimeoutError # Keep for future.result() timeout

from frame.environments.math_env import MathEnv, DEFAULT_RULE_APPLICATION_TIMEOUT, DEFAULT_Z3_PROVER_TIMEOUT
from frame.knowledge_base.entities import DEFAULT_EXAMPLE_VERIFICATION_TIMEOUT
from frame.knowledge_base.knowledge_graph import KnowledgeGraph, NodeType
from frame.knowledge_base.entities import Entity
from frame.policies.base import Policy
from frame.productions.base import ProductionRule
from frame.utils.experiment_utils import (
    print_discovered_entities, 
    check_for_duplicates, 
    generate_visualizations,
    cleanup_resources
)
from frame.tools.cache import EpisodicCache, ProofStatus

def generate_unique_dir():
    """Generate a unique directory name with timestamp and random hash."""
    # Get timestamp with milliseconds
    timestamp = datetime.now().strftime("%Y-%m-%d/%H-%M-%S-%f")[:23]  # Truncate to milliseconds
    # Generate a random 16-char hex string
    random_hash = secrets.token_hex(8)
    return f"{timestamp}-{random_hash}"

# Dictionary mapping rule names to their import paths
RULE_PATHS = {
    # Concept rules
    'MapIterateRule': 'frame.productions.concepts.MapIterateRule',
    'ComposeRule': 'frame.productions.concepts.ComposeRule',
    'NegateRule': 'frame.productions.concepts.NegateRule',
    'ExistsRule': 'frame.productions.concepts.ExistsRule',
    'MatchRule': 'frame.productions.concepts.MatchRule',
    'SpecializeRule': 'frame.productions.concepts.SpecializeRule',
    'SizeRule': 'frame.productions.concepts.SizeRule',
    'ConstantRule': 'frame.productions.concepts.ConstantRule',
    'ForallRule': 'frame.productions.concepts.ForallRule',
    
    # Conjecture rules
    'EquivalenceRule': 'frame.productions.conjectures.EquivalenceRule',
    'ImplicationRule': 'frame.productions.conjectures.ImplicationRule',
    'NonexistenceRule': 'frame.productions.conjectures.NonexistenceRule'
}

def get_timestamp() -> str:
    """Get current timestamp string."""
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def setup_worker_logging(episode_num: int, logs_dir: str) -> logging.Logger:
    """Set up logging for a worker process.
    
    Args:
        episode_num: The episode number
        logs_dir: Directory to store log files
        
    Returns:
        Logger configured for this episode
    """
    # Create episode-specific logger
    logger = logging.getLogger(f"episode_{episode_num}")
    logger.setLevel(logging.INFO)
    
    # Remove any existing handlers
    logger.handlers = []
    
    # Create file handler
    log_file = os.path.join(logs_dir, f"episode_{episode_num}.log")
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    
    # Create formatter
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    
    # Add handler to logger
    logger.addHandler(file_handler)
    
    # Add a console handler too for debugging
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    
    # Important: Set propagate to False to prevent messages going to root logger
    logger.propagate = False
    
    return logger

def import_object_from_path(full_path: str) -> Any:
    """
    Import an object (function, class, variable) from its import path.
    
    Args:
        full_path: The full import path (e.g., 'module.submodule.function_name')
        
    Returns:
        The imported object
        
    Raises:
        ImportError: If the module cannot be imported
        AttributeError: If the object cannot be found in the module
    """
    try:
        # Split the path into module path and object name
        module_path, object_name = full_path.rsplit('.', 1)
        
        # Import the module
        module = importlib.import_module(module_path)
        
        # Get the object
        return getattr(module, object_name)
        
    except (ImportError, AttributeError) as e:
        logging.error(f"Failed to import {full_path}: {e}")
        raise

def get_full_rule_path(rule_name: str) -> str:
    """
    Convert a simple rule name to its full import path.
    
    Args:
        rule_name: The simple rule name (e.g., 'MapIterateRule')
        
    Returns:
        The full import path (e.g., 'frame.productions.concepts.MapIterateRule')
    """
    if rule_name in RULE_PATHS:
        return RULE_PATHS[rule_name]
    else:
        # Assume it's already a full path
        return rule_name

def create_rule_from_target(target: str, **kwargs) -> ProductionRule:
    """
    Create a production rule instance from its target path.
    
    Args:
        target: The import path in format 'module.submodule.ClassName'
        **kwargs: Additional parameters to pass to the constructor
        
    Returns:
        An instantiated production rule
    """
    try:
        rule_class = import_object_from_path(target)
        
        # Ensure it's a production rule or subclass
        if not issubclass(rule_class, ProductionRule):
            raise TypeError(f"{target} is not a ProductionRule subclass")
            
        # Create an instance with the provided kwargs
        return rule_class(**kwargs)
        
    except Exception as e:
        logging.error(f"Failed to create rule from {target}: {e}")
        raise

def setup_episode_directories(base_output_dir: str, episode_num: int) -> Dict[str, str]:
    """Set up directory structure for a specific episode.
    
    Args:
        base_output_dir: Base output directory
        episode_num: The episode number
        
    Returns:
        Dictionary containing paths for the episode's directories
    """
    # Create episode directory
    episode_dir = os.path.join(base_output_dir, f"episode_{episode_num}")
    os.makedirs(episode_dir, exist_ok=True)
    
    # Create subdirectories
    logs_dir = os.path.join(episode_dir, "logs")
    graphs_dir = os.path.join(episode_dir, "graphs")
    viz_dir = os.path.join(episode_dir, "visualizations")
    checkpoints_dir = os.path.join(episode_dir, "checkpoints")
    
    os.makedirs(logs_dir, exist_ok=True)
    os.makedirs(graphs_dir, exist_ok=True)
    os.makedirs(viz_dir, exist_ok=True)
    os.makedirs(checkpoints_dir, exist_ok=True)
    
    return {
        "episode_dir": episode_dir,
        "logs_dir": logs_dir,
        "graphs_dir": graphs_dir,
        "viz_dir": viz_dir,
        "checkpoints_dir": checkpoints_dir
    }

def save_checkpoint(env: MathEnv, step: int, checkpoints_dir: str, logger: logging.Logger) -> None:
    """Save current state as checkpoint.
    
    Args:
        env: The environment containing the graph to save
        step: Current step number
        checkpoints_dir: Directory to save checkpoint in
        logger: Logger for this episode
    """
    if not checkpoints_dir:
        logger.warning("Cannot save checkpoint: checkpoints directory not provided")
        return
        
    # Save the current graph
    timestamp = get_timestamp()
    checkpoint_path = os.path.join(
        checkpoints_dir,
        f"checkpoint_step_{step}_{timestamp}.pkl"
    )
    
    try:
        env.graph.save(checkpoint_path)
        logger.info(f"Checkpoint saved to {checkpoint_path}")
    except Exception as e:
        logger.error(f"Failed to save checkpoint: {e}")

def instantiate_policy(policy_config: Union[DictConfig, Dict[str, Any]]) -> Policy:
    """Instantiate a policy from its configuration."""
    try:
        # If it's a DictConfig with _target_, instantiate directly or pass DictConfig
        if isinstance(policy_config, DictConfig) and OmegaConf.select(policy_config, '_target_', default=None):
            # Option 1: Use Hydra instantiate if available (preferred)
            # try:
            #     logging.debug(f"Attempting Hydra instantiation for: {policy_config._target_}")
            #     return hydra.utils.instantiate(policy_config)
            # except Exception as hydra_e:
            #     logging.warning(f"Hydra instantiation failed ({hydra_e}), falling back to manual import.")
            
            # Option 2: Manual import and pass the DictConfig (if policy supports it)
            target = policy_config._target_
            policy_class = import_object_from_path(target)
            logging.debug(f"Instantiating {target} by passing DictConfig directly")
            # Assuming the policy class constructor can handle the DictConfig object
            # Unpack the DictConfig into keyword arguments for the constructor
            # Convert to a standard dict first to ensure proper unpacking
            config_dict = OmegaConf.to_container(policy_config, resolve=True)
            return policy_class(**config_dict) 

        # Fallback/Alternative: Handle dict case with _target_
        elif isinstance(policy_config, dict) and '_target_' in policy_config:
            target = policy_config['_target_']
            policy_class = import_object_from_path(target)
            kwargs = {k: v for k, v in policy_config.items() if k != '_target_'}
            logging.debug(f"Instantiating policy from dict: {target} with kwargs: {kwargs}")
            return policy_class(**kwargs)
        elif isinstance(policy_config, dict) and 'type' in policy_config:
             # Existing logic for type/params dict
             policy_type = policy_config['type']
             if '.' not in policy_type:
                 raise ValueError(f"Policy type must be a full import path, got: {policy_type}")
             policy_params = policy_config.get('params', {})
             policy_class = import_object_from_path(policy_type)
             logging.debug(f"Instantiating policy from type/params dict: {policy_type} with params: {policy_params}")
             return policy_class(**policy_params)

        else:
            logging.error(f"Invalid policy configuration type: {type(policy_config)} - {policy_config}")
            raise ValueError(f"Invalid policy configuration: {policy_config}")
    except Exception as e:
        logging.error(f"Error instantiating policy from config {policy_config}: {e}")
        logging.error(traceback.format_exc())
        raise

def instantiate_rules(rule_configs: Dict[str, Any]) -> List[ProductionRule]:
    """Instantiate production rules from configurations.
    
    Args:
        rule_configs: Rule configurations
        
    Returns:
        List of instantiated production rules
    """
    # Get a logger instance for this module
    logger = logging.getLogger(__name__)
    
    logger.info(f"Instantiating rules with config: {rule_configs}")
    rules = []
    
    # --- Strategy 1: Load from definition_file if specified --- 
    definition_file_path = rule_configs.get('definition_file')
    if definition_file_path:
        logger.info(f"Attempting to load rules from definition file: {definition_file_path}")
        try:
            # Resolve the path using Hydra's context if necessary (might already be absolute)
            # For simplicity, assume it's resolved correctly by Hydra/OmegaConf
            if not os.path.exists(definition_file_path):
                 raise FileNotFoundError(f"Rule definition file not found: {definition_file_path}")

            with open(definition_file_path, 'r') as f:
                # Use safe_load to parse the YAML structure
                defined_rules = yaml.safe_load(f) 

            if not defined_rules: 
                 logger.warning(f"Rule definition file {definition_file_path} is empty or invalid YAML.")
                 defined_rules = {} # Avoid errors below

            # Process concepts from definition file
            defined_concepts = defined_rules.get('concepts', [])
            logger.info(f"Found {len(defined_concepts)} potential concept rules in definition file.")
            for rule_name in defined_concepts:
                 # Basic check if it looks commented out (might need more robust parsing for edge cases)
                 if isinstance(rule_name, str) and not rule_name.strip().startswith('#'):
                     try:
                         target = get_full_rule_path(rule_name.strip()) # Ensure no extra whitespace
                         rules.append(create_rule_from_target(target))
                        #  logger.info(f"Loaded concept rule '{rule_name.strip()}' from definition file.")
                     except Exception as e:
                         logger.error(f"Failed to instantiate concept rule '{rule_name}' from definition file: {e}")
                 else:
                      logger.debug(f"Skipping commented or invalid concept rule entry: {rule_name}")

            # Process conjectures from definition file
            defined_conjectures = defined_rules.get('conjectures', [])
            logger.info(f"Found {len(defined_conjectures)} potential conjecture rules in definition file.")
            for rule_name in defined_conjectures:
                 if isinstance(rule_name, str) and not rule_name.strip().startswith('#'):
                      try:
                         target = get_full_rule_path(rule_name.strip()) # Ensure no extra whitespace
                         rules.append(create_rule_from_target(target))
                         logger.info(f"Loaded conjecture rule '{rule_name.strip()}' from definition file.")
                      except Exception as e:
                         logger.error(f"Failed to instantiate conjecture rule '{rule_name}' from definition file: {e}")
                 else:
                      logger.debug(f"Skipping commented or invalid conjecture rule entry: {rule_name}")
                      
            # If we successfully loaded from definition file, we are done.
            logger.info(f"Finished instantiating rules from definition file. Total rules loaded: {len(rules)}")
            return rules

        except FileNotFoundError as fnf:
            logger.error(f"{fnf}. Falling back to other methods if configured.")
        except yaml.YAMLError as ye:
            logger.error(f"Error parsing rule definition file {definition_file_path}: {ye}. Falling back.")
        except Exception as e:
             logger.error(f"Unexpected error loading rule definition file {definition_file_path}: {e}. Falling back.", exc_info=True)

    # --- Strategy 2: Use default_rules: True if definition_file wasn't used/failed ---
    # Check for default rules flag
    use_defaults = rule_configs.get('default_rules', False)
    logger.info(f"Checking for default rules: default_rules={use_defaults}")
    
    if use_defaults:
        logger.info("Instantiating default production rules (using RULE_PATHS).")
        # Use the RULE_PATHS defined globally in theory_builder.py
        for rule_name, rule_path in RULE_PATHS.items():
            try:
                rules.append(create_rule_from_target(rule_path))
                logger.info(f"Successfully added default rule: {rule_name} from {rule_path}")
            except Exception as e:
                logger.error(f"Failed to instantiate default rule {rule_name} from {rule_path}: {e}")
        # If default rules were loaded, return them
        if rules: # If default rules were added, return them
             return rules
        else: # If default_rules was true but RULE_PATHS was empty or failed
             logger.warning("default_rules=True but no default rules were instantiated.")
             # Fall through to process explicit lists if any
             
    # --- Strategy 3: Load from explicit lists if definition_file and default_rules weren't used/failed ---
    # Process concept rules
    explicit_concept_rules = rule_configs.get('concepts', [])
    if explicit_concept_rules:
        logger.info(f"Processing {len(explicit_concept_rules)} explicit concept rules.")
    for rule_spec in explicit_concept_rules:
        logger.debug(f"Attempting to instantiate concept rule: {rule_spec}")
        try:
            if isinstance(rule_spec, str):
                target = get_full_rule_path(rule_spec)
                rules.append(create_rule_from_target(target))
            elif isinstance(rule_spec, dict) and '_target_' in rule_spec:
                target = rule_spec.pop('_target_')
                rules.append(create_rule_from_target(target, **rule_spec))
            elif isinstance(rule_spec, dict) and 'target' in rule_spec:
                target = get_full_rule_path(rule_spec['target'])
                kwargs = rule_spec.get('params', {})
                rules.append(create_rule_from_target(target, **kwargs))
            logger.debug(f"Successfully instantiated concept rule: {rule_spec}")
        except Exception as e:
            logger.error(f"Failed to instantiate explicit concept rule {rule_spec}: {e}")
    
    # Process conjecture rules
    explicit_conjecture_rules = rule_configs.get('conjectures', [])
    if explicit_conjecture_rules:
        logger.info(f"Processing {len(explicit_conjecture_rules)} explicit conjecture rules.")
    for rule_spec in explicit_conjecture_rules:
        logger.debug(f"Attempting to instantiate conjecture rule: {rule_spec}")
        try:
            if isinstance(rule_spec, str):
                target = get_full_rule_path(rule_spec)
                rules.append(create_rule_from_target(target))
            elif isinstance(rule_spec, dict) and '_target_' in rule_spec:
                target = rule_spec.pop('_target_')
                rules.append(create_rule_from_target(target, **rule_spec))
            elif isinstance(rule_spec, dict) and 'target' in rule_spec:
                target = get_full_rule_path(rule_spec['target'])
                kwargs = rule_spec.get('params', {})
                rules.append(create_rule_from_target(target, **kwargs))
            logger.debug(f"Successfully instantiated conjecture rule: {rule_spec}")
        except Exception as e:
            logger.error(f"Failed to instantiate explicit conjecture rule {rule_spec}: {e}")
    
    logger.info(f"Finished instantiating rules. Total rules loaded: {len(rules)}")
    return rules

def create_initial_graph(graph_config: Dict[str, Any]) -> KnowledgeGraph:
    """Create initial knowledge graph from configuration.
    
    Args:
        graph_config: Configuration for the initial graph
        
    Returns:
        The initialized knowledge graph
    """
    # Create the knowledge graph
    graph = KnowledgeGraph()
    
    # If there's no import_from specified, return the empty graph
    if 'import_from' not in graph_config:
        return graph
        
    # Get the module to import from
    module_path = graph_config['import_from']
    
    # Add concepts to the graph
    if 'concepts' in graph_config:
        for concept_path in graph_config['concepts']:
            # Check if it's a full path or needs to be combined with module_path
            if '.' in concept_path:
                full_path = concept_path
            else:
                full_path = f"{module_path}.{concept_path}"
                
            try:
                # Get the concept or concept factory function
                concept_obj = import_object_from_path(full_path)
                
                # Check if it's a callable (factory function) or a direct entity
                if callable(concept_obj) and not isinstance(concept_obj, Entity):
                    # Call the factory function to get the concept
                    concept = concept_obj()
                else:
                    # Direct entity reference
                    concept = concept_obj
                    
                # Add to the graph
                graph.add_concept(concept)
                
            except Exception as e:
                logging.error(f"Failed to add concept from {full_path}: {e}")
                
    return graph

def print_episode_log_excerpt(logs_dir: str, max_lines: int = 20) -> None:
    """Print an excerpt of the episode log file.
    
    Args:
        logs_dir: Directory containing the log file
        max_lines: Maximum number of lines to print
    """
    try:
        # Find the log file
        log_files = [f for f in os.listdir(logs_dir) if f.endswith('.log')]
        if not log_files:
            # Use logging.warning as finding no log files might be unexpected
            logging.warning("No log files found in %s", logs_dir)
            return
        
        log_file = os.path.join(logs_dir, log_files[0])
        
        # Read log file
        with open(log_file, 'r') as f:
            lines = f.readlines()
        
        # Get the last max_lines
        excerpt = lines[-max_lines:] if len(lines) > max_lines else lines
        
        # Log the excerpt using logging.info
        # TODO: Consider if this function is still needed if main log captures worker summaries.
        logging.info(f"Log excerpt ({min(max_lines, len(lines))} lines) from {log_file}:")
        for line in excerpt:
            logging.info("  " + line.strip())
    except Exception as e:
        # Log errors using logging.error
        logging.error(f"Error reading log file excerpt from {logs_dir}: {e}")

def run_episode_worker(worker_args: tuple) -> Dict[str, Any]:
    """
    Worker function executed in a separate process to run a single episode.
    
    Args:
        worker_args: Tuple containing all arguments needed for the worker
            (episode_num, cfg_dict, initial_graph_config, production_rule_configs, 
             policy_config, base_output_dir, main_seed, max_steps, checkpoint_frequency,
             save_visualizations, visualization_step_threshold, db_path, episode_timeout_seconds) # Added episode_timeout_seconds
            
    Returns:
        Dictionary containing the episode results
    """
    (episode_num, cfg_dict, initial_graph_config, production_rule_configs, 
     policy_config, base_output_dir, main_seed, max_steps, checkpoint_frequency,
     save_visualizations, visualization_step_threshold, db_path, episode_timeout_seconds) = worker_args
    
    # Configs should be passed in their correct types (dicts or DictConfig)
    # The check below is likely unnecessary if the main process handles types correctly.
    # Remove this check:
    # if isinstance(cfg_dict, str):
    #     cfg_dict = OmegaConf.create(cfg_dict)
    
    try:
        # Setup directories and logging for this episode
        episode_dirs = setup_episode_directories(base_output_dir, episode_num)
        episode_logger = setup_worker_logging(episode_num, episode_dirs['logs_dir'])
        
        # Configure root logger to redirect to episode logger
        # This is critical - capture logs from all modules
        root_logger = logging.getLogger()
        root_logger.setLevel(logging.INFO)
        
        # Save original handlers
        original_handlers = root_logger.handlers.copy()
        original_level = root_logger.level
        
        # Clear root logger handlers and redirect to episode logger
        root_logger.handlers = []
        for handler in episode_logger.handlers:
            root_logger.addHandler(handler)
        
        # Also configure framework module loggers explicitly
        for module_name in ['frame', 'frame.environments', 'frame.environments.math_env', 
                           'frame.policies', 'frame.knowledge_base', 'frame.tools.cache']: # Added proof_cache logger
            module_logger = logging.getLogger(module_name)
            module_logger.handlers = []
            for handler in episode_logger.handlers:
                module_logger.addHandler(handler)
            module_logger.propagate = False
        
        try:
            # Record start time
            start_time = time.time() 

            # Get worker PID for logging
            worker_pid = os.getpid()
            episode_logger.info(f"Episode {episode_num} worker started (PID: {worker_pid}) with seed {int(main_seed) + episode_num}")
            
            # Setup RNG for this worker
            # Convert main_seed to int if it's a string to avoid type error
            worker_seed = int(main_seed) + episode_num
            np.random.seed(worker_seed)
            worker_rng = np.random.RandomState(worker_seed)
            
            # Get timeout settings from config
            rule_application_timeout = DEFAULT_RULE_APPLICATION_TIMEOUT
            example_verification_timeout = DEFAULT_EXAMPLE_VERIFICATION_TIMEOUT
            # Get Z3 timeout from config dict, fallback to default
            z3_prover_timeout_from_config = cfg_dict.get('timeouts', {}).get('z3_prover', DEFAULT_Z3_PROVER_TIMEOUT)
            # Get Z3 example search timeout from config dict, fallback to default
            z3_example_search_timeout_from_config = cfg_dict.get('timeouts', {}).get('z3_example_search', 0.5) # Default 0.5s

            # Get Z3 usage flags
            use_z3_prover_flag = cfg_dict.get('z3_usage', {}).get('use_z3_prover', True) # Default True
            use_z3_example_search_flag = cfg_dict.get('z3_usage', {}).get('use_z3_example_search', True) # Default True

            if 'timeouts' in cfg_dict:
                if 'rule_application' in cfg_dict['timeouts']:
                    rule_application_timeout = cfg_dict['timeouts']['rule_application']
                    episode_logger.info(f"Using configured rule application timeout: {rule_application_timeout} seconds")
                
                if 'example_verification' in cfg_dict['timeouts']:
                    example_verification_timeout = cfg_dict['timeouts']['example_verification']
                    episode_logger.info(f"Using configured example verification timeout: {example_verification_timeout} seconds")
                    
                    # Set the global example verification timeout in entities module
                    from frame.knowledge_base.entities import set_example_verification_timeout
                    # Increase timeout significantly from the default 0.01s
                    # TODO: Make this configurable or tune based on typical verification time
                    increased_verification_timeout = 0.2 
                    set_example_verification_timeout(increased_verification_timeout)
                    episode_logger.info(f"Set example verification timeout to {increased_verification_timeout} seconds (increased from config value {example_verification_timeout})")
            
            # Create components
            # 1. Create initial graph
            initial_graph = create_initial_graph(initial_graph_config)
            episode_logger.info(f"Created initial graph with {len(initial_graph.nodes)} nodes")
            
            # 2. Initialize production rules
            production_rules = instantiate_rules(production_rule_configs)
            episode_logger.info(f"Initialized {len(production_rules)} production rules")
            
            # 3. Initialize policy
            policy = instantiate_policy(policy_config)
            episode_logger.info(f"Initialized policy: {policy.__class__.__name__}")
            
            # Set random number generator for policy
            policy.set_rng(worker_rng)
            
            # Initialize EpisodicCache
            episodic_cache = EpisodicCache(db_path)
            episode_logger.info(f"Episodic cache initialized at: {db_path}")
            
            # 4. Initialize environment with the configured timeout
            env = MathEnv(
                initial_graph=initial_graph,
                production_rules=production_rules,
                max_steps=max_steps,
                enumerate_actions=policy.requires_enumeration,
                rule_application_timeout=rule_application_timeout,
                episodic_cache=episodic_cache, # Pass renamed cache to env
                z3_prover_timeout=z3_prover_timeout_from_config, # Pass Z3 timeout
                z3_example_search_timeout=z3_example_search_timeout_from_config, # Pass Z3 example search timeout
                use_z3_prover=use_z3_prover_flag, # Pass Z3 prover usage flag
                use_z3_example_search=use_z3_example_search_flag # Pass Z3 example search usage flag
            )
            
            # 5. Set rules for the policy if it requires them
            if hasattr(policy, 'set_rules'):
                policy.set_rules(env.rules)
            
            # Run the episode
            graph, _ = env.reset()
            done = False
            truncated = False
            episode_reward = 0
            step = 0
            
            # Counters for statistics
            successful_steps = 0
            failed_steps = 0
            no_new_entity_steps = 0
            
            # Log initial state
            episode_logger.info(f"Episode {episode_num} started with seed {worker_seed}")
            
            # Log policy information
            if hasattr(policy, 'concept_selection'):
                episode_logger.info(f"Policy concept selection strategy: {policy.concept_selection}")
            if hasattr(policy, 'action_selection'):
                episode_logger.info(f"Policy action selection strategy: {policy.action_selection}")
            
            # Log initial valid actions count
            if env.enumerate_actions:
                episode_logger.info(f"Initial valid actions count: {len(env.valid_actions)}")
                # Log more info about available rules
                if hasattr(env, 'rules'):
                    rule_names = [rule.__class__.__name__ for rule in env.rules]
                    episode_logger.info(f"Available rules: {rule_names}")
            
            # Main episode loop
            start_episode_time = time.time() # Record worker start time
            worker_internally_timed_out = False # Flag for internal timeout

            while not (done or truncated):
                # --- Check for internal worker timeout --- 
                current_time = time.time()
                if current_time - start_episode_time > episode_timeout_seconds:
                    episode_logger.error(f"Worker for Episode {episode_num} exceeded internal time limit ({episode_timeout_seconds}s). Stopping episode.")
                    worker_internally_timed_out = True
                    break # Exit the episode loop
                # --- End internal timeout check --- 

                # Select action
                action = policy.select_action(env)
                
                if action is None:
                    episode_logger.info("No valid actions available")
                    break
                
                # Log action selection details
                if hasattr(env, 'enumerate_actions') and env.enumerate_actions and hasattr(env, 'valid_actions'):
                    episode_logger.info(f"Selected action from {len(env.valid_actions)} valid actions")
                
                if hasattr(action, 'rule') and hasattr(action, 'args'):
                    rule_name = action.rule.__class__.__name__ if hasattr(action.rule, '__class__') else str(action.rule)
                    episode_logger.info(f"Selected rule: {rule_name} with args: {action.args}")
                elif hasattr(action, 'rule'):
                    rule_name = action.rule.__class__.__name__ if hasattr(action.rule, '__class__') else str(action.rule)
                    episode_logger.info(f"Selected rule: {rule_name}")
                
                # Take action
                next_graph, reward, done, truncated, info = env.step(action)
                episode_reward += reward
                
                # Log reward for this step
                episode_logger.info(f"Step {step} reward: {reward}")
                
                # Log more details about the action if available - show that it was applied
                if hasattr(action, 'rule') and hasattr(action, 'args'):
                    rule_name = action.rule.__class__.__name__ if hasattr(action.rule, '__class__') else str(action.rule)
                    episode_logger.info(f"Applied rule: {rule_name} with args: {action.args}")
                elif hasattr(action, 'rule'):
                    rule_name = action.rule.__class__.__name__ if hasattr(action.rule, '__class__') else str(action.rule)
                    episode_logger.info(f"Applied rule: {rule_name}")
                
                # Update policy
                policy.update(env, action, reward, done)
                
                # Update graph
                graph = next_graph
                
                # Log step information
                step += 1
                if 'error' in info:
                    episode_logger.warning(f"Step {step} failed: {info['error']}")
                    failed_steps += 1
                elif 'new_entities' in info and info['new_entities']:
                    entity_ids = info['new_entities']
                    entity_names = []
                    for entity_id in entity_ids:
                        entity, _, _ = graph.get_node(entity_id)
                        name = entity.name if hasattr(entity, 'name') else str(entity)
                        entity_names.append(name)
                    
                    episode_logger.info(f"Step {step}: Created {len(entity_ids)} new entities: {entity_names}")
                    successful_steps += 1
                else:
                    episode_logger.info(f"Step {step}: No new entities created")
                    no_new_entity_steps += 1
                
                # Save checkpoint if needed
                if step % checkpoint_frequency == 0:
                    save_checkpoint(env, step, episode_dirs['checkpoints_dir'], episode_logger)
                
                # Check if we've reached the maximum number of steps
                if step >= max_steps:
                    episode_logger.info(f"Reached maximum number of steps ({max_steps})")
                    break
            
            # --- Post-Loop Processing ---
            final_graph_path = None # Initialize as None

            # --- Graph Saving --- 
            if not worker_internally_timed_out:
                # Save final graph for this episode (only if not timed out)
                timestamp = get_timestamp()
                final_graph_path = os.path.join(episode_dirs['graphs_dir'], f"final_graph_{timestamp}.pkl")
                try:
                    episode_logger.info("Attempting to save final graph...")
                    graph.save(final_graph_path)
                    episode_logger.info(f"Episode graph saved to {final_graph_path}")
                except Exception as e:
                    episode_logger.error(f"Failed to save episode graph: {e}", exc_info=True)
                    final_graph_path = None # Set to None if saving failed
            else: # worker_internally_timed_out is True
                episode_logger.warning("Internal worker timeout reached. Skipping final graph SAVE.")

            # --- Visualization --- 
            attempt_visualization = False # Flag to determine if we should try to visualize
            if save_visualizations: # Only proceed if save_visualizations is true globally
                if not worker_internally_timed_out:
                    # Normal completion: respect visualization_step_threshold
                    if visualization_step_threshold is None or step < visualization_step_threshold:
                        attempt_visualization = True
                    else:
                        episode_logger.info(f"Skipping visualizations (normal completion) - episode steps ({step}) exceeded threshold ({visualization_step_threshold})")
                else: # Worker internally timed out
                    if step <= 150:
                        episode_logger.info(f"Internal worker timeout occurred, but step count ({step}) <= 150. Attempting visualization.")
                        attempt_visualization = True
                    else:
                        episode_logger.info(f"Skipping visualizations (internal timeout) - step count ({step}) > 150 and worker internally timed out.")
            else: # save_visualizations is False
                episode_logger.info("Skipping visualizations as save_visualizations is False.")

            if attempt_visualization:
                # Generate visualizations for this episode if under step threshold (only if not timed out)
                # --- Add Timeout for Visualization ---
                viz_executor = None
                try:
                    # The timestamp for visualization filename should be defined even if graph saving was skipped
                    # Re-fetch timestamp if it wasn't set during graph saving (i.e. if worker timed out)
                    timestamp_for_viz = timestamp if 'timestamp' in locals() and timestamp else get_timestamp()
                    episode_logger.info("Attempting to generate visualizations (5s timeout)...")
                    viz_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
                    future = viz_executor.submit(
                        generate_visualizations,
                        graph,
                        episode_dirs['viz_dir'],
                        timestamp_for_viz, # Use potentially re-fetched timestamp
                        episode_logger
                    )
                    try:
                        future.result(timeout=5.0) # 5 second timeout
                        episode_logger.info(f"Visualizations saved to {episode_dirs['viz_dir']}")
                    except FuturesTimeoutError:
                         episode_logger.error("Visualization generation timed out after 5 seconds.")
                         if future: future.cancel() # Attempt to cancel
                    except Exception as viz_err_inner:
                         episode_logger.error(f"Visualization generation failed internally: {viz_err_inner}", exc_info=True)

                except Exception as viz_setup_err:
                    # Catch errors during setup/submission, less likely
                    episode_logger.error(f"Error setting up visualization generation: {viz_setup_err}", exc_info=True)
                finally:
                    if viz_executor:
                        viz_executor.shutdown(wait=False) # Ensure executor is cleaned up
                # --- End Timeout for Visualization ---

            # Log episode statistics
            episode_logger.info(f"Episode {episode_num} finished with reward: {episode_reward}")
            episode_logger.info(f"Successful steps: {successful_steps}")
            episode_logger.info(f"Failed steps: {failed_steps}")
            episode_logger.info(f"No new entity steps: {no_new_entity_steps}")

            # Close environment (with error handling)
            try:
                if hasattr(env, 'close'):
                    episode_logger.info("Attempting to close environment...")
                    env.close()
                    episode_logger.info("Environment closed.")
            except Exception as e:
                 episode_logger.error(f"Error closing environment: {e}", exc_info=True)

            # Calculate and log duration
            end_time = time.time()
            duration = end_time - start_time
            episode_logger.info(f"Episode {episode_num} (PID: {worker_pid}) duration: {duration:.2f} seconds")
            
            # Extract top entity names for reporting (with error handling)
            top_entity_names = []
            try:
                # Get the top 5 most interesting entities if available
                top_entities = []
                if hasattr(policy, 'get_top_entities'):
                    top_entities = policy.get_top_entities(graph, 5)
                elif hasattr(graph, 'get_entities'):
                    # If policy doesn't have the method, just get the last 5 entities added
                    all_entities = graph.get_entities()
                    top_entities = all_entities[-5:] if len(all_entities) > 5 else all_entities
                
                top_entity_names = [entity.name if hasattr(entity, 'name') else str(entity) for entity in top_entities]
            except Exception as e:
                episode_logger.error(f"Error getting top entities: {e}", exc_info=True)

            # --- Get final timeout stats --- 
            final_timeout_count = env.timeout_count if hasattr(env, 'timeout_count') else 0
            final_total_steps = step # Total steps taken in the episode
            final_timed_out_rules = {}
            if hasattr(env, 'timed_out_rules'):
                 try:
                      final_timed_out_rules = env.timed_out_rules.copy()
                 except Exception as copy_err:
                      episode_logger.error(f"Error copying timed_out_rules: {copy_err}")
            # --- End get timeout stats --- 

            # Return results based on whether it was an internal timeout or normal completion
            if worker_internally_timed_out:
                # Ensure logs are flushed before returning if possible (might still hang in finally)
                # logging.shutdown() # Avoid calling shutdown here, let the finally block handle cleanup
                return {
                    "episode_num": episode_num,
                    "reward": episode_reward, # Return reward accumulated so far
                    "status": "timeout", # Indicate timeout
                    "error": f"Worker exceeded internal time limit of {episode_timeout_seconds} seconds after {step} steps.",
                    "total_steps": final_total_steps,
                    "timeout_count": final_timeout_count, # Rule timeouts that happened before worker timeout
                    "successful_steps": successful_steps,
                    "failed_steps": failed_steps,
                    "no_new_entity_steps": no_new_entity_steps,
                    "duration_seconds": time.time() - start_time, # Use current time for duration
                    "final_graph_path": None, # Graph saving skipped
                    "top_entity_names": top_entity_names, # Still try to report entities found
                    "timeout_stats": {
                        'total_timeouts': final_timeout_count,
                        'timed_out_rules': final_timed_out_rules
                    },
                    # Include directory info for potential partial logs/checkpoints
                    "episode_dir": episode_dirs["episode_dir"],
                    "logs_dir": episode_dirs["logs_dir"],
                    "graphs_dir": episode_dirs["graphs_dir"],
                    "viz_dir": episode_dirs["viz_dir"],
                    "checkpoints_dir": episode_dirs["checkpoints_dir"]
                }
            else: # Normal completion
                return {
                    "episode_num": episode_num,
                    "reward": episode_reward,
                    "status": "completed",
                    "total_steps": final_total_steps, # Include total steps
                    "timeout_count": final_timeout_count, # Include timeout count
                    "successful_steps": successful_steps,
                    "failed_steps": failed_steps,
                    "no_new_entity_steps": no_new_entity_steps,
                    "duration_seconds": duration, # Use calculated duration
                    "final_graph_path": final_graph_path, # Include path if saved
                    "top_entity_names": top_entity_names,
                    "timeout_stats": {
                        'total_timeouts': final_timeout_count,
                        'timed_out_rules': final_timed_out_rules
                    },
                    "episode_dir": episode_dirs["episode_dir"],
                    "logs_dir": episode_dirs["logs_dir"],
                    "graphs_dir": episode_dirs["graphs_dir"],
                    "viz_dir": episode_dirs["viz_dir"],
                    "checkpoints_dir": episode_dirs["checkpoints_dir"]
                }
            
        finally:
            # Restore original root logger configuration
            root_logger.handlers = original_handlers
            root_logger.setLevel(original_level)
            # Close episodic cache connection (with error handling and timeout)
            cache_close_executor = None
            try:
                if 'episodic_cache' in locals() and episodic_cache:
                    # --- Add Timeout for Cache Close --- 
                    cache_close_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
                    future = cache_close_executor.submit(episodic_cache.close)
                    try:
                        # Increase timeout for cache closing
                        cache_close_timeout = 10.0 # Increased from 3.0
                        episode_logger.info(f"Attempting to close episodic cache connection ({cache_close_timeout}s timeout)...")
                        future.result(timeout=cache_close_timeout) # Use the increased timeout
                        episode_logger.info("Episodic cache connection closed successfully.")
                    except FuturesTimeoutError:
                        episode_logger.error(f"Episodic cache close timed out after {cache_close_timeout} seconds.") # Use variable in message
                        if future: future.cancel()
                    except Exception as cache_close_err_inner:
                        episode_logger.error(f"Episodic cache close failed internally: {cache_close_err_inner}", exc_info=True)
                    # --- End Timeout for Cache Close --- 
            except Exception as e:
                 # Catch errors during setup/submission of cache close, less likely
                 episode_logger.error(f"Error setting up episodic cache close: {e}", exc_info=True)
            finally:
                 if cache_close_executor:
                      cache_close_executor.shutdown(wait=False)
            
    except Exception as e:
        # Log error
        if 'episode_logger' in locals():
            # Include PID in error log if possible
            worker_pid = os.getpid() if 'worker_pid' not in locals() else worker_pid
            episode_logger.error(f"Episode {episode_num} (PID: {worker_pid}) failed: {e}\n{traceback.format_exc()}")
        else:
            # Use logging.error here. It might go to a default root logger if episode logger failed.
            logging.error(f"Error in worker for episode {episode_num} before logger init: {e}")
            logging.error(traceback.format_exc())

        # Return failure information
        episode_dirs_dict = episode_dirs if 'episode_dirs' in locals() else {}
        # Ensure essential keys are present in the failure dictionary
        failure_result = {
            "episode_num": episode_num,
            "reward": 0.0,
            "status": "failed",
            "error": str(e),
            "traceback": traceback.format_exc(),
            "total_steps": step if 'step' in locals() else 0,
            "timeout_count": final_timeout_count if 'final_timeout_count' in locals() else 0,
            "successful_steps": successful_steps if 'successful_steps' in locals() else 0,
            "failed_steps": failed_steps if 'failed_steps' in locals() else 0,
            "no_new_entity_steps": no_new_entity_steps if 'no_new_entity_steps' in locals() else 0,
            "duration_seconds": time.time() - start_time if 'start_time' in locals() else 0.0,
            "final_graph_path": None,
            "top_entity_names": [],
            "timeout_stats": {
                 'total_timeouts': final_timeout_count if 'final_timeout_count' in locals() else 0,
                 'timed_out_rules': final_timed_out_rules if 'final_timed_out_rules' in locals() else {}
            },
            **episode_dirs_dict # Add directory info if available
        }
        return failure_result

@dataclass
class TheoryBuilderConfig:
    """Configuration for theory building experiments."""
    policy_type: str
    policy_params: Dict[str, Any]
    max_steps: int
    num_episodes: int # M: Episodes per function in N x M mode
    seed: int
    checkpoint_frequency: int
    print_concepts: bool = False
    print_graph: bool = False
    check_duplicates: bool = False
    save_visualizations: bool = True
    num_workers: int = 1
    episode_timeout_seconds: int = 10
    visualization_step_threshold: Optional[int] = None
    mark_best_episode: bool = True # Add new config flag
    # New flags for N x M evaluation
    evaluate_multiple_interestingness: bool = False # Enable N x M mode
    num_interestingness_to_generate: int = 1 # N: Functions to generate in N x M mode

class TheoryBuilder:
    """Coordinates the process of mathematical discovery."""
    
    def __init__(self, cfg: DictConfig):
        """
        Initialize the theory builder.
        
        Args:
            cfg: Hydra configuration
        """
        # Check if we need to generate a random seed (if seed=0 or not provided)
        self._maybe_set_random_seed(cfg)
        
        # Set up output directories (must happen before processing config for flags)
        self._setup_output_directories()
        
        # Determine episodic cache path
        # Look for episodic_cache_db_path in the config
        episodic_cache_path_from_cfg = OmegaConf.select(cfg, 'episodic_cache_db_path', default=None)
        if not episodic_cache_path_from_cfg:
            episodic_cache_path_from_cfg = OmegaConf.select(cfg, 'theory_building.episodic_cache_db_path', default=None)
            
        if episodic_cache_path_from_cfg:
            # Use user-provided path directly
            self.db_path = episodic_cache_path_from_cfg
            logging.info(f"Using provided episodic cache DB path: {self.db_path}")
        else:
            # Default path inside the main output directory
            self.db_path = os.path.join(self.output_dir, "episodic_cache.db")
            logging.info(f"Episodic cache DB path not provided, using default: {self.db_path}")
        
        # Process configuration first to get flags like evaluate_multiple_interestingness
        self.config = self._process_config(cfg)
        
        # Store raw configuration dictionary for passing to workers
        # Make a deep copy to avoid issues if cfg is modified later (e.g., in _maybe_generate)
        self.cfg_dict = copy.deepcopy(OmegaConf.to_container(cfg, resolve=True))

        # Generate interestingness function during init ONLY if NOT in N x M mode
        self.interestingness_function_path = None
        if not self.config.evaluate_multiple_interestingness:
            # Pass the original cfg object here, as this method might modify it
            self._maybe_generate_interestingness_function(cfg)
            # Store the path if generated or provided and NOT in N x M mode
            # Re-read from potentially modified cfg
            if hasattr(cfg.policy, 'params') and hasattr(cfg.policy.params, 'interestingness_function_path'):
                self.interestingness_function_path = cfg.policy.params.interestingness_function_path
            elif hasattr(cfg.policy, 'interestingness_function_path'):
                self.interestingness_function_path = cfg.policy.interestingness_function_path
        
        # Extract configurations needed by workers - don't instantiate components yet
        # 1. Store initial graph configuration
        if hasattr(cfg, 'initial_state'):
            # Use the stored cfg_dict which reflects the initial state
            self.initial_graph_config = self.cfg_dict.get('initial_state', {})
        elif hasattr(cfg, 'theory_building') and hasattr(cfg.theory_building, 'initial_state'):
            self.initial_graph_config = self.cfg_dict.get('theory_building', {}).get('initial_state', {})
        else:
            self.initial_graph_config = {}
            logging.warning("No initial state found in configuration, using empty graph")
        
        # 2. Store production rule configurations
        if hasattr(cfg, 'production_rules'):
            self.production_rule_configs = self.cfg_dict.get('production_rules', {})
        elif hasattr(cfg, 'theory_building') and hasattr(cfg.theory_building, 'production_rules'):
            self.production_rule_configs = self.cfg_dict.get('theory_building', {}).get('production_rules', {})
        else:
            self.production_rule_configs = {}
            logging.warning("No production rules found in configuration")
        
        # 3. Store policy configuration (reflecting any changes from _maybe_generate...)
        # Use the potentially modified cfg object to get the latest policy config
        # Store as DictConfig
        self.policy_config = copy.deepcopy(cfg.policy) # Keep as DictConfig
        
        # Log the interestingness function being used if available (only in standard mode)
        if self.interestingness_function_path:
            logging.info(f"(Standard Mode Init) Using interestingness function: {self.interestingness_function_path}")
        elif not self.config.evaluate_multiple_interestingness:
             logging.info("(Standard Mode Init) No specific interestingness function path set.")

    def run(self) -> float:
        """Run the theory building process based on the configured mode."""
        start_time = time.time()
        
        if self.config.evaluate_multiple_interestingness:
            # --- N x M Evaluation Mode ---
            logging.info(f"Running in N x M evaluation mode: Generating {self.config.num_interestingness_to_generate} functions, evaluating each for {self.config.num_episodes} episodes.")
            overall_results = []
            
            for func_idx in range(self.config.num_interestingness_to_generate):
                logging.info(f"\n{'='*20} Starting Evaluation for Function {func_idx + 1}/{self.config.num_interestingness_to_generate} {'='*20}")
                
                # 1. Generate the interestingness function
                logging.info(f"Generating interestingness function {func_idx + 1}...")
                generated_path = self._generate_interestingness_function()
                
                if generated_path is None:
                    logging.error(f"Failed to generate interestingness function {func_idx + 1}. Skipping evaluation.")
                    overall_results.append({
                        "function_index": func_idx,
                        "status": "generation_failed",
                        "avg_reward": 0.0,
                        "results": []
                    })
                    continue
                
                logging.info(f"Generated interestingness function {func_idx + 1}: {generated_path}")
                
                # 2. Prepare config and output dir for this function's evaluation
                func_output_dir = os.path.join(self.output_dir, f"function_{func_idx}")
                os.makedirs(func_output_dir, exist_ok=True)
                logging.info(f"Results directory for function {func_idx + 1}: {func_output_dir}")
                
                # Create a deep copy of the base policy config and update the path
                # Use self.cfg_dict to get the *original* policy structure before potential __init__ modifications
                base_policy_config_dict = self.cfg_dict.get('policy', {})
                func_policy_config = copy.deepcopy(base_policy_config_dict)

                if 'params' in func_policy_config:
                    func_policy_config['params']['interestingness_function_path'] = generated_path
                    # Ensure concept selection uses it if applicable (might be redundant if type implies it)
                    if 'concept_selection' in func_policy_config['params']:
                         func_policy_config['params']['concept_selection'] = "INTERESTINGNESS"
                else: # Handle cases where params might not exist (e.g. simple policy type)
                     func_policy_config['interestingness_function_path'] = generated_path
                     if 'concept_selection' in func_policy_config:
                         func_policy_config['concept_selection'] = "INTERESTINGNESS"

                logging.debug(f"Policy config for function {func_idx + 1}: {func_policy_config}")

                # 3. Evaluate the function
                logging.info(f"Evaluating function {func_idx + 1} for {self.config.num_episodes} episodes...")
                
                # Run M episodes for this specific function
                func_results = self._evaluate_interestingness_function(
                    function_index=func_idx,
                    policy_config=func_policy_config, # Pass the modified config dict
                    base_output_dir=func_output_dir,
                    num_episodes_to_run=self.config.num_episodes # M
                )
                
                overall_results.append(func_results)
                logging.info(f"Evaluation complete for function {func_idx + 1}. Average reward: {func_results['avg_reward']:.4f}")
                logging.info(f"{'='*60}")

            # --- Log overall N x M summary ---
            logging.info(f"\n{'='*30} OVERALL N x M EVALUATION SUMMARY {'='*30}")
            total_completed = 0
            total_reward_across_all = 0.0
            generation_failures = 0
            best_overall_avg_reward = float('-inf')
            best_function_index = -1
            avg_rewards_per_function = [] # To calculate std dev of averages

            for res in overall_results:
                status = res.get("status", "unknown")
                avg_rew = res.get("avg_reward", 0.0)
                std_dev = res.get("reward_std_dev", 0.0) # Get std dev
                num_completed = res.get("completed_episodes", 0)
                num_failed = res.get("failed_episodes", 0)
                num_timeout = res.get("timeout_episodes", 0)
                
                logging.info(f"Function {res['function_index'] + 1}: Status={status}, Avg Reward={avg_rew:.4f}, Std Dev={std_dev:.4f}, Completed={num_completed}, Failed={num_failed}, Timeout={num_timeout}")
                
                if status == "evaluation_completed":
                    total_completed += num_completed
                    # Use the total reward for the function to calculate overall average correctly
                    total_reward_across_all += res.get("total_reward", 0.0) 
                    avg_rewards_per_function.append(avg_rew) # Collect average reward
                    if avg_rew > best_overall_avg_reward:
                         best_overall_avg_reward = avg_rew
                         best_function_index = res['function_index']
                elif status == "generation_failed":
                     generation_failures += 1

            overall_avg_reward = total_reward_across_all / total_completed if total_completed > 0 else 0.0
            std_dev_of_avg_rewards = np.std(avg_rewards_per_function) if len(avg_rewards_per_function) > 0 else 0.0 # Calculate std dev of averages
            
            logging.info(f"Total Functions Generated: {self.config.num_interestingness_to_generate - generation_failures}/{self.config.num_interestingness_to_generate}")
            logging.info(f"Total Completed Episodes (across all successful evaluations): {total_completed}")
            logging.info(f"Overall Average Reward (across all completed episodes): {overall_avg_reward:.4f}")
            logging.info(f"Std Dev of Average Rewards (across functions): {std_dev_of_avg_rewards:.4f}") # Log std dev of averages
            if best_function_index != -1:
                 logging.info(f"Best Performing Function Index: {best_function_index + 1} (Avg Reward: {best_overall_avg_reward:.4f})")
            
            end_time = time.time()
            logging.info(f"Total N x M evaluation time: {end_time - start_time:.2f} seconds")
            logging.info(f"{'='*80}")
            
            # Return overall average reward as the final result for N x M mode
            return overall_avg_reward

        else:
            # --- Standard Mode ---
            logging.info(f"Running in standard mode for {self.config.num_episodes} episodes.")
            
            # The single function path (if any) is already reflected in self.policy_config from __init__
            # The base output dir is self.output_dir
            standard_results = self._evaluate_interestingness_function(
                 function_index=0, # Only one "function group" in this mode
                 policy_config=self.policy_config, # Use the config prepared during __init__
                 base_output_dir=self.output_dir, # Save directly in the main output dir
                 num_episodes_to_run=self.config.num_episodes # M
            )
            
            end_time = time.time()
            logging.info(f"Standard run completed in {end_time - start_time:.2f} seconds")
            logging.info(f"Average reward: {standard_results['avg_reward']:.4f}")
            logging.info(f"Reward std dev: {standard_results['reward_std_dev']:.4f}") # Log std dev
            logging.info(f"Completed: {standard_results['completed_episodes']}, Failed: {standard_results['failed_episodes']}, Timeout: {standard_results['timeout_episodes']}")
            
            # Also mark the best episode globally in standard mode, ONLY IF more than one episode was run AND marking is enabled
            if self.config.mark_best_episode and self.config.num_episodes > 1 and standard_results.get("best_episode_details"):
                pass # TODO(_; 4/24): This is buggy, I'm just leaving it out for now.
                # self._mark_best_episode_global(standard_results["best_episode_details"])
            elif not self.config.mark_best_episode:
                logging.info("Skipping global best episode marking as mark_best_episode is False.")
            elif self.config.num_episodes <= 1:
                logging.info("Skipping best episode marking as only one episode was run.")

            # Return the average reward from the standard run
            return standard_results['avg_reward']

    def _evaluate_interestingness_function(
        self,
        function_index: int,
        policy_config: Union[DictConfig, Dict[str, Any]], # Accept DictConfig or dict
        base_output_dir: str,
        num_episodes_to_run: int
    ) -> Dict[str, Any]:
        """
        Runs M episodes using available workers, assigning new episodes as workers finish.

        Args:
            function_index: The index (0 to N-1) of the function group being evaluated.
            policy_config: The specific policy configuration (DictConfig or dict) to use.
            base_output_dir: The directory where results for THIS function's episodes should be stored.
            num_episodes_to_run: M, the number of episodes to run for evaluation.

        Returns:
            A dictionary containing aggregated results for this function's evaluation run.
        """
        # --- Log Interestingness Function Content ---
        interestingness_func_path = None
        if isinstance(policy_config, (DictConfig, dict)):
            # Try to get from params first, then directly
            if 'params' in policy_config and isinstance(policy_config['params'], (DictConfig, dict)):
                interestingness_func_path = policy_config['params'].get('interestingness_function_path')
            if not interestingness_func_path:
                interestingness_func_path = policy_config.get('interestingness_function_path')

        if interestingness_func_path:
            # Resolve path if it's relative
            if not os.path.isabs(interestingness_func_path):
                original_path_for_log = interestingness_func_path
                try:
                    import hydra # Import hydra here for utility access
                    resolved_path = hydra.utils.to_absolute_path(interestingness_func_path)
                    logging.info(f"(Eval Func {function_index+1}) Resolved relative path '{original_path_for_log}' to '{resolved_path}' using hydra.utils.")
                    interestingness_func_path = resolved_path
                except ImportError:
                    logging.warning(f"(Eval Func {function_index+1}) Hydra utilities not available. Attempting to resolve '{original_path_for_log}' using os.path.abspath.")
                    current_cwd_for_log = os.getcwd()
                    resolved_path = os.path.abspath(interestingness_func_path)
                    logging.info(f"(Eval Func {function_index+1}) CWD for abspath: {current_cwd_for_log}. Resolved relative path '{original_path_for_log}' to '{resolved_path}' using os.path.abspath.")
                    interestingness_func_path = resolved_path
                except Exception as e:
                    logging.error(f"(Eval Func {function_index+1}) Error resolving relative path '{original_path_for_log}': {e}. Will attempt to use as is.")
            try:
                if os.path.exists(interestingness_func_path):
                    with open(interestingness_func_path, 'r') as f_content:
                        content = f_content.read()
                    logging.info(f"(Eval Func {function_index+1}) Using interestingness function from: {interestingness_func_path}")
                    logging.info(f"--- Interestingness Function Content (Eval Func {function_index+1}) ---\n{content}\n--- End Content ---")
                else:
                    logging.warning(f"(Eval Func {function_index+1}) Interestingness function path specified but not found: {interestingness_func_path}")
            except Exception as e:
                logging.error(f"(Eval Func {function_index+1}) Error reading interestingness function file {interestingness_func_path}: {e}")
        else:
            logging.info(f"(Eval Func {function_index+1}) No specific interestingness function path provided in policy_config for this run.")
        # --- End Log Interestingness Function Content ---

        total_reward_for_func = 0.0 # Initialize as float
        best_episode_reward = float('-inf')
        best_episode_result_local = None
        results = []
        completed_episode_rewards = []
        total_steps_across_episodes = 0 # New accumulator for steps
        episodes_with_steps = 0 # Counter for averaging steps
        completed_episodes = 0
        failed_episodes = 0
        timeout_episodes = 0 # Tracks episodes that timed out *internally* within the worker

        # Prepare arguments for each worker process for this function
        all_worker_args = []
        for m in range(num_episodes_to_run):
            global_episode_num = function_index * self.config.num_episodes + m
            # Ensure policy_config is a plain dict if it came as DictConfig
            policy_config_dict = OmegaConf.to_container(policy_config, resolve=True) if isinstance(policy_config, DictConfig) else policy_config
            args = (
                global_episode_num,
                self.cfg_dict,
                self.initial_graph_config,
                self.production_rule_configs,
                policy_config_dict, # Pass as dict
                base_output_dir,
                self.config.seed,
                self.config.max_steps,
                self.config.checkpoint_frequency,
                self.config.save_visualizations,
                self.config.visualization_step_threshold,
                self.db_path,
                self.config.episode_timeout_seconds # Worker's internal timeout
            )
            all_worker_args.append(args)

        pool_size = min(self.config.num_workers, num_episodes_to_run)
        logging.info(f"(Eval Func {function_index+1}) Using {pool_size} workers for {num_episodes_to_run} episodes.")

        # --- Choose Execution Strategy: Sequential with Isolation or Parallel ---
        if pool_size == 1 and num_episodes_to_run > 0:
            # --- Sequential Execution with Process Isolation ---
            logging.info(f"(Eval Func {function_index+1}) Running episodes sequentially with fresh processes for isolation.")
            for worker_args_tuple in all_worker_args:
                global_ep_num_seq = worker_args_tuple[0]
                local_episode_index_seq = global_ep_num_seq % self.config.num_episodes
                single_episode_pool = None
                result = None # Initialize result for this episode scope
                try:
                    # Create a new pool for *this episode only*
                    single_episode_pool = pebble.ProcessPool(max_workers=1)
                    future = single_episode_pool.schedule(run_episode_worker, args=(worker_args_tuple,), timeout=self.config.episode_timeout_seconds * 1.2)
                    
                    # Wait for this single episode to complete or timeout via pebble
                    result = future.result() 
                    results.append(result) # Store the raw result

                except ProcessExpired as e:
                    logging.critical(f'(Eval Func {function_index+1}) Worker for Episode (Global: {global_ep_num_seq}) TERMINATED by orchestrator: {e}')
                    result = {"episode_num": global_ep_num_seq, "reward": 0.0, "status": "failed_orchestrator_timeout_terminated", "error": str(e)}
                    results.append(result)
                except FuturesTimeoutError: # Should be rare with pebble, but handle defensively
                    logging.error(f'(Eval Func {function_index+1}) Orchestrator timed out waiting for result from future for Episode (Global: {global_ep_num_seq})')
                    result = {"episode_num": global_ep_num_seq, "reward": 0.0, "status": "failed_orchestrator_result_timeout", "error": "Orchestrator result timeout"}
                    results.append(result)
                except Exception as exc:
                    logging.error(f'(Eval Func {function_index+1}) Episode (Global: {global_ep_num_seq}) generated an exception during execution/result retrieval: {exc}', exc_info=True)
                    result = {"episode_num": global_ep_num_seq, "reward": 0.0, "status": "failed_in_main", "error": str(exc)}
                    results.append(result)
                finally:
                    # Ensure the single-episode pool is always cleaned up
                    if single_episode_pool:
                        single_episode_pool.close()
                        single_episode_pool.join()
                        logging.debug(f"(Eval Func {function_index+1}) Single episode pool for global_ep_num {global_ep_num_seq} closed and joined.")
                
                # Process the result AFTER the try/except/finally block
                if result:
                    # Process logging and simple updates
                    self._process_episode_result_logging(result, global_ep_num_seq, local_episode_index_seq, function_index, num_episodes_to_run)
                    
                    # Update aggregate counters based on the result
                    status = result.get("status", "unknown")
                    if status in ["completed", "timeout"]:
                        total_reward_for_func += result.get("reward", 0.0)
                        total_steps_across_episodes += result.get('total_steps', 0)
                        episodes_with_steps += 1
                        completed_episode_rewards.append(result.get("reward", 0.0)) # Append reward if completed or timeout
                    if status == "completed":
                        completed_episodes += 1
                        reward = result.get("reward", 0.0)
                        # completed_episode_rewards.append(reward) # Moved up
                        if reward > best_episode_reward:
                            best_episode_reward = reward
                            best_episode_result_local = result # Keep track of the best result dict
                    elif status == "timeout":
                        timeout_episodes += 1
                    elif status.startswith("failed"):
                        failed_episodes += 1
            # --- End Sequential Execution ---
            
        elif num_episodes_to_run > 0: # Check needed episodes > 0 before creating pool
            # --- Parallel Execution with Blocks ---
            logging.info(f"(Eval Func {function_index+1}) Running episodes in parallel with pool size {pool_size}, in blocks.")
            
            num_blocks = (num_episodes_to_run + pool_size - 1) // pool_size  # Calculate total number of blocks
            logging.info(f"(Eval Func {function_index+1}) Total episodes: {num_episodes_to_run}, Pool size: {pool_size}, Number of blocks: {num_blocks}")

            for block_idx in range(num_blocks):
                start_episode_idx_in_block = block_idx * pool_size
                end_episode_idx_in_block = min((block_idx + 1) * pool_size, num_episodes_to_run)
                
                current_block_worker_args = all_worker_args[start_episode_idx_in_block:end_episode_idx_in_block]
                
                if not current_block_worker_args: # Should not happen if logic is correct
                    logging.warning(f"(Eval Func {function_index+1}) Block {block_idx + 1}/{num_blocks} has no worker arguments. Skipping.")
                    continue

                logging.info(f"(Eval Func {function_index+1}) Starting Block {block_idx + 1}/{num_blocks}, "
                             f"processing episodes {start_episode_idx_in_block + 1} to {end_episode_idx_in_block} "
                             f"(global indices from worker_args).")

                try:
                    # Create a new pool for *this block only*
                    with pebble.ProcessPool(max_workers=pool_size) as executor:
                        future_to_episode_in_block = {
                            executor.schedule(run_episode_worker, args=(args,), timeout=self.config.episode_timeout_seconds * 1.2):
                            args[0] for args in current_block_worker_args # args[0] is global_episode_num
                        }

                        for future in concurrent.futures.as_completed(future_to_episode_in_block):
                            global_ep_num = future_to_episode_in_block[future]
                            local_episode_index = global_ep_num % self.config.num_episodes # Relative to the function's M episodes
                            result = None 
                            try:
                                result = future.result(timeout=10.0) 
                                results.append(result)
                            except ProcessExpired as e:
                                logging.critical(f'(Eval Func {function_index+1}) Worker for Episode (Global: {global_ep_num}) TERMINATED by orchestrator (Block {block_idx+1}): {e}')
                                result = {"episode_num": global_ep_num, "reward": 0.0, "status": "failed_orchestrator_timeout_terminated", "error": str(e)}
                                results.append(result)
                            except FuturesTimeoutError:
                                logging.error(f'(Eval Func {function_index+1}) Orchestrator timed out waiting for result from future for Episode (Global: {global_ep_num}) (Block {block_idx+1})')
                                result = {"episode_num": global_ep_num, "reward": 0.0, "status": "failed_orchestrator_result_timeout", "error": "Orchestrator result timeout"}
                                results.append(result)
                            except Exception as exc:
                                logging.error(f'(Eval Func {function_index+1}) Episode (Global: {global_ep_num}) generated an exception during result processing (Block {block_idx+1}): {exc}', exc_info=True)
                                result = {"episode_num": global_ep_num, "reward": 0.0, "status": "failed_in_main", "error": str(exc)}
                                results.append(result)

                            if result:
                                self._process_episode_result_logging(result, global_ep_num, local_episode_index, function_index, num_episodes_to_run)
                                status = result.get("status", "unknown")
                                if status in ["completed", "timeout"]:
                                    total_reward_for_func += result.get("reward", 0.0)
                                    total_steps_across_episodes += result.get('total_steps', 0)
                                    episodes_with_steps += 1
                                    completed_episode_rewards.append(result.get("reward", 0.0))
                                if status == "completed":
                                    completed_episodes += 1
                                    reward = result.get("reward", 0.0)
                                    if reward > best_episode_reward:
                                        best_episode_reward = reward
                                        best_episode_result_local = result
                                elif status == "timeout": # Worker internal timeout
                                    timeout_episodes += 1
                                elif status.startswith("failed"):
                                    failed_episodes += 1
                    # Pool for this block is automatically closed and joined here by the 'with' statement
                    logging.info(f"(Eval Func {function_index+1}) Block {block_idx + 1}/{num_blocks} completed and its ProcessPool terminated.")

                except Exception as block_pool_error:
                    logging.error(f"(Eval Func {function_index+1}) Error occurred during ProcessPool execution for Block {block_idx + 1}: {block_pool_error}", exc_info=True)
                    # Mark all episodes in this block as failed due to the pool error for this block
                    for args_tuple in current_block_worker_args:
                        global_ep_num_fail = args_tuple[0]
                        results.append({"episode_num": global_ep_num_fail, "reward": 0.0, "status": "failed_block_pool_error", "error": str(block_pool_error)})
                        failed_episodes += 1 
                        # No reward, no steps for these.
                    # Continue to the next block if possible
            # --- End Parallel Execution with Blocks ---
        else:
             logging.info(f"(Eval Func {function_index+1}) No episodes to run (num_episodes_to_run={num_episodes_to_run}).")


        # --- Final Processing and Return (Common for both Sequential and Parallel) ---
        if self.config.mark_best_episode and num_episodes_to_run > 1 and best_episode_result_local:
            self._mark_best_episode_local(best_episode_result_local)

        avg_reward = total_reward_for_func / episodes_with_steps if episodes_with_steps > 0 else 0.0
        reward_std_dev = np.std(completed_episode_rewards) if episodes_with_steps > 0 else 0.0 # Use episodes_with_steps and the populated list
        avg_steps = total_steps_across_episodes / episodes_with_steps if episodes_with_steps > 0 else 0.0

        logging.info(f"--- Evaluation Summary for Function {function_index + 1} ---")
        logging.info(f"Average Reward (Completed + TimedOut): {avg_reward:.4f}")
        logging.info(f"Reward Std Dev (Completed + TimedOut): {reward_std_dev:.4f}") # Updated log message
        logging.info(f"Average Steps (Completed + TimedOut): {avg_steps:.2f}")
        logging.info(f"Total Reward (Completed + TimedOut): {total_reward_for_func:.4f}")
        logging.info(f"Completed Episodes: {completed_episodes}")
        logging.info(f"Failed Episodes: {failed_episodes}")
        logging.info(f"Worker Timeout Episodes: {timeout_episodes}")
        if best_episode_result_local:
            # Safely get the global episode number from the stored result dict
            best_global_ep_num = best_episode_result_local.get('episode_num', 'N/A')
            logging.info(f"Best Episode Reward (local to this function run): {best_episode_reward:.4f} (Global Ep: {best_global_ep_num})")
        logging.info("--- End Summary ---")

        return {
            "function_index": function_index,
            "status": "evaluation_completed",
            "avg_reward": avg_reward,
            "total_reward": total_reward_for_func,
            "completed_episodes": completed_episodes,
            "failed_episodes": failed_episodes,
            "timeout_episodes": timeout_episodes,
            "reward_std_dev": reward_std_dev,
            "avg_steps": avg_steps,
            "best_episode_reward": best_episode_reward if best_episode_result_local else float('-inf'),
            "best_episode_details": best_episode_result_local,
            "individual_results": results
        }

    # Renamed helper function for clarity
    def _process_episode_result_logging(
        self, 
        result: Dict[str, Any], 
        global_ep_num: int, 
        local_episode_index: int, 
        function_index: int, 
        num_episodes_to_run: int
    ) -> None:
        """Helper function to log a single episode's result status. Does not update aggregates."""
        status = result.get("status", "unknown")
        log_message = f"(Eval Func {function_index+1}) Episode {local_episode_index + 1}/{num_episodes_to_run} (Global: {global_ep_num}) finished with worker status: {status}"

        reward = 0.0
        total_steps = 0
        if status in ["completed", "timeout"]:
            reward = result.get("reward", 0.0)
            total_steps = result.get('total_steps', 0)

        if status == "completed":
            duration = result.get("duration_seconds", -1.0)
            log_message += f" and reward: {reward:.4f}, steps: {total_steps} in {duration:.2f}s"
            # Check worker's internal rule timeout threshold
            rule_timeout_count = result.get('timeout_count', 0)
            if total_steps > 0 and rule_timeout_count > 0: # Only log if timeouts occurred
                 timeout_percentage = (rule_timeout_count / total_steps) * 100
                 # Log as warning if threshold exceeded
                 log_level = logging.ERROR if timeout_percentage > 25.0 else logging.WARNING
                 logging.log(log_level, f"Episode {global_ep_num} had internal rule timeouts: {rule_timeout_count}/{total_steps} ({timeout_percentage:.1f}%) rules timed out.")
        elif status == "timeout":
            duration = result.get("duration_seconds", -1.0)
            error_msg = result.get("error", "Worker internal timeout")
            log_message += f". Recorded Reward: {reward:.4f}, Steps: {total_steps}, Duration: {duration:.2f}s. Error: {error_msg}"
            logging.warning(log_message + '.') # Log as warning if it was an internal worker timeout
            return # Avoid double logging info below
        elif status.startswith("failed"): # Covers failed_orchestrator_timeout, failed_in_main etc.
            error_msg = result.get('error', 'Unknown worker error')
            log_message += f". Error: {error_msg}"
            logging.error(log_message + '.')
            if 'traceback' in result and result['traceback']:
                logging.error(f"Traceback for failed episode {global_ep_num}:\n{result['traceback']}")
            return # Avoid double logging info below
        else: # Unknown status
            logging.warning(f"Received unknown status '{status}' for episode {global_ep_num}")
        
        logging.info(log_message + '.') # Log the base message for completed/unknown

    def _mark_best_episode_local(self, best_episode_result: Dict[str, Any]) -> None:
        """Rename the best episode directory within its function's subdirectory."""
        if not best_episode_result or 'episode_dir' not in best_episode_result:
            logging.warning("Could not mark local best episode: Missing result or episode_dir.")
            return
        
        try:
            old_path = best_episode_result['episode_dir']
            # Ensure the path exists before trying to rename
            if not os.path.isdir(old_path):
                 logging.warning(f"Could not mark local best episode: Directory does not exist: {old_path}")
                 # Set episode_dir to indicate failure to find it? Or just return?
                 # Returning for now, the result dict won't be updated.
                 return

            # Place 'best_episode' within the parent directory of the episode run
            # e.g., if old_path is '.../function_0/episode_3', new_path is '.../function_0/best_episode'
            parent_dir = os.path.dirname(old_path)
            new_path = os.path.join(parent_dir, "best_episode") 
            
            # Ensure parent directory exists (it should, being the function_X dir)
            if not os.path.isdir(parent_dir):
                 logging.error(f"Cannot mark local best episode: Parent directory {parent_dir} not found.")
                 return

            logging.info(f"Attempting to mark local best episode: Renaming {old_path} to {new_path}")

            # Safer rename: remove existing link/dir first
            if os.path.lexists(new_path): # Use lexists to handle dangling symlinks too
                 logging.warning(f"Existing 'best_episode' found at {new_path}. Removing it.")
                 if os.path.isdir(new_path) and not os.path.islink(new_path):
                      import shutil
                      shutil.rmtree(new_path)
                      logging.info(f"Removed existing best_episode directory: {new_path}")
                 else:
                      os.remove(new_path)
                      logging.info(f"Removed existing best_episode file/link: {new_path}")

            # Rename the directory
            os.rename(old_path, new_path)
            
            # Update the paths in the result dictionary (important!)
            # Store the new path before updating potentially missing keys
            updated_episode_dir = new_path
            for key in ['logs_dir', 'graphs_dir', 'viz_dir', 'checkpoints_dir']:
                if key in best_episode_result and best_episode_result[key] is not None:
                    # Replace the original base path part with the new 'best_episode' path part
                    # Be careful with separators and potential double replacements
                    # Safer: reconstruct path based on new base 'new_path' and relative part
                    relative_part = os.path.basename(best_episode_result[key]) # e.g., 'logs', 'graphs'
                    original_base = os.path.dirname(best_episode_result[key]) # e.g., '.../function_0/episode_3'
                    if original_base == old_path: # Ensure we only replace if it matches the old episode path
                         best_episode_result[key] = os.path.join(new_path, relative_part)
                    else:
                         logging.warning(f"Path mismatch when updating '{key}': expected base '{old_path}', got '{original_base}'")
                elif key not in best_episode_result:
                    logging.debug(f"Key '{key}' not found in best_episode_result during path update.")

            # Update the main episode directory path
            best_episode_result['episode_dir'] = updated_episode_dir
            
            logging.info(f"Successfully marked local best episode directory: {new_path}")
            logging.debug(f"Updated best_episode_result paths: {best_episode_result}")

        except Exception as e:
            logging.error(f"Failed to mark local best episode directory (from {old_path}): {e}")
            logging.error(traceback.format_exc())
            # Attempt to leave paths in result dict unchanged on error

    # Rename original _mark_best_episode to avoid conflicts
    def _mark_best_episode_global(self, best_episode_result: Dict[str, Any]) -> None: 
        """Rename the best episode directory globally (used in standard mode)."""
        # (Keep original implementation of _mark_best_episode here)
        if not best_episode_result or 'episode_dir' not in best_episode_result:
            logging.warning("Could not mark global best episode: Missing result or episode_dir.")
            return
        try:
            old_path = best_episode_result['episode_dir']
            if not os.path.isdir(old_path):
                 logging.warning(f"Could not mark global best episode: Directory does not exist: {old_path}")
                 return

            parent_dir = os.path.dirname(old_path) # Should be the main output dir
            new_path = os.path.join(parent_dir, "best_episode")
            
            logging.info(f"Attempting to mark global best episode: Renaming {old_path} to {new_path}")
            
            if os.path.lexists(new_path):
                logging.warning(f"Existing global 'best_episode' found at {new_path}. Removing it.")
                if os.path.isdir(new_path) and not os.path.islink(new_path):
                    import shutil; shutil.rmtree(new_path)
                else: os.remove(new_path)
            
            os.rename(old_path, new_path)
            
            # Update paths in the result dict
            updated_episode_dir = new_path
            for key in ['logs_dir', 'graphs_dir', 'viz_dir', 'checkpoints_dir']:
                 if key in best_episode_result and best_episode_result[key] is not None:
                      relative_part = os.path.basename(best_episode_result[key])
                      original_base = os.path.dirname(best_episode_result[key])
                      if original_base == old_path:
                           best_episode_result[key] = os.path.join(new_path, relative_part)
                      else:
                           logging.warning(f"Path mismatch when updating global '{key}': expected base '{old_path}', got '{original_base}'")

            best_episode_result['episode_dir'] = updated_episode_dir
            logging.info(f"Successfully marked global best episode directory: {new_path}")
        except Exception as e:
             logging.error(f"Failed to mark global best episode directory: {e}")
             logging.error(traceback.format_exc())

    def _process_config(self, cfg: DictConfig) -> TheoryBuilderConfig:
        """Process Hydra config into internal config."""
        # -- Log type of incoming cfg --
        logging.debug(f"_process_config received cfg of type: {type(cfg)}")

        # Set defaults for optional parameters
        num_episodes = 1 # Default M
        checkpoint_frequency = 10
        print_concepts = False
        print_graph = False
        check_duplicates = False
        save_visualizations = True
        num_workers = 1
        episode_timeout_seconds = 15
        visualization_step_threshold = None
        # Defaults for N x M mode
        evaluate_multiple_interestingness = False
        num_interestingness_to_generate = 1 # Default N
        # Default for marking best episode
        mark_best_episode = True
        
        # Extract values from configuration if available
        # Use OmegaConf.select to safely access nested attributes with defaults
        experiment_cfg = OmegaConf.select(cfg, 'experiment', default={})
        logging_cfg = OmegaConf.select(cfg, 'logging', default={})
        output_cfg = OmegaConf.select(cfg, 'output', default={})
        policy_cfg = OmegaConf.select(cfg, 'policy', default={})
        
        # -- Log types of extracted configs --
        logging.debug(f"Type of experiment_cfg: {type(experiment_cfg)}")
        logging.debug(f"Type of logging_cfg: {type(logging_cfg)}")
        logging.debug(f"Type of output_cfg: {type(output_cfg)}")
        logging.debug(f"Type of policy_cfg: {type(policy_cfg)}")

        # Standard params
        num_episodes = OmegaConf.select(experiment_cfg, 'num_episodes', default=num_episodes)
        num_workers = OmegaConf.select(experiment_cfg, 'num_workers', default=num_workers)
        episode_timeout_seconds = OmegaConf.select(experiment_cfg, 'episode_timeout_seconds', default=episode_timeout_seconds)
        max_steps = OmegaConf.select(experiment_cfg, 'max_steps', default=100) # Provide a default max_steps
        seed = OmegaConf.select(experiment_cfg, 'seed', default=0) # Default seed if not set

        # N x M params
        evaluate_multiple_interestingness = OmegaConf.select(experiment_cfg, 'evaluate_multiple_interestingness', default=evaluate_multiple_interestingness)
        num_interestingness_to_generate = OmegaConf.select(experiment_cfg, 'num_interestingness_to_generate', default=num_interestingness_to_generate)
            
        checkpoint_frequency = OmegaConf.select(logging_cfg, 'checkpoint_frequency', default=checkpoint_frequency)
            
        print_concepts = OmegaConf.select(output_cfg, 'print_concepts', default=print_concepts)
        print_graph = OmegaConf.select(output_cfg, 'print_graph', default=print_graph)
        save_visualizations = OmegaConf.select(output_cfg, 'save_visualizations', default=save_visualizations)
        visualization_step_threshold = OmegaConf.select(output_cfg, 'visualization_step_threshold', default=visualization_step_threshold)
        mark_best_episode = OmegaConf.select(output_cfg, 'mark_best_episode', default=mark_best_episode) # Read the new flag
            
        # Determine policy type/target safely
        policy_type = OmegaConf.select(policy_cfg, 'type', default=None)
        if policy_type is None:
             policy_type = OmegaConf.select(policy_cfg, '_target_', default='UnknownPolicy')
        policy_params = OmegaConf.select(policy_cfg, 'params', default={})

        # Create the config object
        builder_config = TheoryBuilderConfig(
            policy_type=policy_type,
            policy_params=policy_params, # Pass as dict
            max_steps=max_steps,
            num_episodes=num_episodes, # M
            seed=seed,
            checkpoint_frequency=checkpoint_frequency,
            print_concepts=print_concepts,
            print_graph=print_graph,
            check_duplicates=check_duplicates,
            save_visualizations=save_visualizations,
            num_workers=num_workers,
            episode_timeout_seconds=episode_timeout_seconds,
            visualization_step_threshold=visualization_step_threshold,
            # N x M flags
            evaluate_multiple_interestingness=evaluate_multiple_interestingness,
            num_interestingness_to_generate=num_interestingness_to_generate,
            mark_best_episode=mark_best_episode # Pass the flag to the config object
        )

        # Log the mode being used
        if builder_config.evaluate_multiple_interestingness:
             logging.info(f"Config processed: N x M evaluation mode ENABLED (N={builder_config.num_interestingness_to_generate}, M={builder_config.num_episodes}).")
        else:
             logging.info(f"Config processed: Standard evaluation mode ENABLED (M={builder_config.num_episodes}).")

        return builder_config
        
    def _maybe_set_random_seed(self, cfg: DictConfig) -> None:
        """
        Set a random seed in the configuration if needed.
        
        If the seed is 0 or not provided, generate a random seed based on time
        and set it in the configuration. Modifies cfg in place.
        
        Args:
            cfg: Hydra configuration
        """
        # Check if we need to generate a random seed
        need_random_seed = False
        seed_value = OmegaConf.select(cfg, 'experiment.seed', default=None)

        if seed_value is None or seed_value == 0:
            need_random_seed = True
            
        if need_random_seed:
            import time
            import random
            # Use current time (milliseconds) to generate a random seed
            random_seed = int(time.time() * 1000) % 1000000
            
            # Ensure experiment section exists and is writable
            OmegaConf.set_struct(cfg, False) # Allow adding experiment if missing
            if 'experiment' not in cfg:
                cfg.experiment = {}
            
            # Ensure seed can be set (if experiment was initially defined but struct)
            if isinstance(cfg.experiment, DictConfig):
                 OmegaConf.set_struct(cfg.experiment, False) 
            
            cfg.experiment.seed = random_seed

            # Optionally re-enable struct mode
            if isinstance(cfg.experiment, DictConfig):
                 OmegaConf.set_struct(cfg.experiment, True)
            OmegaConf.set_struct(cfg, True) 
            
            logging.info(f"Generated random seed: {random_seed} and updated config.")
        else:
            logging.info(f"Using seed from config: {seed_value}")
            
    def _maybe_generate_interestingness_function(self, cfg: DictConfig) -> None:
        """
        Generate an interestingness function IF requested in config AND called during standard mode __init__.
        Modifies the input cfg object in place with the generated path.
        
        Args:
            cfg: The configuration object (Hydra DictConfig).
        """
        # This function is now only called during __init__ if NOT in N x M mode.
        # Generation for N x M mode is handled directly in the run loop.

        # Check if we need to generate (based on config flags)
        generate_flag = OmegaConf.select(cfg, 'policy.generate_interestingness', default=False)
        if not generate_flag:
             generate_flag = OmegaConf.select(cfg, 'policy.params.generate_interestingness', default=False)
            
        if not generate_flag:
             logging.info("Interestingness generation not requested in config for standard mode.")
             return # Don't generate if not requested

        # Proceed with generation only if needed for standard mode
        logging.info("Standard mode init: Generating single interestingness function...")
        interestingness_path = self._generate_interestingness_function()
        
        # --- Update the cfg object passed to __init__ ---
        # This ensures the path is stored correctly in self.policy_config later in __init__
        OmegaConf.set_struct(cfg.policy, False) # Allow modifications
        if hasattr(cfg.policy, 'params'):
             OmegaConf.set_struct(cfg.policy.params, False)
        
        if interestingness_path:
            if hasattr(cfg.policy, 'params'):
                cfg.policy.params.interestingness_function_path = interestingness_path
                cfg.policy.params.concept_selection = "INTERESTINGNESS"
            else:
                cfg.policy.interestingness_function_path = interestingness_path
                cfg.policy.concept_selection = "INTERESTINGNESS"
            logging.info(f"Standard mode init: Updated config with generated path: {interestingness_path}")
        else:
            logging.warning("Standard mode init: Failed to generate interestingness function.")
            # Ensure path is empty if generation fails
            if hasattr(cfg.policy, 'params'):
                 cfg.policy.params.interestingness_function_path = ""
            else:
                 cfg.policy.interestingness_function_path = ""

        # Re-enable struct mode if needed
        if hasattr(cfg.policy, 'params'):
             OmegaConf.set_struct(cfg.policy.params, True)
        OmegaConf.set_struct(cfg.policy, True)

    def _generate_interestingness_function(self, model_config: Optional[Dict[str, Any]] = None) -> Optional[str]:
        """
        Generate an interestingness function using the one-shot LLM approach.
        
        Args:
            model_config: Optional model configuration dictionary
            
        Returns:
            The path to the generated function file, or None if generation failed
        """
        # Define logger at the start of the function
        logger = logging.getLogger(__name__)
        
        try:
            from frame.interestingness.learning.algorithms.one_shot_llm import (
                OneShotLLMGenerator,
                # prepare_prompt, # Remove this import
            )
            from frame.interestingness.learning.dsl_primitives import ALL_PRIMITIVES
                        
            # Generate a unique identifier for this function with timestamp and random hash
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            random_hash = secrets.token_hex(4)  # 8-character hex string
            conversation_id = f"interestingness_{timestamp}_{random_hash}"
            
            # Get project root directory - the FRAME directory
            project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
            
            # Define explicit paths using environment variables if provided
            model_config_path = os.environ.get("MODEL_CONFIG_DIR", None)
            if model_config_path:
                model_config_path = os.path.join(model_config_path, "gpt4o-mini.yaml")
            else:
                model_config_path = os.path.join(project_root, "frame", "configs", "models", "gpt4o-mini.yaml")
            
            prompt_template_path = os.environ.get("PROMPT_TEMPLATE_DIR", None)
            if prompt_template_path:
                prompt_template_path = os.path.join(prompt_template_path, "interestingness_prompt.yaml")
            else:
                prompt_template_path = os.path.join(project_root, "frame", "configs", "prompts", "interestingness_prompt.yaml")
            
            output_dir = os.path.join(project_root, "frame", "interestingness", "learning", "generated_programs")
            
            # Ensure output directory exists
            os.makedirs(output_dir, exist_ok=True)
            
            logger.info(f"Using model config at: {model_config_path}")
            logger.info(f"Using prompt template at: {prompt_template_path}")
            logger.info(f"Will save generated function to directory: {output_dir}")
            
            # Initialize the LLM generator, passing the prompt path
            # The generator will now load the YAML internally.
            generator = OneShotLLMGenerator(
                model_config_path=model_config_path,
                prompt_template_path=prompt_template_path,
                output_dir=output_dir,
                logger=logger
            )
            
            # Generate the function - generator handles prompt creation internally now
            logger.info(f"Generating interestingness function with ID: {conversation_id}")
            _, function_path = asyncio.run(generator.generate(conversation_id=conversation_id))
            
            logger.info(f"Saved interestingness function to {function_path}")
            return function_path
            
        except Exception as e:
            logger.error(f"Error generating interestingness function: {e}")
            import traceback
            logger.error(traceback.format_exc())
            return None

    def _setup_output_directories(self) -> None:
        """Set up output directories for the experiment."""
        try:
            # Get Hydra's output directory
            output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
            self.output_dir = output_dir
            
            # Create output directory
            os.makedirs(output_dir, exist_ok=True)
            
            # Set up main experiment log
            main_log_file = os.path.join(output_dir, "theory_builder.log")
            main_handler = logging.FileHandler(main_log_file)
            main_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
            
            # Get root logger and remove existing handlers
            root_logger = logging.getLogger()
            root_logger.handlers = []
            root_logger.addHandler(main_handler)
            
        except Exception as e:
            logging.error(f"Error setting up output directories: {e}")
            self.output_dir = None

@hydra.main(version_base=None, config_path="configs", config_name="theory_discovery")
def main(cfg: DictConfig) -> None:
    """Run theory building process."""
    # Log startup information instead of printing
    logging.info("Starting TheoryBuilder with configuration:")
    logging.info(f"- Experiment: {cfg.experiment.name if hasattr(cfg.experiment, 'name') else 'unnamed'}")
    logging.info(f"- Policy: {cfg.policy._target_ if hasattr(cfg.policy, '_target_') else cfg.policy.type}")
    logging.info(f"- Max steps: {cfg.experiment.max_steps}")
    logging.info(f"- Episodes: {cfg.experiment.num_episodes if hasattr(cfg.experiment, 'num_episodes') else 1}")
    logging.info(f"- Workers: {cfg.experiment.num_workers if hasattr(cfg.experiment, 'num_workers') else 1}")
    
    # Initialize and run the theory builder
    try:
        builder = TheoryBuilder(cfg)
        
        # Run the builder
        total_reward = builder.run()
        
        # Final completion message
        logging.info(f"\nExperiment completed successfully with total reward: {total_reward}")
        return total_reward
    except KeyboardInterrupt:
        # Log interruption instead of printing
        logging.warning("\nExperiment interrupted by user")
        return None
    except Exception as e:
        # Log error instead of printing
        logging.error(f"Error running experiment: {e}")
        import traceback
        # Log traceback as well
        logging.error(traceback.format_exc())
        return None
    finally:
        # Ensure resources are cleaned up
        cleanup_resources()

# Main execution block
if __name__ == "__main__":
    # Set Hydra environment variables
    os.environ["HYDRA_FULL_ERROR"] = "1"
    
    # Generate unique directory name first
    unique_dir = generate_unique_dir()
    output_dir = os.path.join(os.getcwd(), "outputs/theory_discovery", unique_dir)
    
    # Set Hydra's output directory to our custom one
    os.environ["HYDRA_RUN_DIR"] = output_dir
    os.environ["HYDRA_OUTPUT_SUBDIR"] = "."
    
    # Global flag to track interruption
    interrupted = False
    cleanup_done = False
    
    # Set up multiprocessing to spawn new processes
    # This is safer than fork(), especially when dealing with threads
    multiprocessing.set_start_method('spawn', force=True)
    
    # Define signal handler
    def signal_handler(sig, frame):
        global interrupted, cleanup_done
        print("\nInterrupt received, cleaning up resources...")
        
        if not cleanup_done:
            try:
                cleanup_resources()
                cleanup_done = True
                print("Cleanup complete, exiting.")
            except Exception as e:
                print(f"Error during cleanup: {e}")
            finally:
                # Directly exit without returning to the signal handler
                os._exit(1)  # Use os._exit to force immediate termination
    
    # Register signal handlers for clean shutdown
    signal.signal(signal.SIGINT, signal_handler)  # Ctrl+C
    signal.signal(signal.SIGTERM, signal_handler)  # Termination signal
    
    # Register cleanup at exit - this avoids redundant cleanup calls
    def cleanup_at_exit():
        global cleanup_done
        if not cleanup_done:
            cleanup_resources()
            cleanup_done = True
    
    atexit.register(cleanup_at_exit)
    
    try:
        # Run with Hydra
        main()
    except KeyboardInterrupt:
        if not cleanup_done:
            print("\nInterrupt received, cleaning up resources...")
            cleanup_resources()
            cleanup_done = True
            print("Cleanup complete, exiting.")
        sys.exit(0)
    except Exception as e:
        if not cleanup_done:
            print(f"Error occurred: {e}")
            cleanup_resources()
            cleanup_done = True
        sys.exit(1)