#!/usr/bin/env python3
"""FunSearch integration with FRAME's TheoryBuilder."""

import os
import sys
import time
import logging
import hydra
import omegaconf
from omegaconf import DictConfig, OmegaConf
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Union
import asyncio
import re
import textwrap
import yaml
import copy
import ast

# Import funsearch components
try:
    # Updated imports for the new structure
    from frame.funsearch.implementation import code_manipulation
    from frame.funsearch.implementation.programs_database import ProgramsDatabase, ScoresPerTest
    from frame.funsearch.implementation.config import Config as FunSearchConfig
    from frame.funsearch.implementation.config import ProgramsDatabaseConfig
    from frame.funsearch.implementation.samplers import MutationSampler, AbstractionSampler
    from frame.funsearch.implementation.evaluator import Evaluator
    from frame.funsearch.implementation.abstraction_library import Abstraction
    
    # Create local aliases for missing classes
    class FunctionSignature:
        def __init__(self, name, args, returns):
            self.name = name
            self.args = args
            self.returns = returns
            
    class EntryPoint:
        def __init__(self, signature):
            self.signature = signature
            
    FUNSEARCH_AVAILABLE = True
except ImportError:
    raise ImportError("FunSearch implementation not found. Please install funsearch or use the mock implementation.")

# Import FRAME components
from frame.tools.llm_caller import LLMCaller

# Local imports - using absolute paths
from frame.funsearch.theory_builder_sandbox import TheoryBuilderSandbox
from frame.funsearch.db_persister import save_database, load_database

logger = logging.getLogger(__name__)

class FunSearchManager:
    """Manages the FunSearch evolutionary process with FRAME."""
    
    def __init__(self, cfg: DictConfig):
        """
        Initialize the FunSearch manager.
        
        Args:
            cfg: Configuration object
        """
        self.cfg = cfg
        self.iteration_rewards_log: List[Tuple[int, float, Optional[str]]] = [] # To store (iteration, reward, best_program_code_at_iteration_end)
        self.island_progress_log: List[Tuple[int, int, float]] = [] # To store (iteration, island_id, best_score_for_island_at_iteration)
        
        # Store the path to the TheoryBuilder base configuration
        if not hasattr(self.cfg, 'tb_config_path') or not self.cfg.tb_config_path:
            raise ValueError("Missing required configuration key: 'tb_config_path' must be provided.")
        self.tb_config_path = self.cfg.tb_config_path
        
        # Set up output directory using Hydra's working directory
        self.output_dir = Path(os.getcwd())
        logger.info(f"Using Hydra run directory as main output directory: {self.output_dir}")
        # No need to create self.output_dir, Hydra handles it.
        
        # Set up logging within the Hydra directory
        log_file = self.output_dir / "funsearch.log"
        logging.basicConfig(
            level=logging.DEBUG,
            format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
            handlers=[
                logging.FileHandler(log_file),
                # logging.StreamHandler(sys.stdout) # Removed to prevent logging to stdout
            ]
        )
        
        # Load the function specification
        self.spec_path = self.cfg.spec_path
        logger.info(f"Using specification from {self.spec_path}")
        
        # Load the specification code once to get the initial function
        # This also ensures the spec file is valid early on
        initial_function_code_str, self.template = self._load_spec_and_template()
        
        # Set up LLM caller with model config from the Hydra config
        model_config = omegaconf.OmegaConf.create({
            "name": self.cfg.llm.model_name,
            "temperature": self.cfg.llm.temperature,
            "top_p": self.cfg.llm.top_p,
            "max_tokens": self.cfg.llm.max_tokens
        })
        self.llm_caller = LLMCaller(model_config=model_config, logger=logger)
        
        # --- Abstraction Setup --- 
        self.setup_abstraction() # Determine if abstraction is enabled
        self.abstractions_dir = None # Initialize as None
        if self.abstraction_enabled:
            # Create dedicated directory for abstraction files within the run directory
            self.abstractions_dir = self.output_dir / 'abstractions'
            os.makedirs(self.abstractions_dir, exist_ok=True)
            logger.info(f"Abstraction files will be stored in: {self.abstractions_dir}")
            # Initialize empty files (optional, but good practice)
            num_islands = self.cfg.programs_database.num_islands
            for island_id in range(num_islands):
                 file_path = self.abstractions_dir / f'island_{island_id}_abstractions.py'
                 with open(file_path, 'w') as f:
                     f.write("# Initial abstraction file\n")
        # --- End Abstraction Setup Modification --- 

        # Load or create ProgramsDatabase
        db_created, initial_function_obj = self.load_or_create_database()
        
        # Determine shared DB path (using absolute path)
        self.shared_db_path = str((self.output_dir / "episodic_cache.db").resolve())
        logger.info(f"Shared Episodic Cache DB will be at: {self.shared_db_path}")
        
        # Set up the sandbox -> PASS the abstractions_dir path AND shared_db_path
        self.setup_sandbox(abstractions_dir_path=self.abstractions_dir, shared_db_path=self.shared_db_path)
        
        # --- Safeguard for excessive parallelism ---
        # Access TheoryBuilder's num_workers from the base config loaded by the sandbox
        tb_base_cfg = self.sandbox.base_cfg
        tb_num_workers = omegaconf.OmegaConf.select(tb_base_cfg, 'experiment.num_workers', default=64)
        
        # Access num_islands from FunSearch config (after database setup)
        num_islands = self.cfg.programs_database.num_islands
        
        total_potential_tb_workers = tb_num_workers * num_islands
        
        logger.info(f"Potential TheoryBuilder worker calculation for safeguard: "
                    f"TheoryBuilder experiment.num_workers ({tb_num_workers}) * "
                    f"FunSearch programs_database.num_islands ({num_islands}) = {total_potential_tb_workers}")

        # The limit is < 64, so >= 64 should raise an error.
        WORKER_LIMIT = 256 
        if total_potential_tb_workers > WORKER_LIMIT:
            error_msg = (
                f"Potential concurrent TheoryBuilder worker count ({total_potential_tb_workers}) meets or exceeds limit of {WORKER_LIMIT}. "
                f"This is calculated from: TheoryBuilder's experiment.num_workers ({tb_num_workers}) * "
                f"FunSearch's programs_database.num_islands ({num_islands}). "
                f"Please reduce these values in your configurations to prevent system overload."
            )
            logger.error(error_msg)
            raise ValueError(error_msg)
        else:
            logger.info(f"Potential concurrent TheoryBuilder worker count ({total_potential_tb_workers}) is within the limit (< {WORKER_LIMIT}).")
        # --- End safeguard ---

        # Evaluate the initial function if the database was just created
        if db_created and initial_function_obj and initial_function_code_str:
            logger.info("Evaluating initial program...")
            self.evaluate_initial_program(initial_function_obj, initial_function_code_str)
        elif db_created:
            logger.error("Database created, but failed to get initial function object or code string for evaluation.")
            
        # Initialize Samplers and Evaluators
        self.evaluators = []
        for _ in range(self.cfg.num_evaluators):
            self.evaluators.append(Evaluator(
                database=self.programs_db,
                template=self.template, 
                function_to_evolve="calculate_interestingness",
                function_to_run="calculate_interestingness",
                inputs=[], 
                timeout_seconds=self.cfg.frame.evaluation_timeout_seconds,
                sandbox=self.sandbox # Pass the sandbox instance here
            ))
            
        logger.info(f"Initialized {len(self.evaluators)} evaluators.")
        
        # Initialize MutationSampler
        self.mutation_sampler = MutationSampler(
            database=self.programs_db,
            evaluators=self.evaluators,
            llm_caller=self.llm_caller, # Pass the pre-configured LLM caller
            samples_per_prompt=self.cfg.samples_per_prompt,
            prompt_yaml_path=self.cfg.prompt_yaml_path, # Pass the path from config
            abstraction_enabled=self.abstraction_enabled # Pass the flag
        )
        logger.info(f"Initialized MutationSampler.")
        
        # Initialize AbstractionSampler if enabled
        if self.abstraction_enabled:
            self.abstraction_sampler = AbstractionSampler(
                llm_caller=self.llm_caller,
                # Pass the abstraction-specific config and loaded prompts
                abstraction_config={
                    'llm': self.cfg.abstraction.get('llm'), # For overrides
                    'max_abstractions_per_step': self.abstraction_max_abstractions,
                    # Pass loaded prompts directly to avoid re-loading
                    'abstraction_system_prompt': self.abstraction_system_prompt,
                    'abstraction_user_prompt_template': self.abstraction_user_prompt_template 
                } 
            )
            logger.info("Initialized AbstractionSampler.")
        else:
            self.abstraction_sampler = None # Explicitly set to None if disabled
        
    def _load_spec_and_template(self) -> Tuple[Optional[str], Optional[code_manipulation.Program]]:
        """Loads the specification file and creates the initial template."""
        spec_globals = {}
        try:
            with open(self.spec_path, 'r') as f:
                spec_code = f.read()
                # Inject the prompt yaml path into the execution context
                if hasattr(self.cfg, 'prompt_yaml_path') and self.cfg.prompt_yaml_path:
                     spec_globals['prompt_yaml_path'] = self.cfg.prompt_yaml_path
                     logger.info(f"Injecting prompt_yaml_path: {self.cfg.prompt_yaml_path}")
                else:
                     logger.error("Configuration missing 'prompt_yaml_path'. 'get_imports' might fail.")
                     # Optionally raise an error here if it's critical
                     # raise ValueError("Configuration missing required key: prompt_yaml_path")
                     
                exec(spec_code, spec_globals)
                
            calculate_interestingness_str = None
            if 'InterestingnessSpec' in spec_globals:
                spec_class = spec_globals['InterestingnessSpec']
                if hasattr(spec_class, 'get_default_function'):
                    calculate_interestingness_str = spec_class.get_default_function()
                else:
                    logger.error("InterestingnessSpec found but has no 'get_default_function' method.")
                    raise ValueError("Could not get function code from InterestingnessSpec.")
                    # Attempt to get from imports as fallback
                    if hasattr(spec_class, 'get_imports'):
                        imports = spec_class.get_imports()
                        calculate_interestingness_str = imports # Placeholder, needs actual function body
                        logger.warning("Using only imports as function string - this is likely incorrect.")
                    else:
                        raise ValueError("Could not get function code from InterestingnessSpec.")
                    
            elif 'calculate_interestingness' in spec_globals:
                # Fallback: try getting source from function object if spec class not found
                import inspect
                calculate_interestingness_fn = spec_globals['calculate_interestingness']
                calculate_interestingness_str = inspect.getsource(calculate_interestingness_fn)
            else:
                raise ValueError("'InterestingnessSpec' class or 'calculate_interestingness' function not found in spec file")

            if not calculate_interestingness_str:
                raise ValueError("Could not extract initial function code string from spec file")

            initial_function_obj = code_manipulation.text_to_function(calculate_interestingness_str)
            template = code_manipulation.Program(
                # Assuming preface includes necessary imports from the spec
                # Extract imports using the spec class method if available
                preface=spec_globals['InterestingnessSpec'].get_imports() if 'InterestingnessSpec' in spec_globals and hasattr(spec_globals['InterestingnessSpec'], 'get_imports') else "from typing import Dict, Any, List, Set, Optional, Union\nimport math", # Provide default imports if spec method missing
                functions=[initial_function_obj]
            )
            logger.info("Successfully loaded spec and created template.")
            return calculate_interestingness_str, template
            
        except FileNotFoundError:
            logger.error(f"Specification file not found: {self.spec_path}")
        except ValueError as ve:
            logger.error(f"Error processing spec file {self.spec_path}: {ve}")
        except Exception as e:
            logger.error(f"Unexpected error loading spec file {self.spec_path}: {e}", exc_info=True)
            
        return None, None # Return None if loading failed

    def load_or_create_database(self) -> Tuple[bool, Optional[code_manipulation.Function]]:
        """Load an existing database from disk or create a new one.
        
        Returns:
            Tuple (db_created: bool, initial_function_obj: Optional[Function])
        """
        db_created = False
        initial_function_obj = None
        
        if self.cfg.database_backup_path and os.path.exists(self.cfg.database_backup_path):
            raise NotImplementedError("Loading database from backup not implemented yet, we should not be using this setting presently.")
            logger.info(f"Loading database from {self.cfg.database_backup_path}")
            self.programs_db = load_database(self.cfg.database_backup_path)
            if self.programs_db is None:
                logger.warning("Failed to load database, creating a new one")
                self.programs_db, initial_function_obj = self.create_new_database()
                db_created = True if self.programs_db else False
            else:
                # DB loaded, get initial function from its template if needed later (e.g., for reset)
                # For now, we don't need to return it as it won't be evaluated again.
                pass
        else:
            logger.info("Creating new ProgramsDatabase")
            self.programs_db, initial_function_obj = self.create_new_database()
            db_created = True if self.programs_db else False
            
        return db_created, initial_function_obj

    def create_new_database(self) -> Tuple[Optional[ProgramsDatabase], Optional[code_manipulation.Function]]:
        """Create a new ProgramsDatabase with configuration parameters.
        
        Returns:
            Tuple (ProgramsDatabase | None, initial_function_obj | None)
        """
        if not FUNSEARCH_AVAILABLE:
            programs_db = ProgramsDatabase()
            logger.warning("Created mock ProgramsDatabase")
            return programs_db, None # No initial function for mock DB

        # The template and initial function code are now loaded in _load_spec_and_template
        if not self.template or not self.template.functions:
             logger.error("Template not loaded correctly, cannot create database.")
             return None, None
        
        initial_function_obj = self.template.functions[0]

        # Create a config for the database
        # Use ProgramsDatabaseConfig directly if available
        db_config_dict = {
            "functions_per_prompt": self.cfg.programs_database.functions_per_prompt,
            "num_islands": self.cfg.programs_database.num_islands,
            "reset_period": self.cfg.programs_database.reset_period,
            "cluster_sampling_temperature_init": self.cfg.programs_database.cluster_sampling_temperature_init,
            "cluster_sampling_temperature_period": self.cfg.programs_database.cluster_sampling_temperature_period,
            "abstraction_max_prompt_chars": self.cfg.programs_database.get('abstraction_max_prompt_chars', 32768) # Get with a default
        }
        db_config = ProgramsDatabaseConfig(**db_config_dict)
        
        # Initialize the ProgramsDatabase
        programs_db = ProgramsDatabase(
            config=db_config,
            template=self.template,
            function_to_evolve="calculate_interestingness"
        )
        
        # Initialize with the base function
        fn_code = initial_function_obj  # We've already created this
        programs_db.register_program(fn_code, None, {"default": 0.0})
        logger.info("Created new ProgramsDatabase with initial function")
        # Return the new DB and the initial function object for evaluation
        return programs_db, initial_function_obj
    
    def setup_sandbox(self, abstractions_dir_path: Optional[Path] = None, shared_db_path: Optional[str] = None):
        """Set up the TheoryBuilder sandbox, passing the abstractions directory path and shared DB path."""
        try:
            # Use the path provided via configuration (and command-line override)
            base_config_path = self.tb_config_path 
            logger.info(f"Loading base TheoryBuilder config from: {base_config_path}")
            # Ensure the path exists relative to the original execution directory if needed
            # Hydra might resolve it automatically, but manual check can help debugging
            if not os.path.isabs(base_config_path):
                 # Assuming the path in low.yaml is relative to the project root 
                 # or where the script was launched from. Hydra usually handles this.
                 pass # Rely on Hydra/OmegaConf path resolution for now.
                 
            if not os.path.exists(base_config_path):
                 logger.error(f"Base TheoryBuilder config file not found at resolved path: {base_config_path}")
                 raise FileNotFoundError(f"Base TheoryBuilder config not found: {base_config_path}")

            loaded_base_cfg = omegaconf.OmegaConf.load(base_config_path)
            logger.info("Successfully loaded base TheoryBuilder config.")

            # Read the visualization flag from the main config
            save_visualizations_flag = self.cfg.frame.get('save_eval_visualizations', False)
            logger.info(f"Sandbox will be configured with save_visualizations={save_visualizations_flag}")

            # Instantiate the sandbox, passing the abstractions dir path and visualization flag
            self.sandbox = TheoryBuilderSandbox(
                base_cfg_fragment=loaded_base_cfg, 
                episodes=self.cfg.frame.evaluation_episodes_M,
                evals_base_dir=self.output_dir / "evaluations", # Place evaluations inside Hydra run dir
                main_run_output_dir=self.output_dir, # Pass the Hydra run dir
                prompt_yaml_path=self.cfg.prompt_yaml_path,
                abstractions_dir_path=abstractions_dir_path, # Pass the path
                save_visualizations=save_visualizations_flag, # Pass the flag
                shared_db_path=shared_db_path # Pass the shared DB path
            )
            logger.info("TheoryBuilder sandbox initialized.")
        except FileNotFoundError as fnf:
            logger.error(f"Fatal error setting up sandbox: {fnf}")
            # Decide how to handle this - exit? raise? For now, log and maybe raise
            raise # Re-raise the exception to halt execution
        except Exception as e:
            logger.error(f"Unexpected error setting up sandbox: {e}", exc_info=True)
            raise # Re-raise to halt execution
    
    def evaluate_initial_program(self, initial_function_obj: code_manipulation.Function, initial_code_str: str):
        """Evaluates the initial program loaded from the spec file."""
        logger.info("Evaluating initial program from spec file...")
        try:
            # Pass the raw function code, island_id=None, and NO abstraction content
            score, success = self.sandbox.run(
                main_function_code=initial_code_str, 
                island_id=None, 
                iteration=0 # Mark as iteration 0
            )
            if success:
                logger.info(f"Initial program evaluation successful. Score: {score}")
                # Register the initial program in all islands (island_id=None)
                self.programs_db.register_program(
                    program=initial_function_obj,
                    island_id=None, 
                    scores_per_test={'eval_score': score} # Use a dict format
                )
            else:
                logger.error("Initial program evaluation failed.")
        except Exception as e:
             logger.error(f"Exception during initial program evaluation: {e}", exc_info=True)
    
    def save_database(self):
        """Save the database to disk."""
        if not self.cfg.database_backup_path:
            logger.warning("No database backup path provided, skipping backup")
            return
        
        backup_path = self.cfg.database_backup_path
        save_database(self.programs_db, backup_path)
        
        # Also save a timestamped copy in the output directory
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Create dedicated subdirectory for backups
        backup_subdir = self.output_dir / "database_backups"
        os.makedirs(backup_subdir, exist_ok=True)
        
        timestamped_path = backup_subdir / f"programs_db_{timestamp}.pkl"
        save_database(self.programs_db, timestamped_path)
    
    def run(self, iterations: int = 10):
        """
        Run the FunSearch process for a specified number of iterations.
        
        Args:
            iterations: Number of iterations to run
        """
        logger.info(f"Starting FunSearch with {iterations} iterations")
        
        try:
            for i in range(iterations):
                logger.info(f"Iteration {i+1}/{iterations}")

                # --- Abstraction Step ---
                if self.cfg.abstraction.enabled and (i + 1) % self.cfg.abstraction.frequency == 0 and i > 0:
                     logger.info(f"--- Performing Abstraction Step (Iteration {i+1}) ---")
                     self.perform_abstraction_step() # Calls the abstraction logic
                     logger.info(f"--- Abstraction Step Complete ---")
                # --- End Abstraction Step ---
                     
                # Run one round of sampling and evaluation (mutation)
                logger.info("Running Mutation Sampler...")
                # Sampler uses the latest state of the database (potentially updated by abstraction)
                self.mutation_sampler.sample(iteration=i)

                # --- Populate island_progress_log ---
                if hasattr(self.programs_db, '_best_score_per_island') and self.programs_db._best_score_per_island:
                    num_islands_for_log = len(self.programs_db._best_score_per_island)
                    for island_idx in range(num_islands_for_log):
                        current_island_score = self.programs_db._best_score_per_island[island_idx]
                        self.island_progress_log.append((i + 1, island_idx, current_island_score))
                # --- End populate island_progress_log ---

                # Save the database periodically
                self.save_database()
                
                # Print the best program so far
                self.show_best_program()
                
                # Save checkpoint information
                self.save_checkpoint_info(i + 1)

                # Record the best reward at the end of this iteration
                current_best_reward, current_best_code, _ = self._get_overall_best_program_info()
                self.iteration_rewards_log.append((i + 1, current_best_reward, current_best_code))
                
                # Add a small delay between iterations
                time.sleep(1)
                
        except KeyboardInterrupt:
            logger.info("FunSearch process interrupted by user")
        except Exception as e:
            logger.error(f"Error in FunSearch process: {e}")
        finally:
            # Final save of the database
            self.save_database()
            logger.info("FunSearch process completed")
            # Show the best program found overall
            self.show_best_program()
            # Log final abstraction summary
            self._log_final_abstraction_summary()
            # Log iteration rewards summary
            self._log_iteration_rewards()
            # Log island progress summary
            self._log_island_progress_summary()
    
    def _log_iteration_rewards(self):
        """Logs the best reward achieved at the end of each iteration, followed by a simple list of rewards."""
        logger.info("--- Iteration Rewards Summary ---")
        if not self.iteration_rewards_log:
            logger.info("No iteration rewards recorded.")
            logger.info("--- End Iteration Rewards Summary ---")
            return

        logger.info("Best reward at the end of each iteration:")
        for i, (iteration, reward, program_code) in enumerate(self.iteration_rewards_log):
            logger.info(f"  Iteration {iteration}: Best Reward = {reward}")
            # Optional: detailed debug log for the program code that achieved this reward
            # if program_code:
            #     logger.debug(f"    Code for best program at end of Iteration {iteration} (Reward: {reward}):\\n{textwrap.indent(program_code, '      ')}")

        logger.info("Summary list of rewards per iteration:")
        rewards_only_list = [reward for _, reward, _ in self.iteration_rewards_log]
        logger.info(str(rewards_only_list))

        logger.info("--- End Iteration Rewards Summary ---")

    def _log_island_progress_summary(self):
        """Logs a summary of the best score achieved by each island at each iteration."""
        logger.info("--- Island Progress Summary ---") # Simplified title
        
        # Part 1: Per-iteration scores (current bests at that iteration)
        logger.info("--- Per-Iteration Island Best Scores ---")
        if not self.island_progress_log:
            logger.info("No island progress recorded for per-iteration view.")
        else:
            iterations_logged = sorted(list(set(item[0] for item in self.island_progress_log)))
            for iteration_num in iterations_logged:
                logger.info(f"  Iteration {iteration_num}:")
                iteration_specific_logs = [log for log in self.island_progress_log if log[0] == iteration_num]
                iteration_specific_logs.sort(key=lambda x: x[1]) 
                for _, island_id, score in iteration_specific_logs:
                    score_str = f"{score:.4f}" if isinstance(score, float) and score != float('-inf') else str(score)
                    logger.info(f"    Island {island_id}: Current Best Score = {score_str}")
        
        # Part 2: NEW - Historical best scores per island (list of scores over iterations)
        logger.info("\n--- Historical Best Scores Per Island (across iterations) ---")
        if not self.island_progress_log:
            logger.info("No island progress recorded to show historical scores.")
        else:
            try:
                num_islands = self.cfg.programs_database.num_islands
                all_island_historical_scores = {island_idx: [] for island_idx in range(num_islands)}

                # Temporary structure to hold (iteration, score) tuples for sorting
                temp_island_scores_with_iter = {island_idx: [] for island_idx in range(num_islands)}

                for iteration_num, island_id, score in self.island_progress_log:
                    if 0 <= island_id < num_islands: # Ensure island_id is valid
                        temp_island_scores_with_iter[island_id].append((iteration_num, score))
                
                for island_idx in range(num_islands):
                    # Sort by iteration number before extracting scores
                    sorted_scores_for_island = sorted(temp_island_scores_with_iter[island_idx], key=lambda x: x[0])
                    # Extract just the scores for the list
                    historical_scores_list = [
                        f"{s:.4f}" if isinstance(s, float) and s != float('-inf') else str(s) 
                        for _, s in sorted_scores_for_island
                    ]
                    all_island_historical_scores[island_idx] = historical_scores_list
                    logger.info(f"  Island {island_idx}: Best scores over iterations = {all_island_historical_scores[island_idx]}")
            except Exception as e:
                logger.error(f"Error generating historical scores per island: {e}", exc_info=True)
                logger.info("Could not reliably determine number of islands or process historical data.")

        # Part 3: Final best scores per island (single best score at the end)
        logger.info("\n--- Final Best Scores Per Island ---")
        if hasattr(self.programs_db, '_best_score_per_island') and self.programs_db._best_score_per_island:
            final_best_scores_values = []
            for island_idx, score in enumerate(self.programs_db._best_score_per_island):
                score_str = f"{score:.4f}" if isinstance(score, float) and score != float('-inf') else str(score)
                logger.info(f"  Island {island_idx}: Final Best Score = {score_str}")
                final_best_scores_values.append(score)
            logger.info(f"List of final best scores across islands: {[f'{s:.4f}' if isinstance(s, float) and s != float('-inf') else str(s) for s in final_best_scores_values]}")
        else:
            logger.info("Could not retrieve final best scores per island from the database.")
        logger.info("--- End Island Progress Summary ---")

    def _get_overall_best_program_info(self) -> Tuple[float, Optional[str], int]:
        """Helper to get current best program's reward, code, and island ID."""
        best_reward = float('-inf')
        best_code = None
        best_island_id = -1

        if not FUNSEARCH_AVAILABLE or not hasattr(self.programs_db, '_best_program_per_island') or not hasattr(self.programs_db, '_best_score_per_island'):
            return best_reward, best_code, best_island_id

        # Ensure the lists are not empty and have the same length
        if not self.programs_db._best_score_per_island or \
           len(self.programs_db._best_score_per_island) != len(self.programs_db._best_program_per_island):
            logger.warning("_best_score_per_island or _best_program_per_island is empty or mismatched in length.")
            return best_reward, best_code, best_island_id

        for island_id in range(len(self.programs_db._best_score_per_island)):
            score = self.programs_db._best_score_per_island[island_id]
            program_obj = self.programs_db._best_program_per_island[island_id]
            if program_obj and score > best_reward:
                best_reward = score
                try:
                    best_code = str(program_obj)
                except Exception as e:
                    logger.warning(f"Could not convert program object to string for island {island_id}: {e}")
                    best_code = None # Ensure best_code is None if conversion fails
                best_island_id = island_id
        return best_reward, best_code, best_island_id

    def show_best_program(self):
        """Print the best program found so far."""
        best_reward, best_code, best_island_id = self._get_overall_best_program_info()
        
        if best_code:
            logger.info(f"Best program so far (found in island {best_island_id}, reward={best_reward}):")
            logger.info("-" * 40)
            logger.info(best_code)
            logger.info("-" * 40)
        else:
            logger.info("No evaluated programs available yet")
    
    def save_checkpoint_info(self, iteration: int):
        """
        Save checkpoint information about the current state.
        
        Args:
            iteration: Current iteration number
        """
        # Create checkpoints subdirectory if it doesn't exist
        checkpoints_dir = self.output_dir / "checkpoints"
        os.makedirs(checkpoints_dir, exist_ok=True)
        
        checkpoint_file = checkpoints_dir / f"checkpoint_{iteration}.txt"
        with open(checkpoint_file, 'w') as f:
            f.write(f"FunSearch checkpoint at iteration {iteration}\n")
            f.write(f"Timestamp: {datetime.now().isoformat()}\n")
            
            # Get stats from ProgramsDatabase
            num_programs_in_db = 0
            best_overall_reward = float('-inf')
            best_overall_program_code = None
            best_overall_island_id = -1
            
            if hasattr(self.programs_db, '_islands'):
                 num_programs_in_db = sum(island._num_programs for island in self.programs_db._islands)
                 # Best reward is tracked per island, find the max
                 if self.programs_db._best_score_per_island:
                     best_overall_reward = max(self.programs_db._best_score_per_island)
                     # Find the island and program corresponding to the best overall reward
                     for island_id, score in enumerate(self.programs_db._best_score_per_island):
                          if score == best_overall_reward:
                              program_obj = self.programs_db._best_program_per_island[island_id]
                              if program_obj:
                                   best_overall_program_code = str(program_obj)
                                   best_overall_island_id = island_id
                                   break # Found the first best, stop searching
            
            f.write(f"Number of programs in database: {num_programs_in_db}\n")
            
            if best_overall_reward > float('-inf'):
                f.write(f"Best OVERALL reward so far: {best_overall_reward} (found in Island {best_overall_island_id})\n")
                f.write("Best Overall Program Code:\n")
                f.write("-" * 20 + "\n")
                f.write(best_overall_program_code if best_overall_program_code else "(Code not available)")
                f.write("\n" + "-" * 20 + "\n")
                # Add note about visualization location
                f.write("Note: Detailed evaluation logs and visualizations for the episode that generated this program")
                f.write(f" *should* be located within the 'evaluations/island_{best_overall_island_id}/iteration_X/eval_Y' subdirectory")
                f.write(" structure, where X and Y depend on when this specific program was evaluated.\n")
            else:
                f.write("No evaluated programs recorded yet\n")
                
            # --- Per-Island Details ---
            f.write("\n--- Per-Island Best Programs and Rewards ---")
            num_islands = self.cfg.programs_database.num_islands
            for island_id in range(num_islands):
                f.write(f"\n--- Island {island_id} ---")
                
                # Get best score for this island
                best_island_score = self.programs_db._best_score_per_island[island_id]
                if best_island_score > float('-inf'):
                    f.write(f"\nBest Reward: {best_island_score}")
                else:
                    f.write("\nBest Reward: (None recorded yet)")
                
                # Get best program code for this island
                program_obj = self.programs_db._best_program_per_island[island_id]
                if program_obj:
                    try:
                        code_str = str(program_obj)
                        f.write("\nBest Program Code:\n")
                        f.write("-" * 20 + "\n")
                        f.write(code_str)
                        f.write("\n" + "-" * 20 + "\n")
                    except Exception as e:
                        f.write(f"\nBest Program Code: (Error converting to string: {e})\n")
                else:
                    f.write("\nBest Program Code: (None recorded yet)\n")
                    
                # Add note about visualization location for this island's best
                f.write("Note: Visualizations for this island's best program's evaluation episode")
                f.write(" *should* be within the corresponding 'evaluations/island_{island_id}/iteration_X/eval_Y' subdirectory.\n")
            f.write("\n--- End Per-Island Details ---")
            # --- End Per-Island Details ---

    def setup_abstraction(self):
        """Initializes abstraction-related configurations and prompts."""
        self.abstraction_enabled = self.cfg.get('abstraction', {}).get('enabled', False)
        if not self.abstraction_enabled:
            logger.info("Abstraction feature is disabled in configuration.")
            return
        
        logger.info("Abstraction feature is enabled.")
        # Store other abstraction config values
        self.abstraction_frequency = self.cfg.abstraction.get('frequency', 10) # Default to 10
        self.abstraction_programs_to_sample = self.cfg.abstraction.get('programs_to_sample', 5) # Default to 5
        self.abstraction_max_abstractions = self.cfg.abstraction.get('max_abstractions_per_step', 3) # Default to 3
        self.abstraction_prompt_path = self.cfg.abstraction.get('prompt_yaml_path')
        
        if not self.abstraction_prompt_path:
            logger.error("Abstraction prompt_yaml_path is not configured. Disabling abstraction.")
            self.abstraction_enabled = False
            return

        # Load abstraction prompt YAML
        try:
            with open(self.abstraction_prompt_path, 'r') as f:
                self.abstraction_prompts_config = yaml.safe_load(f) 
                # Assuming YAML structure has these keys
                self.abstraction_system_prompt = self.abstraction_prompts_config['abstraction_system_prompt']
                self.abstraction_user_prompt_template = self.abstraction_prompts_config['abstraction_user_prompt_template']
                logger.info(f"Loaded abstraction prompts from {self.abstraction_prompt_path}")
        except FileNotFoundError:
            logger.error(f"Abstraction prompt file not found: {self.abstraction_prompt_path}. Disabling abstraction.")
            self.abstraction_enabled = False
        except KeyError as e:
            logger.error(f"Missing key {e} in abstraction prompt file {self.abstraction_prompt_path}. Disabling abstraction.")
            self.abstraction_enabled = False
        except Exception as e:
            logger.error(f"Failed to load or parse abstraction prompts from {self.abstraction_prompt_path}: {e}. Disabling abstraction.", exc_info=True)
            self.abstraction_enabled = False
           
    def _log_final_abstraction_summary(self):
        """Logs a summary of the learned abstractions per island at the end of the run."""
        if not self.cfg.abstraction.enabled:
            return # Don't log if abstraction was never enabled
        
        logger.info("--- Final Abstraction Library Summary ---")
        if not hasattr(self, 'programs_db') or not hasattr(self.programs_db, '_abstraction_libraries'):
            logger.warning("ProgramsDatabase or abstraction libraries not initialized. Cannot log summary.")
            logger.info("--- End Summary --- ")
            return
        
        num_islands = self.cfg.programs_database.num_islands
        total_abstractions = 0
        for island_id in range(num_islands):
            library = self.programs_db._abstraction_libraries.get(island_id)
            if library and len(library) > 0:
                logger.info(f"Island {island_id}: Found {len(library)} abstractions:")
                sorted_abs = sorted(library.get_abstractions(), key=lambda x: x.name)
                for abs_obj in sorted_abs:
                    logger.info(f"  - {abs_obj.name}{abs_obj.signature}: {abs_obj.description}")
                total_abstractions += len(library)
            else:
                logger.info(f"Island {island_id}: Found 0 abstractions.")
                
        logger.info(f"Total abstractions found across all islands: {total_abstractions}")
        logger.info("--- End Summary --- ")

    #Note(_; 4/16): Could be moved to a different file
    def perform_abstraction_step(self):
        """Orchestrates the abstraction process, including updating files."""
        if not self.abstraction_enabled:
            logger.warning("perform_abstraction_step called but abstraction is disabled.")
            return
        if not self.abstraction_sampler or not self.abstractions_dir:
             logger.error("Abstraction enabled, but sampler or abstractions_dir not initialized.")
             return
            
        num_islands = self.cfg.programs_database.num_islands
        num_to_sample = self.abstraction_programs_to_sample
        # Get abstraction timeout from config, default to a reasonable value (e.g., 120 seconds)
        abstraction_task_timeout = self.cfg.abstraction.get('task_timeout_seconds', 120.0)

        def run_abstraction_for_island_task(island_id: int):
            """Runs the abstraction generation task and UPDATES the file."""
            logger.info(f"[Island {island_id}] Starting abstraction task (timeout: {abstraction_task_timeout}s)...")
            try:
                # 1. Get top programs
                top_programs = self.programs_db.get_top_programs_for_island(island_id, num_to_sample)
                if not top_programs:
                    logger.warning(f"[Island {island_id}] No programs found for abstraction. Skipping.")
                    return

                # 2. Get current library state for context
                if not hasattr(self.programs_db, '_abstraction_libraries') or not isinstance(self.programs_db._abstraction_libraries, dict):
                     logger.error(f"[Island {island_id}] programs_db._abstraction_libraries is missing or not a dict!")
                     current_library = None
                else:
                     logger.info(f"[Island {island_id}] Checking library. Keys available: {list(self.programs_db._abstraction_libraries.keys())}")
                     current_library = self.programs_db._abstraction_libraries.get(island_id)
                     logger.info(f"[Island {island_id}] Library object retrieved: type={type(current_library)}, value='{current_library}'")

                current_abstractions_str = "(None)"
                if current_library is not None:
                    current_abstractions_str = current_library.format_for_sampler_prompt() if len(current_library) > 0 else "(None)"
                else:
                    logger.warning(f"[Island {island_id}] Could not find abstraction library object for this island ID in database.")


                # 3. Call the Abstraction Sampler *WITH TIMEOUT*
                new_abstractions = [] # Default to empty list
                try:
                    # Wrap the potentially hanging call in asyncio.wait_for
                    new_abstractions = asyncio.run(
                        self.abstraction_sampler.generate_abstractions(
                            top_programs=top_programs,
                            island_id=island_id,
                            current_abstractions_str=current_abstractions_str
                        )
                    )
                except asyncio.TimeoutError:
                     logger.error(f"[Island {island_id}] Abstraction generation timed out after {abstraction_task_timeout} seconds.")
                     # new_abstractions remains empty, so the file/DB won't be updated below

                # 4. Update the Database AND the abstraction file
                if new_abstractions:
                    logger.info(f"[Island {island_id}] Found {len(new_abstractions)} new abstractions.")
                    # Update DB library first
                    self.programs_db.update_abstraction_library(island_id, new_abstractions)

                    # Get updated file content from DB library
                    # Pass include_imports=True to ensure the file is self-contained
                    updated_content = self.programs_db.get_abstraction_library_content(island_id, include_imports=True)

                    if updated_content is not None:
                        # Overwrite the stable file
                        file_path = self.abstractions_dir / f'island_{island_id}_abstractions.py'
                        try:
                             with open(file_path, 'w') as f:
                                 f.write(updated_content)
                             logger.info(f"[Island {island_id}] Updated abstraction file: {file_path}")
                        except Exception as write_e:
                             logger.error(f"[Island {island_id}] Failed to write updated abstraction file {file_path}: {write_e}")
                    else:
                         logger.error(f"[Island {island_id}] Failed to get updated content from library after update.")
                elif not top_programs: # Added condition to avoid logging "no new abstractions" when there were no inputs
                    pass # Already logged skipping due to no programs
                else: # Log only if generation was attempted but yielded nothing or timed out
                    logger.info(f"[Island {island_id}] No valid new abstractions found or generation timed out.")


            except Exception as e:
                 logger.error(f"[Island {island_id}] Error during abstraction task (outside generation call): {e}", exc_info=True)
            finally:
                 logger.info(f"[Island {island_id}] Abstraction task finished.")

        def run_all_tasks():
             tasks = [run_abstraction_for_island_task(island_id) for island_id in range(num_islands)]
             asyncio.run(asyncio.gather(*tasks))

        logger.info(f"Starting concurrent abstraction step for {num_islands} islands...")
        start_time = time.time()
        try:
             run_all_tasks()
        except Exception as e:
            logger.error(f"Error running concurrent abstraction tasks: {e}", exc_info=True)
        end_time = time.time()
        logger.info(f"Concurrent abstraction step finished in {end_time - start_time:.2f} seconds.")

@hydra.main(config_path="../configs", config_name="funsearch-mid")
def main(cfg: DictConfig):
    """
    Main entry point for the FunSearch process.
    
    Args:
        cfg: Hydra configuration object
    """
    # Print the configuration
    print("--- FunSearch main() started ---", flush=True) # Add immediate print
    print(omegaconf.OmegaConf.to_yaml(cfg))
    
    # Initialize the FunSearch manager
    manager = FunSearchManager(cfg)
    
    # Run the FunSearch process
    manager.run(iterations=cfg.get("iterations", 10))

if __name__ == "__main__":
    main() 