import json
import logging
import sqlite3
import time
from dataclasses import asdict, dataclass, field
from functools import wraps
from pathlib import Path
import random
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union
import math
from .complexity import analyze_code_metrics
from .parents import CombinedParentSelector
from .inspirations import CombinedContextSelector
from .islands import CombinedIslandManager
from .display import DatabaseDisplay
from shinka.llm.embedding import EmbeddingClient

logger = logging.getLogger(__name__)


def clean_nan_values(obj: Any) -> Any:
    
    if isinstance(obj, dict):
        return {key: clean_nan_values(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [clean_nan_values(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(clean_nan_values(item) for item in obj)
    elif isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
        return None
    elif isinstance(obj, np.floating) and (np.isnan(obj) or np.isinf(obj)):
        return None
    elif hasattr(obj, "dtype") and np.issubdtype(obj.dtype, np.floating):
        
        if np.isscalar(obj):
            if np.isnan(obj) or np.isinf(obj):
                return None
            else:
                return float(obj)
        else:
            
            return clean_nan_values(obj.tolist())
    else:
        return obj


@dataclass
class DatabaseConfig:
    db_path: Optional[str] = None
    num_islands: int = 4
    archive_size: int = 100

    
    elite_selection_ratio: float = 0.3  
    num_archive_inspirations: int = 5  
    num_top_k_inspirations: int = 2  

    
    migration_interval: int = 10  
    migration_rate: float = 0.1  
    island_elitism: bool = True  
    enforce_island_separation: bool = (
        True  
    )

    
    parent_selection_strategy: str = (
        "power_law"  
    )

    
    exploitation_alpha: float = 1.0  
    exploitation_ratio: float = 0.2  

    
    parent_selection_lambda: float = 10.0  

    
    num_beams: int = 5


def db_retry(max_retries=5, initial_delay=0.1, backoff_factor=2):
    

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            delay = initial_delay
            for i in range(max_retries):
                try:
                    return func(*args, **kwargs)
                except (
                    sqlite3.OperationalError,
                    sqlite3.DatabaseError,
                    sqlite3.IntegrityError,
                ) as e:
                    if i == max_retries - 1:
                        logger.error(
                            f"DB operation {func.__name__} failed after "
                            f"{max_retries} retries: {e}"
                        )
                        raise
                    logger.warning(
                        f"DB operation {func.__name__} failed with "
                        f"{type(e).__name__}: {e}. "
                        f"Retrying in {delay:.2f}s..."
                    )
                    time.sleep(delay)
                    delay *= backoff_factor
            
            raise RuntimeError(
                f"DB retry logic failed for function {func.__name__} without "
                "raising an exception."
            )

        return wrapper

    return decorator


@dataclass
class Program:
    

    
    id: str
    code: str
    language: str = "python"

    
    parent_id: Optional[str] = None
    archive_inspiration_ids: List[str] = field(
        default_factory=list
    )  
    top_k_inspiration_ids: List[str] = field(
        default_factory=list
    )  
    island_idx: Optional[int] = None
    generation: int = 0
    timestamp: float = field(default_factory=time.time)
    code_diff: Optional[str] = None

    
    combined_score: float = 0.0
    public_metrics: Dict[str, Any] = field(default_factory=dict)
    private_metrics: Dict[str, Any] = field(default_factory=dict)
    text_feedback: Union[str, List[str]] = ""
    correct: bool = False  
    children_count: int = 0

    
    complexity: float = 0.0  
    embedding: List[float] = field(default_factory=list)
    embedding_pca_2d: List[float] = field(default_factory=list)
    embedding_pca_3d: List[float] = field(default_factory=list)
    embedding_cluster_id: Optional[int] = None

    
    migration_history: List[Dict[str, Any]] = field(default_factory=list)

    
    metadata: Dict[str, Any] = field(default_factory=dict)

    
    in_archive: bool = False

    def to_dict(self) -> Dict[str, Any]:
        
        data = asdict(self)
        return clean_nan_values(data)

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Program":
        
        
        
        data["public_metrics"] = (
            data.get("public_metrics")
            if isinstance(data.get("public_metrics"), dict)
            else {}
        )
        data["private_metrics"] = (
            data.get("private_metrics")
            if isinstance(data.get("private_metrics"), dict)
            else {}
        )
        data["metadata"] = (
            data.get("metadata") if isinstance(data.get("metadata"), dict) else {}
        )
        
        archive_ids_val = data.get("archive_inspiration_ids")
        if isinstance(archive_ids_val, list):
            data["archive_inspiration_ids"] = archive_ids_val
        else:
            data["archive_inspiration_ids"] = []

        top_k_ids_val = data.get("top_k_inspiration_ids")
        if isinstance(top_k_ids_val, list):
            data["top_k_inspiration_ids"] = top_k_ids_val
        else:
            data["top_k_inspiration_ids"] = []

        
        embedding_val = data.get("embedding")
        if isinstance(embedding_val, list):
            data["embedding"] = embedding_val
        else:
            data["embedding"] = []

        embedding_pca_2d_val = data.get("embedding_pca_2d")
        if isinstance(embedding_pca_2d_val, list):
            data["embedding_pca_2d"] = embedding_pca_2d_val
        else:
            data["embedding_pca_2d"] = []

        embedding_pca_3d_val = data.get("embedding_pca_3d")
        if isinstance(embedding_pca_3d_val, list):
            data["embedding_pca_3d"] = embedding_pca_3d_val
        else:
            data["embedding_pca_3d"] = []

        
        migration_history_val = data.get("migration_history")
        if isinstance(migration_history_val, list):
            data["migration_history"] = migration_history_val
        else:
            data["migration_history"] = []

        
        program_fields = {f.name for f in cls.__dataclass_fields__.values()}
        filtered_data = {k: v for k, v in data.items() if k in program_fields}

        return cls(**filtered_data)


class ProgramDatabase:
    

    def __init__(self, config: DatabaseConfig, read_only: bool = False):
        self.config = config
        self.conn: Optional[sqlite3.Connection] = None
        self.cursor: Optional[sqlite3.Cursor] = None
        self.read_only = read_only
        self.embedding_client = EmbeddingClient()

        self.last_iteration: int = 0
        self.best_program_id: Optional[str] = None
        self.beam_search_parent_id: Optional[str] = None
        
        self._schedule_migration: bool = False

        
        self.island_manager: Optional[CombinedIslandManager] = None

        db_path_str = getattr(self.config, "db_path", None)

        if db_path_str:
            db_file = Path(db_path_str).resolve()
            if not read_only:
                
                db_wal_file = Path(f"{db_file}-wal")
                db_shm_file = Path(f"{db_file}-shm")
                if (
                    db_file.exists()
                    and db_file.stat().st_size == 0
                    and (db_wal_file.exists() or db_shm_file.exists())
                ):
                    logger.warning(
                        f"Database file {db_file} is empty but WAL/SHM files "
                        "exist. This may indicate an unclean shutdown. "
                        "Removing WAL/SHM files to attempt recovery."
                    )
                    if db_wal_file.exists():
                        db_wal_file.unlink()
                    if db_shm_file.exists():
                        db_shm_file.unlink()
                db_file.parent.mkdir(parents=True, exist_ok=True)
                self.conn = sqlite3.connect(str(db_file), timeout=30.0)
                logger.debug(f"Connected to SQLite database: {db_file}")
            else:
                if not db_file.exists():
                    raise FileNotFoundError(
                        f"Database file not found for read-only connection: {db_file}"
                    )
                db_uri = f"file:{db_file}?mode=ro"
                self.conn = sqlite3.connect(db_uri, uri=True, timeout=30.0)
                logger.debug(
                    "Connected to SQLite database in read-only mode: %s",
                    db_file,
                )
        else:
            self.conn = sqlite3.connect(":memory:")
            logger.info("Initialized in-memory SQLite database.")

        self.conn.row_factory = sqlite3.Row
        self.cursor = self.conn.cursor()
        if not self.read_only:
            self._create_tables()
        self._load_metadata_from_db()

        
        self.island_manager = CombinedIslandManager(
            cursor=self.cursor,
            conn=self.conn,
            config=self.config,
        )

        count = self._count_programs_in_db()
        logger.debug(f"DB initialized with {count} programs.")
        logger.debug(
            f"Last iter: {self.last_iteration}. Best ID: {self.best_program_id}"
        )

    def _create_tables(self):
        if not self.cursor or not self.conn:
            raise ConnectionError("DB not connected.")

        
        
        self.cursor.execute("PRAGMA journal_mode = WAL;")
        self.cursor.execute("PRAGMA busy_timeout = 30000;")  
        self.cursor.execute(
            "PRAGMA wal_autocheckpoint = 1000;"
        )  
        self.cursor.execute("PRAGMA synchronous = NORMAL;")  
        self.cursor.execute("PRAGMA cache_size = -64000;")  
        self.cursor.execute("PRAGMA temp_store = MEMORY;")
        self.cursor.execute("PRAGMA foreign_keys = ON;")  

        self.cursor.execute(
            """
            CREATE TABLE IF NOT EXISTS programs (
                id TEXT PRIMARY KEY,
                code TEXT NOT NULL,
                language TEXT NOT NULL,
                parent_id TEXT,
                archive_inspiration_ids TEXT,  -- JSON serialized List[str]
                top_k_inspiration_ids TEXT,    -- JSON serialized List[str]
                generation INTEGER NOT NULL,
                timestamp REAL NOT NULL,
                code_diff TEXT,     -- Stores edit difference
                combined_score REAL,
                public_metrics TEXT, -- JSON serialized Dict[str, Any]
                private_metrics TEXT, -- JSON serialized Dict[str, Any]
                text_feedback TEXT, -- Text feedback for the program
                complexity REAL,   -- Calculated complexity metric
                embedding TEXT,    -- JSON serialized List[float]
                embedding_pca_2d TEXT, -- JSON serialized List[float]
                embedding_pca_3d TEXT, -- JSON serialized List[float]
                embedding_cluster_id INTEGER,
                correct BOOLEAN DEFAULT 0,  -- Correct (0=False, 1=True)
                children_count INTEGER NOT NULL DEFAULT 0,
                metadata TEXT,      -- JSON serialized Dict[str, Any]
                migration_history TEXT, -- JSON of migration events
                island_idx INTEGER  -- Add island_idx to the schema
            )
            """
        )

        
        idx_cmds = [
            "CREATE INDEX IF NOT EXISTS idx_programs_generation ON "
            "programs(generation)",
            "CREATE INDEX IF NOT EXISTS idx_programs_timestamp ON programs(timestamp)",
            "CREATE INDEX IF NOT EXISTS idx_programs_complexity ON "
            "programs(complexity)",
            "CREATE INDEX IF NOT EXISTS idx_programs_parent_id ON programs(parent_id)",
            "CREATE INDEX IF NOT EXISTS idx_programs_children_count ON "
            "programs(children_count)",
            "CREATE INDEX IF NOT EXISTS idx_programs_island_idx ON "
            "programs(island_idx)",
        ]
        for cmd in idx_cmds:
            self.cursor.execute(cmd)

        self.cursor.execute(
            """
            CREATE TABLE IF NOT EXISTS archive (
                program_id TEXT PRIMARY KEY,
                FOREIGN KEY (program_id) REFERENCES programs(id)
                    ON DELETE CASCADE
            )
            """
        )

        self.cursor.execute(
            """
            CREATE TABLE IF NOT EXISTS metadata_store (
                key TEXT PRIMARY KEY, value TEXT
            )
            """
        )

        self.conn.commit()

        
        self._run_migrations()

        logger.debug("Database tables and indices ensured to exist.")

    def _run_migrations(self):
        
        if not self.cursor or not self.conn:
            raise ConnectionError("DB not connected.")

        
        try:
            
            self.cursor.execute("PRAGMA table_info(programs)")
            columns = [row[1] for row in self.cursor.fetchall()]

            if "text_feedback" not in columns:
                logger.info("Adding text_feedback column to programs table")
                self.cursor.execute(
                    "ALTER TABLE programs ADD COLUMN text_feedback TEXT DEFAULT ''"
                )
                self.conn.commit()
                logger.info("Successfully added text_feedback column")
        except sqlite3.Error as e:
            logger.error(f"Error during text_feedback migration: {e}")
            

    @db_retry()
    def _load_metadata_from_db(self):
        if not self.cursor:
            raise ConnectionError("DB cursor not available.")

        self.cursor.execute(
            "SELECT value FROM metadata_store WHERE key = 'last_iteration'"
        )
        row = self.cursor.fetchone()
        self.last_iteration = (
            int(row["value"]) if row and row["value"] is not None else 0
        )
        if not row or row["value"] is not None:  
            if not self.read_only:
                self._update_metadata_in_db("last_iteration", str(self.last_iteration))

        self.cursor.execute(
            "SELECT value FROM metadata_store WHERE key = 'best_program_id'"
        )
        row = self.cursor.fetchone()
        self.best_program_id = (
            str(row["value"])
            if row and row["value"] is not None and row["value"] != "None"
            else None
        )
        if (
            not row or row["value"] is None or row["value"] == "None"
        ):  
            if not self.read_only:
                self._update_metadata_in_db("best_program_id", None)

        self.cursor.execute(
            "SELECT value FROM metadata_store WHERE key = 'beam_search_parent_id'"
        )
        row = self.cursor.fetchone()
        self.beam_search_parent_id = (
            str(row["value"])
            if row and row["value"] is not None and row["value"] != "None"
            else None
        )
        if not row or row["value"] is None or row["value"] == "None":
            if not self.read_only:
                self._update_metadata_in_db("beam_search_parent_id", None)

    @db_retry()
    def _update_metadata_in_db(self, key: str, value: Optional[str]):
        if not self.cursor or not self.conn:
            raise ConnectionError("DB not connected.")
        self.cursor.execute(
            "INSERT OR REPLACE INTO metadata_store (key, value) VALUES (?, ?)",
            (key, value),  
        )
        self.conn.commit()

    @db_retry()
    def _count_programs_in_db(self) -> int:
        if not self.cursor:
            return 0
        self.cursor.execute("SELECT COUNT(*) FROM programs")
        return (self.cursor.fetchone() or {"COUNT(*)": 0})["COUNT(*)"]

    @db_retry()
    def add(self, program: Program, verbose: bool = False) -> str:
        
        if self.read_only:
            raise PermissionError("Cannot add program in read-only mode.")
        if not self.cursor or not self.conn:
            raise ConnectionError("DB not connected.")

        self.island_manager.assign_island(program)

        
        if program.complexity == 0.0:
            try:
                code_metrics = analyze_code_metrics(program.code, program.language)
                program.complexity = code_metrics.get("complexity_score", 0.0)
                if program.metadata is None:
                    program.metadata = {}
                program.metadata["code_analysis_metrics"] = code_metrics
            except Exception as e:
                logger.warning(
                    f"Could not calculate complexity for program {program.id}: {e}"
                )
                program.complexity = float(len(program.code))  

        
        
        if not isinstance(program.embedding, list):
            logger.warning(
                f"Program {program.id} embedding is not a list, "
                "defaulting to empty list."
            )
            program.embedding = []

        
        public_metrics_json = json.dumps(program.public_metrics or {})
        private_metrics_json = json.dumps(program.private_metrics or {})
        metadata_json = json.dumps(program.metadata or {})
        archive_insp_ids_json = json.dumps(program.archive_inspiration_ids or [])
        top_k_insp_ids_json = json.dumps(program.top_k_inspiration_ids or [])
        embedding_json = json.dumps(program.embedding)  
        embedding_pca_2d_json = json.dumps(program.embedding_pca_2d or [])
        embedding_pca_3d_json = json.dumps(program.embedding_pca_3d or [])
        migration_history_json = json.dumps(program.migration_history or [])

        
        text_feedback_str = program.text_feedback
        if isinstance(text_feedback_str, list):
            
            text_feedback_str = "\n".join(str(item) for item in text_feedback_str)
        elif text_feedback_str is None:
            text_feedback_str = ""
        else:
            text_feedback_str = str(text_feedback_str)

        
        self.conn.execute("BEGIN TRANSACTION")

        try:
            
            self.cursor.execute(
                """
                INSERT INTO programs
                   (id, code, language, parent_id, archive_inspiration_ids,
                    top_k_inspiration_ids, generation, timestamp, code_diff,
                    combined_score, public_metrics, private_metrics,
                    text_feedback, complexity, embedding, embedding_pca_2d,
                    embedding_pca_3d, embedding_cluster_id, correct,
                    children_count, metadata, island_idx, migration_history)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
                           ?, ?, ?, ?, ?, ?)
                """,
                (
                    program.id,
                    program.code,
                    program.language,
                    program.parent_id,
                    archive_insp_ids_json,
                    top_k_insp_ids_json,
                    program.generation,
                    program.timestamp,
                    program.code_diff,
                    program.combined_score,
                    public_metrics_json,
                    private_metrics_json,
                    text_feedback_str,
                    program.complexity,
                    embedding_json,  
                    embedding_pca_2d_json,
                    embedding_pca_3d_json,
                    program.embedding_cluster_id,
                    program.correct,
                    program.children_count,
                    metadata_json,
                    program.island_idx,
                    migration_history_json,
                ),
            )

            
            if program.parent_id:
                self.cursor.execute(
                    "UPDATE programs SET children_count = children_count + 1 "
                    "WHERE id = ?",
                    (program.parent_id,),
                )

            
            self.conn.commit()
            logger.info(
                "Program %s added to DB - score: %s.",
                program.id,
                program.combined_score,
            )

        except sqlite3.IntegrityError as e:
            self.conn.rollback()
            logger.error(f"IntegrityError for program {program.id}: {e}")
            raise
        except Exception as e:
            self.conn.rollback()
            logger.error(f"Error adding program {program.id}: {e}")
            raise

        self._update_archive(program)

        
        self._update_best_program(program)

        
        self._recompute_embeddings_and_clusters()

        
        if program.generation > self.last_iteration:
            self.last_iteration = program.generation
            self._update_metadata_in_db("last_iteration", str(self.last_iteration))

        
        if verbose:
            self._print_program_summary(program)

        
        if self.island_manager.needs_island_copies(program):
            logger.info(
                f"Creating copies of initial program {program.id} for all islands"
            )
            self.island_manager.copy_program_to_islands(program)
            
            if program.metadata:
                program.metadata.pop("_needs_island_copies", None)
                metadata_json = json.dumps(program.metadata)
                self.cursor.execute(
                    "UPDATE programs SET metadata = ? WHERE id = ?",
                    (metadata_json, program.id),
                )
                self.conn.commit()

        
        if self.island_manager.should_schedule_migration(program):
            self._schedule_migration = True

        self.check_scheduled_operations()
        return program.id

    def _program_from_row(self, row: sqlite3.Row) -> Optional[Program]:
        
        if not row:
            return None

        program_data = dict(row)

        
        public_metrics_text = program_data.get("public_metrics")
        if public_metrics_text:
            try:
                program_data["public_metrics"] = json.loads(public_metrics_text)
            except json.JSONDecodeError:
                program_data["public_metrics"] = {}
        else:
            program_data["public_metrics"] = {}

        private_metrics_text = program_data.get("private_metrics")
        if private_metrics_text:
            try:
                program_data["private_metrics"] = json.loads(private_metrics_text)
            except json.JSONDecodeError:
                program_data["private_metrics"] = {}
        else:
            program_data["private_metrics"] = {}

        
        metadata_text = program_data.get("metadata")
        if metadata_text:
            try:
                program_data["metadata"] = json.loads(metadata_text)
            except json.JSONDecodeError:
                program_data["metadata"] = {}
        else:
            program_data["metadata"] = {}

        
        if "text_feedback" not in program_data or program_data["text_feedback"] is None:
            program_data["text_feedback"] = ""

        
        archive_insp_ids_text = program_data.get("archive_inspiration_ids")
        if archive_insp_ids_text:
            try:
                program_data["archive_inspiration_ids"] = json.loads(
                    archive_insp_ids_text
                )
            except json.JSONDecodeError:
                program_data["archive_inspiration_ids"] = []
        else:
            program_data["archive_inspiration_ids"] = []

        top_k_insp_ids_text = program_data.get("top_k_inspiration_ids")
        if top_k_insp_ids_text:
            try:
                program_data["top_k_inspiration_ids"] = json.loads(top_k_insp_ids_text)
            except json.JSONDecodeError:
                logger.warning(
                    "Could not decode top_k_inspiration_ids for "
                    f"program {program_data.get('id')}. "
                    "Defaulting to empty list."
                )
                program_data["top_k_inspiration_ids"] = []
        else:
            program_data["top_k_inspiration_ids"] = []

        
        embedding_text = program_data.get("embedding")
        if embedding_text:
            try:
                program_data["embedding"] = json.loads(embedding_text)
            except json.JSONDecodeError:
                logger.warning(
                    f"Could not decode embedding for program "
                    f"{program_data.get('id')}. Defaulting to empty list."
                )
                program_data["embedding"] = []
        else:
            program_data["embedding"] = []

        embedding_pca_2d_text = program_data.get("embedding_pca_2d")
        if embedding_pca_2d_text:
            try:
                program_data["embedding_pca_2d"] = json.loads(embedding_pca_2d_text)
            except json.JSONDecodeError:
                program_data["embedding_pca_2d"] = []
        else:
            program_data["embedding_pca_2d"] = []

        embedding_pca_3d_text = program_data.get("embedding_pca_3d")
        if embedding_pca_3d_text:
            try:
                program_data["embedding_pca_3d"] = json.loads(embedding_pca_3d_text)
            except json.JSONDecodeError:
                program_data["embedding_pca_3d"] = []
        else:
            program_data["embedding_pca_3d"] = []

        
        migration_history_text = program_data.get("migration_history")
        if migration_history_text:
            try:
                program_data["migration_history"] = json.loads(migration_history_text)
            except json.JSONDecodeError:
                logger.warning(
                    f"Could not decode migration_history for program "
                    f"{program_data.get('id')}. Defaulting to empty list."
                )
                program_data["migration_history"] = []
        else:
            program_data["migration_history"] = []

        
        program_data["in_archive"] = bool(program_data.get("in_archive", 0))

        return Program.from_dict(program_data)

    @db_retry()
    def get(self, program_id: str) -> Optional[Program]:
        
        if not self.cursor:
            raise ConnectionError("DB not connected.")
        self.cursor.execute("SELECT * FROM programs WHERE id = ?", (program_id,))
        row = self.cursor.fetchone()
        return self._program_from_row(row)

    @db_retry()
    def sample(
        self,
        target_generation=None,
        novelty_attempt=None,
        max_novelty_attempts=None,
        resample_attempt=None,
        max_resample_attempts=None,
    ) -> Tuple[Program, List[Program], List[Program]]:
        if not self.cursor:
            raise ConnectionError("DB not connected.")

        
        if not self.island_manager.are_all_islands_initialized():
            
            self.cursor.execute("SELECT * FROM programs ORDER BY timestamp ASC LIMIT 1")
            row = self.cursor.fetchone()
            if not row:
                raise RuntimeError("No programs found in database")

            parent = self._program_from_row(row)
            if not parent:
                raise RuntimeError("Failed to load initial program")

            logger.info(
                f"Not all islands initialized. Using initial program {parent.id} "
                "without inspirations."
            )

            
            self._print_sampling_summary_helper(
                parent,
                [],
                [],
                target_generation,
                novelty_attempt,
                max_novelty_attempts,
                resample_attempt,
                max_resample_attempts,
            )

            return parent, [], []

        
        initialized_islands = self.island_manager.get_initialized_islands()
        sampled_island = random.choice(initialized_islands)

        logger.debug(f"Sampling from island {sampled_island}")

        
        parent_selector = CombinedParentSelector(
            cursor=self.cursor,
            conn=self.conn,
            config=self.config,
            get_program_func=self.get,
            best_program_id=self.best_program_id,
            beam_search_parent_id=self.beam_search_parent_id,
            last_iteration=self.last_iteration,
            update_metadata_func=self._update_metadata_in_db,
            get_best_program_func=self.get_best_program,
        )

        parent = parent_selector.sample_parent(island_idx=sampled_island)
        if not parent:
            raise RuntimeError(f"Failed to sample parent from island {sampled_island}")

        num_archive_insp = (
            self.config.num_archive_inspirations
            if hasattr(self.config, "num_archive_inspirations")
            else 5
        )
        num_top_k_insp = (
            self.config.num_top_k_inspirations
            if hasattr(self.config, "num_top_k_inspirations")
            else 2
        )

        
        context_selector = CombinedContextSelector(
            cursor=self.cursor,
            conn=self.conn,
            config=self.config,
            get_program_func=self.get,
            best_program_id=self.best_program_id,
            get_island_idx_func=self.island_manager.get_island_idx,
            program_from_row_func=self._program_from_row,
        )

        archive_inspirations, top_k_inspirations = context_selector.sample_context(
            parent, num_archive_insp, num_top_k_insp
        )

        logger.debug(
            f"Sampled parent {parent.id} from island {sampled_island}, "
            f"{len(archive_inspirations)} archive inspirations, "
            f"{len(top_k_inspirations)} top-k inspirations."
        )

        
        self._print_sampling_summary_helper(
            parent,
            archive_inspirations,
            top_k_inspirations,
            target_generation,
            novelty_attempt,
            max_novelty_attempts,
            resample_attempt,
            max_resample_attempts,
        )

        return parent, archive_inspirations, top_k_inspirations

    def _print_sampling_summary_helper(
        self,
        parent,
        archive_inspirations,
        top_k_inspirations,
        target_generation=None,
        novelty_attempt=None,
        max_novelty_attempts=None,
        resample_attempt=None,
        max_resample_attempts=None,
    ):
        
        if not hasattr(self, "_database_display"):
            self._database_display = DatabaseDisplay(
                cursor=self.cursor,
                conn=self.conn,
                config=self.config,
                island_manager=self.island_manager,
                count_programs_func=self._count_programs_in_db,
                get_best_program_func=self.get_best_program,
            )

        self._database_display.print_sampling_summary(
            parent,
            archive_inspirations,
            top_k_inspirations,
            target_generation,
            novelty_attempt,
            max_novelty_attempts,
            resample_attempt,
            max_resample_attempts,
        )

    @db_retry()
    def get_best_program(self, metric: Optional[str] = None) -> Optional[Program]:
        if not self.cursor:
            raise ConnectionError("DB not connected.")

        
        if metric is None and self.best_program_id:
            program = self.get(self.best_program_id)
            if program and program.correct:  
                return program
            else:  
                logger.warning(
                    f"Tracked best_program_id '{self.best_program_id}' "
                    "not found or incorrect. Re-evaluating."
                )
                if not self.read_only:
                    self._update_metadata_in_db("best_program_id", None)
                self.best_program_id = None

        
        self.cursor.execute("SELECT * FROM programs WHERE correct = 1")
        all_rows = self.cursor.fetchall()
        if not all_rows:
            logger.debug("No correct programs found in database.")
            return None

        programs = []
        for row_data in all_rows:
            p_dict = dict(row_data)
            p_dict["public_metrics"] = (
                json.loads(p_dict["public_metrics"])
                if p_dict.get("public_metrics")
                else {}
            )
            p_dict["private_metrics"] = (
                json.loads(p_dict["private_metrics"])
                if p_dict.get("private_metrics")
                else {}
            )
            p_dict["metadata"] = (
                json.loads(p_dict["metadata"]) if p_dict.get("metadata") else {}
            )
            programs.append(Program.from_dict(p_dict))

        if not programs:
            return None

        sorted_p: List[Program] = []
        log_key = "average metrics"

        if metric:
            progs_with_metric = [
                p for p in programs if p.public_metrics and metric in p.public_metrics
            ]
            sorted_p = sorted(
                progs_with_metric,
                key=lambda p_item: p_item.public_metrics.get(metric, -float("inf")),
                reverse=True,
            )
            log_key = f"metric '{metric}'"
        elif any(p.combined_score is not None for p in programs):
            progs_with_cs = [p for p in programs if p.combined_score is not None]
            sorted_p = sorted(
                progs_with_cs,
                key=lambda p_item: p_item.combined_score or -float("inf"),
                reverse=True,
            )
            log_key = "combined_score"
        else:
            progs_with_metrics = [p for p in programs if p.public_metrics]
            sorted_p = sorted(
                progs_with_metrics,
                key=lambda p_item: sum(p_item.public_metrics.values())
                / len(p_item.public_metrics)
                if p_item.public_metrics
                else -float("inf"),
                reverse=True,
            )

        if not sorted_p:
            logger.debug("No correct programs matched criteria for get_best_program.")
            return None

        best_overall = sorted_p[0]
        logger.debug(f"Best correct program by {log_key}: {best_overall.id}")

        if self.best_program_id != best_overall.id:  
            logger.info(
                "Updating tracked best program from "
                f"'{self.best_program_id}' to '{best_overall.id}'."
            )
            self.best_program_id = best_overall.id
            if not self.read_only:
                self._update_metadata_in_db("best_program_id", self.best_program_id)
        return best_overall

    @db_retry()
    def get_all_programs(self) -> List[Program]:
        
        if not self.cursor:
            raise ConnectionError("DB not connected.")
        self.cursor.execute(
            """
            SELECT p.*,
                   CASE WHEN a.program_id IS NOT NULL THEN 1 ELSE 0 END as in_archive
            FROM programs p
            LEFT JOIN archive a ON p.id = a.program_id
            """
        )
        rows = self.cursor.fetchall()
        programs = [self._program_from_row(row) for row in rows]
        
        return [p for p in programs if p is not None]

    @db_retry()
    def get_programs_by_generation(self, generation: int) -> List[Program]:
        
        if not self.cursor:
            raise ConnectionError("DB not connected.")
        self.cursor.execute(
            "SELECT * FROM programs WHERE generation = ?", (generation,)
        )
        rows = self.cursor.fetchall()
        programs = [self._program_from_row(row) for row in rows]
        return [p for p in programs if p is not None]

    @db_retry()
    def get_top_programs(
        self,
        n: int = 10,
        metric: Optional[str] = "combined_score",
        correct_only: bool = False,
    ) -> List[Program]:
        
        if not self.cursor:
            raise ConnectionError("DB not connected.")

        
        correctness_filter = "WHERE correct = 1" if correct_only else ""

        
        if metric == "combined_score":
            
            base_query = """
                SELECT * FROM programs
                WHERE combined_score IS NOT NULL
            """
            if correct_only:
                base_query += " AND correct = 1"
            base_query += " ORDER BY combined_score DESC LIMIT ?"

            self.cursor.execute(base_query, (n,))
            all_rows = self.cursor.fetchall()
        elif metric == "timestamp":
            
            query = (
                f"SELECT * FROM programs {correctness_filter} "
                "ORDER BY timestamp DESC LIMIT ?"
            )
            self.cursor.execute(query, (n,))
            all_rows = self.cursor.fetchall()
        else:
            
            query = f"SELECT * FROM programs {correctness_filter}"
            self.cursor.execute(query)
            all_rows = self.cursor.fetchall()

        if not all_rows:
            return []

        
        programs = []
        for row_data in all_rows:
            p_dict = dict(row_data)

            
            public_metrics_text = p_dict.get("public_metrics")
            if public_metrics_text:
                try:
                    p_dict["public_metrics"] = json.loads(public_metrics_text)
                except json.JSONDecodeError:
                    p_dict["public_metrics"] = {}
            else:
                p_dict["public_metrics"] = {}

            private_metrics_text = p_dict.get("private_metrics")
            if private_metrics_text:
                try:
                    p_dict["private_metrics"] = json.loads(private_metrics_text)
                except json.JSONDecodeError:
                    p_dict["private_metrics"] = {}
            else:
                p_dict["private_metrics"] = {}

            metadata_text = p_dict.get("metadata")
            if metadata_text:
                try:
                    p_dict["metadata"] = json.loads(metadata_text)
                except json.JSONDecodeError:
                    p_dict["metadata"] = {}
            else:
                p_dict["metadata"] = {}

            
            programs.append(Program.from_dict(p_dict))

        
        if metric in ["combined_score", "timestamp"] and programs:
            return programs[:n]

        
        if programs:
            if metric:
                progs_with_metric = [
                    p
                    for p in programs
                    if p.public_metrics and metric in p.public_metrics
                ]
                sorted_p = sorted(
                    progs_with_metric,
                    key=lambda p_item: p_item.public_metrics.get(metric, -float("inf")),
                    reverse=True,
                )
            else:  
                progs_with_metrics = [p for p in programs if p.public_metrics]
                sorted_p = sorted(
                    progs_with_metrics,
                    key=lambda p_item: sum(p_item.public_metrics.values())
                    / len(p_item.public_metrics)
                    if p_item.public_metrics
                    else -float("inf"),
                    reverse=True,
                )

            return sorted_p[:n]

        return []

    def save(self, path: Optional[str] = None) -> None:
        if not self.conn or not self.cursor:
            logger.warning("No DB connection, skipping save.")
            return

        
        current_db_file_path_str = self.config.db_path
        if path and current_db_file_path_str:
            if Path(path).resolve() != Path(current_db_file_path_str).resolve():
                logger.warning(
                    f"Save path '{path}' differs from connected DB "
                    f"'{current_db_file_path_str}'. Metadata saved to "
                    "connected DB."
                )
        elif path and not current_db_file_path_str:
            logger.warning(
                f"Attempting to save with path '{path}' but current "
                "database is in-memory. Metadata will be committed to the "
                "in-memory instance."
            )

        self._update_metadata_in_db("last_iteration", str(self.last_iteration))

        self.conn.commit()  
        logger.info(
            f"Database state committed. Last iteration: "
            f"{self.last_iteration}. Best: {self.best_program_id}"
        )

    def load(self, path: str) -> None:
        logger.info(f"Loading database from '{path}'...")
        if self.conn:
            db_display_name = self.config.db_path or ":memory:"
            logger.info(f"Closing existing connection to '{db_display_name}'.")
            self.conn.close()

        db_path_obj = Path(path).resolve()
        
        db_wal_file = Path(f"{db_path_obj}-wal")
        db_shm_file = Path(f"{db_path_obj}-shm")
        if (
            db_path_obj.exists()
            and db_path_obj.stat().st_size == 0
            and (db_wal_file.exists() or db_shm_file.exists())
        ):
            logger.warning(
                f"Database file {db_path_obj} is empty but WAL/SHM files "
                "exist. This may indicate an unclean shutdown. Removing "
                "WAL/SHM files to attempt recovery.",
                db_path_obj,
            )
            if db_wal_file.exists():
                db_wal_file.unlink()
            if db_shm_file.exists():
                db_shm_file.unlink()

        self.config.db_path = str(db_path_obj)  

        if not db_path_obj.exists():
            logger.warning(
                f"DB file '{db_path_obj}' not found. New DB created if writes occur."
            )
            db_path_obj.parent.mkdir(parents=True, exist_ok=True)

        self.conn = sqlite3.connect(str(db_path_obj), timeout=30.0)
        self.conn.row_factory = sqlite3.Row
        self.cursor = self.conn.cursor()
        self._create_tables()
        self._load_metadata_from_db()

        count = self._count_programs_in_db()
        logger.info(
            f"Loaded DB from '{db_path_obj}'. {count} programs. "
            f"Last iter: {self.last_iteration}."
        )

    def _is_better(self, program1: Program, program2: Program) -> bool:
        
        if program1.correct and not program2.correct:
            return True
        if program2.correct and not program1.correct:
            return False

        
        s1 = program1.combined_score
        s2 = program2.combined_score

        if s1 is not None and s2 is not None:
            if s1 != s2:
                return s1 > s2
        elif s1 is not None:
            return True  
        elif s2 is not None:
            return False  

        try:
            avg1 = (
                sum(program1.public_metrics.values()) / len(program1.public_metrics)
                if program1.public_metrics
                else -float("inf")
            )
            avg2 = (
                sum(program2.public_metrics.values()) / len(program2.public_metrics)
                if program2.public_metrics
                else -float("inf")
            )
            if avg1 != avg2:
                return avg1 > avg2
        except Exception:
            return False
        return program1.timestamp > program2.timestamp  

    @db_retry()
    def _update_archive(self, program: Program) -> None:
        if (
            not self.cursor
            or not self.conn
            or not hasattr(self.config, "archive_size")
            or self.config.archive_size <= 0
        ):
            logger.debug("Archive update skipped (config/DB issue or size <= 0).")
            return

        
        if not program.correct:
            logger.debug(f"Program {program.id} not added to archive (not correct).")
            return

        self.cursor.execute("SELECT COUNT(*) FROM archive")
        count = (self.cursor.fetchone() or [0])[0]

        if count < self.config.archive_size:
            self.cursor.execute(
                "INSERT OR IGNORE INTO archive (program_id) VALUES (?)",
                (program.id,),
            )
        else:  
            self.cursor.execute(
                "SELECT a.program_id, p.combined_score, p.timestamp, p.correct "
                "FROM archive a JOIN programs p ON a.program_id = p.id"
            )
            archived_rows = self.cursor.fetchall()
            if not archived_rows:  
                self.cursor.execute(
                    "INSERT OR IGNORE INTO archive (program_id) VALUES (?)",
                    (program.id,),
                )
                self.conn.commit()
                return

            archive_programs_for_cmp = []
            for r_data in archived_rows:
                
                combined_score_val = r_data["combined_score"]
                
                
                archive_programs_for_cmp.append(
                    Program(
                        id=r_data["program_id"],
                        code="",
                        combined_score=combined_score_val,
                        timestamp=r_data["timestamp"],
                        correct=bool(r_data["correct"]),
                    )
                )

            if (
                not archive_programs_for_cmp
            ):  
                self.cursor.execute(
                    "INSERT OR IGNORE INTO archive (program_id) VALUES (?)",
                    (program.id,),
                )
                self.conn.commit()
                return

            worst_in_archive = archive_programs_for_cmp[0]
            for p_archived in archive_programs_for_cmp[1:]:
                if self._is_better(worst_in_archive, p_archived):
                    worst_in_archive = p_archived

            if self._is_better(program, worst_in_archive):
                self.cursor.execute(
                    "DELETE FROM archive WHERE program_id = ?",
                    (worst_in_archive.id,),
                )
                self.cursor.execute(
                    "INSERT INTO archive (program_id) VALUES (?)", (program.id,)
                )
                logger.info(
                    f"Program {program.id} replaced {worst_in_archive.id} in archive."
                )
        self.conn.commit()

    @db_retry()
    def _update_best_program(self, program: Program) -> None:
        
        if not program.correct:
            logger.debug(f"Program {program.id} not considered for best (not correct).")
            return

        current_best_p = None
        if self.best_program_id:
            current_best_p = self.get(self.best_program_id)

        if current_best_p is None or self._is_better(program, current_best_p):
            self.best_program_id = program.id
            self._update_metadata_in_db("best_program_id", self.best_program_id)

            log_msg = f"New best program: {program.id}"
            if current_best_p:
                p1_score = program.combined_score or 0.0
                p2_score = current_best_p.combined_score or 0.0
                log_msg += (
                    f" (gen: {current_best_p.generation} → {program.generation}, "
                    f"score: {p2_score:.4f} → {p1_score:.4f}, "
                    f"island: {current_best_p.island_idx} → {program.island_idx})"
                )
            else:
                score = program.combined_score or 0.0
                log_msg += (
                    f" (gen: {program.generation}, score: {score:.4f}, initialized "
                    f"island: {program.island_idx})."
                )
            logger.info(log_msg)

    def print_summary(self, console=None) -> None:
        
        if not hasattr(self, "_database_display"):
            self._database_display = DatabaseDisplay(
                cursor=self.cursor,
                conn=self.conn,
                config=self.config,
                island_manager=self.island_manager,
                count_programs_func=self._count_programs_in_db,
                get_best_program_func=self.get_best_program,
            )
            self._database_display.set_last_iteration(self.last_iteration)

        self._database_display.print_summary(console)

    def _print_program_summary(self, program) -> None:
        
        if not hasattr(self, "_database_display"):
            self._database_display = DatabaseDisplay(
                cursor=self.cursor,
                conn=self.conn,
                config=self.config,
                island_manager=self.island_manager,
                count_programs_func=self._count_programs_in_db,
                get_best_program_func=self.get_best_program,
            )

        self._database_display.print_program_summary(program)

    def check_scheduled_operations(self):
        
        if self._schedule_migration:
            logger.info("Running scheduled migration operation")
            self.island_manager.perform_migration(self.last_iteration)
            self._schedule_migration = False

    def close(self):
        
        if self.conn:
            self.conn.close()

    def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
        
        if not vec1 or not vec2 or len(vec1) != len(vec2):
            return 0.0

        arr1 = np.array(vec1, dtype=np.float32)
        arr2 = np.array(vec2, dtype=np.float32)

        norm_a = np.linalg.norm(arr1)
        norm_b = np.linalg.norm(arr2)

        if norm_a == 0 or norm_b == 0:
            return 0.0

        similarity = np.dot(arr1, arr2) / (norm_a * norm_b)
        return float(similarity)

    @db_retry()
    def compute_similarity_thread_safe(
        self, vec: List[float], island_idx: int
    ) -> List[float]:
        
        conn = None
        try:
            
            conn = sqlite3.connect(
                self.config.db_path, check_same_thread=False, timeout=60.0
            )
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()

            cursor.execute(
                "SELECT embedding FROM programs WHERE island_idx = ? AND embedding IS NOT NULL AND embedding != '[]'",
                (island_idx,),
            )
            rows = cursor.fetchall()

            if not rows:
                return []

            similarities = []
            for row in rows:
                db_embedding = json.loads(row["embedding"])
                if db_embedding:
                    sim = self._cosine_similarity(vec, db_embedding)
                    similarities.append(sim)
            return similarities

        except Exception as e:
            logger.error(f"Thread-safe similarity computation failed: {e}")
            raise
        finally:
            if conn:
                conn.close()

    @db_retry()
    def compute_similarity(
        self, code_embedding: List[float], island_idx: int
    ) -> List[float]:
        
        if not self.cursor:
            raise ConnectionError("DB not connected.")

        if not code_embedding:
            logger.warning("Empty code embedding provided to compute_similarity")
            return []

        
        self.cursor.execute(
            """
            SELECT id, embedding FROM programs 
            WHERE island_idx = ? AND embedding IS NOT NULL AND embedding != '[]'
            """,
            (island_idx,),
        )
        rows = self.cursor.fetchall()

        if not rows:
            logger.debug(f"No programs with embeddings found in island {island_idx}")
            return []

        
        similarity_scores = []
        for row in rows:
            try:
                embedding = json.loads(row["embedding"])
                if embedding:  
                    similarity = self._cosine_similarity(code_embedding, embedding)
                    similarity_scores.append(similarity)
                else:
                    similarity_scores.append(0.0)
            except json.JSONDecodeError:
                logger.warning(f"Could not decode embedding for program {row['id']}")
                similarity_scores.append(0.0)
                continue

        logger.debug(
            f"Computed {len(similarity_scores)} similarity scores for "
            f"island {island_idx}"
        )
        return similarity_scores

    @db_retry()
    def get_most_similar_program(
        self, code_embedding: List[float], island_idx: int
    ) -> Optional[Program]:
        
        if not self.cursor:
            raise ConnectionError("DB not connected.")

        if not code_embedding:
            logger.warning("Empty code embedding provided to get_most_similar_program")
            return None

        
        self.cursor.execute(
            """
            SELECT id, embedding FROM programs 
            WHERE island_idx = ? AND embedding IS NOT NULL AND embedding != '[]'
            """,
            (island_idx,),
        )
        rows = self.cursor.fetchall()

        if not rows:
            logger.debug(f"No programs with embeddings found in island {island_idx}")
            return None

        
        max_similarity = -1.0
        most_similar_id = None

        for row in rows:
            try:
                embedding = json.loads(row["embedding"])
                if embedding:  
                    similarity = self._cosine_similarity(code_embedding, embedding)
                    if similarity > max_similarity:
                        max_similarity = similarity
                        most_similar_id = row["id"]
            except json.JSONDecodeError:
                logger.warning(f"Could not decode embedding for program {row['id']}")
                continue

        if most_similar_id:
            return self.get(most_similar_id)
        return None

    @db_retry()
    def get_most_similar_program_thread_safe(
        self, code_embedding: List[float], island_idx: int
    ) -> Optional[Program]:
        
        if not code_embedding:
            logger.warning(
                "Empty code embedding provided to get_most_similar_program_thread_safe"
            )
            return None

        conn = None
        try:
            
            conn = sqlite3.connect(
                self.config.db_path, check_same_thread=False, timeout=60.0
            )
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()

            
            cursor.execute(
                """
                SELECT id, embedding FROM programs 
                WHERE island_idx = ? AND embedding IS NOT NULL AND embedding != '[]'
                """,
                (island_idx,),
            )

            rows = cursor.fetchall()
            if not rows:
                return None

            
            import numpy as np

            similarities = []
            program_ids = []

            for row in rows:
                try:
                    embedding = json.loads(row["embedding"])
                    if embedding:  
                        similarity = np.dot(code_embedding, embedding) / (
                            np.linalg.norm(code_embedding) * np.linalg.norm(embedding)
                        )
                        similarities.append(similarity)
                        program_ids.append(row["id"])
                except (json.JSONDecodeError, ValueError, ZeroDivisionError) as e:
                    logger.warning(
                        f"Error computing similarity for program {row['id']}: {e}"
                    )
                    continue

            if not similarities:
                return None

            
            max_similarity_idx = np.argmax(similarities)
            most_similar_id = program_ids[max_similarity_idx]

            
            cursor.execute("SELECT * FROM programs WHERE id = ?", (most_similar_id,))
            row = cursor.fetchone()

            if row:
                return self._program_from_row(row)
            return None

        except Exception as e:
            logger.error(f"Error in get_most_similar_program_thread_safe: {e}")
            return None
        finally:
            if conn:
                conn.close()

    @db_retry()
    def _recompute_embeddings_and_clusters(self, num_clusters: int = 4):
        if self.read_only:
            return
        if not self.cursor or not self.conn:
            raise ConnectionError("DB not connected.")

        self.cursor.execute(
            "SELECT id, embedding FROM programs "
            "WHERE embedding IS NOT NULL AND embedding != '[]'"
        )
        rows = self.cursor.fetchall()

        if len(rows) < num_clusters:
            logger.info(
                f"Not enough programs with embeddings ({len(rows)}) to "
                f"perform clustering. Need at least {num_clusters}."
            )
            return

        program_ids = [row["id"] for row in rows]
        embeddings = [json.loads(row["embedding"]) for row in rows]

        
        try:
            logger.info(
                "Recomputing PCA-reduced embedding features for %s programs.",
                len(program_ids),
            )
            reduced_2d = self.embedding_client.get_dim_reduction(
                embeddings, method="pca", dims=2
            )
            reduced_3d = self.embedding_client.get_dim_reduction(
                embeddings, method="pca", dims=3
            )
            cluster_ids = self.embedding_client.get_embedding_clusters(
                embeddings, num_clusters=num_clusters
            )
        except Exception as e:
            logger.error(f"Failed to recompute embedding features: {e}")
            return

        
        self.conn.execute("BEGIN TRANSACTION")
        try:
            for i, program_id in enumerate(program_ids):
                embedding_pca_2d_json = json.dumps(reduced_2d[i].tolist())
                embedding_pca_3d_json = json.dumps(reduced_3d[i].tolist())
                cluster_id = int(cluster_ids[i])

                self.cursor.execute(
                    """
                    UPDATE programs
                    SET embedding_pca_2d = ?,
                        embedding_pca_3d = ?,
                        embedding_cluster_id = ?
                    WHERE id = ?
                    """,
                    (
                        embedding_pca_2d_json,
                        embedding_pca_3d_json,
                        cluster_id,
                        program_id,
                    ),
                )
            self.conn.commit()
            logger.info(
                "Successfully updated embedding features for %s programs.",
                len(program_ids),
            )
        except Exception as e:
            self.conn.rollback()
            logger.error("Failed to update programs with new embedding features: %s", e)

    @db_retry()
    def _recompute_embeddings_and_clusters_thread_safe(self, num_clusters: int = 4):
        
        if self.read_only:
            return

        conn = None
        try:
            
            conn = sqlite3.connect(
                self.config.db_path, check_same_thread=False, timeout=60.0
            )
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()

            cursor.execute(
                "SELECT id, embedding FROM programs "
                "WHERE embedding IS NOT NULL AND embedding != '[]'"
            )
            rows = cursor.fetchall()

            if len(rows) < num_clusters:
                if len(rows) > 0:
                    logger.info(
                        f"Not enough programs with embeddings ({len(rows)}) to "
                        f"perform clustering. Need at least {num_clusters}."
                    )
                return

            program_ids = [row["id"] for row in rows]
            embeddings = [json.loads(row["embedding"]) for row in rows]

            
            try:
                logger.info(
                    "Recomputing PCA-reduced embedding features for %s programs.",
                    len(program_ids),
                )

                logger.info("Computing 2D PCA reduction...")
                reduced_2d = self.embedding_client.get_dim_reduction(
                    embeddings, method="pca", dims=2
                )
                logger.info("2D PCA reduction completed")

                logger.info("Computing 3D PCA reduction...")
                reduced_3d = self.embedding_client.get_dim_reduction(
                    embeddings, method="pca", dims=3
                )
                logger.info("3D PCA reduction completed")

                logger.info(f"Computing GMM clustering with {num_clusters} clusters...")
                cluster_ids = self.embedding_client.get_embedding_clusters(
                    embeddings, num_clusters=num_clusters
                )
                logger.info("GMM clustering completed")
            except Exception as e:
                logger.error(f"Failed to recompute embedding features: {e}")
                return

            
            conn.execute("BEGIN TRANSACTION")
            try:
                for i, program_id in enumerate(program_ids):
                    embedding_pca_2d_json = json.dumps(reduced_2d[i].tolist())
                    embedding_pca_3d_json = json.dumps(reduced_3d[i].tolist())
                    cluster_id = int(cluster_ids[i])

                    cursor.execute(
                        """
                        UPDATE programs
                        SET embedding_pca_2d = ?,
                            embedding_pca_3d = ?,
                            embedding_cluster_id = ?
                        WHERE id = ?
                        """,
                        (
                            embedding_pca_2d_json,
                            embedding_pca_3d_json,
                            cluster_id,
                            program_id,
                        ),
                    )
                conn.commit()
                logger.info(
                    "Successfully updated embedding features for %s programs.",
                    len(program_ids),
                )
            except Exception as e:
                conn.rollback()
                logger.error(
                    "Failed to update programs with new embedding features: %s", e
                )
                raise  

        except Exception as e:
            logger.error(f"Thread-safe embedding recomputation failed: {e}")
            raise  

        finally:
            if conn:
                conn.close()

    @db_retry()
    def get_programs_by_generation_thread_safe(self, generation: int) -> List[Program]:
        
        conn = None
        try:
            conn = sqlite3.connect(
                self.config.db_path, check_same_thread=False, timeout=60.0
            )
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()

            cursor.execute("SELECT * FROM programs WHERE generation = ?", (generation,))
            rows = cursor.fetchall()

            programs = []
            for row in rows:
                if not row:
                    continue
                program_data = dict(row)
                
                for key, value in program_data.items():
                    if key in [
                        "public_metrics",
                        "private_metrics",
                        "metadata",
                        "archive_inspiration_ids",
                        "top_k_inspiration_ids",
                        "embedding",
                        "embedding_pca_2d",
                        "embedding_pca_3d",
                        "migration_history",
                    ] and isinstance(value, str):
                        try:
                            program_data[key] = json.loads(value)
                        except json.JSONDecodeError:
                            program_data[key] = {} if key.endswith("_metrics") else []
                programs.append(Program(**program_data))
            return programs
        finally:
            if conn:
                conn.close()

    @db_retry()
    def get_top_programs_thread_safe(
        self,
        n: int = 10,
        correct_only: bool = True,
    ) -> List[Program]:
        
        conn = None
        try:
            conn = sqlite3.connect(
                self.config.db_path, check_same_thread=False, timeout=60.0
            )
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()

            
            base_query = """
                SELECT * FROM programs
                WHERE combined_score IS NOT NULL
            """
            if correct_only:
                base_query += " AND correct = 1"
            base_query += " ORDER BY combined_score DESC LIMIT ?"

            cursor.execute(base_query, (n,))
            all_rows = cursor.fetchall()

            if not all_rows:
                return []

            
            programs = []
            for row_data in all_rows:
                program_data = dict(row_data)

                
                json_fields = [
                    "public_metrics",
                    "private_metrics",
                    "metadata",
                    "archive_inspiration_ids",
                    "top_k_inspiration_ids",
                    "embedding",
                    "embedding_pca_2d",
                    "embedding_pca_3d",
                    "migration_history",
                ]
                for key, value in program_data.items():
                    if key in json_fields and isinstance(value, str):
                        try:
                            program_data[key] = json.loads(value)
                        except json.JSONDecodeError:
                            is_dict_field = (
                                key.endswith("_metrics") or key == "metadata"
                            )
                            program_data[key] = {} if is_dict_field else []

                
                if (
                    "text_feedback" not in program_data
                    or program_data["text_feedback"] is None
                ):
                    program_data["text_feedback"] = ""

                programs.append(Program.from_dict(program_data))

            return programs

        finally:
            if conn:
                conn.close()

    def _get_programs_for_island(self, island_idx: int) -> List[Program]:
        """
        Get all programs for a specific island.
        """
