import logging
import sqlite3
from abc import ABC, abstractmethod
from typing import Optional, Callable, Any, List, Set

logger = logging.getLogger(__name__)


class ContextSelectorStrategy(ABC):
    

    def __init__(
        self,
        cursor: sqlite3.Cursor,
        conn: sqlite3.Connection,
        config: Any,
        get_program_func: Callable[[str], Any],
        best_program_id: Optional[str] = None,
        get_island_idx_func: Optional[Callable[[str], Optional[int]]] = None,
        program_from_row_func: Optional[Callable[[sqlite3.Row], Any]] = None,
    ):
        self.cursor = cursor
        self.conn = conn
        self.config = config
        self.get_program = get_program_func
        self.best_program_id = best_program_id
        self.get_island_idx = get_island_idx_func
        self.program_from_row = program_from_row_func

    @abstractmethod
    def sample_context(self, parent: Any, n: int) -> List[Any]:
        
        pass


class ArchiveInspirationSelector(ContextSelectorStrategy):
    

    def sample_context(self, parent: Any, n: int) -> List[Any]:
        
        if n <= 0:
            return []
        if not hasattr(self.config, "elite_selection_ratio"):
            raise ConnectionError(
                "Config missing elite_selection_ratio for inspiration sampling."
            )

        parent_island_idx = (
            self.get_island_idx(parent.id) if self.get_island_idx else None
        )

        inspirations: List[Any] = []
        insp_ids: Set[str] = {parent.id}

        enforce_separation = getattr(self.config, "enforce_island_separation", False)

        
        if self.best_program_id and self.best_program_id not in insp_ids:
            prog = self.get_program(self.best_program_id)
            if prog and prog.correct:
                if enforce_separation:
                    if prog.island_idx == parent_island_idx:
                        inspirations.append(prog)
                        insp_ids.add(prog.id)
                else:
                    inspirations.append(prog)
                    insp_ids.add(prog.id)

        
        num_elites = max(0, int(n * self.config.elite_selection_ratio))
        if num_elites > 0 and len(inspirations) < n and parent_island_idx is not None:
            self.cursor.execute(
                """
                SELECT p.id FROM programs p
                WHERE p.island_idx = ? AND p.correct = 1
                ORDER BY p.combined_score DESC
                LIMIT ?
                """,
                (parent_island_idx, num_elites + len(insp_ids)),
            )
            for row in self.cursor.fetchall():
                if len(inspirations) >= n:
                    break
                prog = self.get_program(row["id"])
                if prog and prog.id not in insp_ids:
                    inspirations.append(prog)
                    insp_ids.add(prog.id)

        
        if len(inspirations) < n and parent_island_idx is not None:
            needed = n - len(inspirations)
            if needed > 0:
                placeholders_rand = ",".join("?" * len(insp_ids))
                sql_rand = f"""
                    SELECT p.id FROM programs p
                    WHERE p.island_idx = ? AND p.correct = 1 
                    AND p.id NOT IN ({placeholders_rand})
                    ORDER BY RANDOM() LIMIT ?
                """
                params_rand = [parent_island_idx] + list(insp_ids) + [needed]

                self.cursor.execute(sql_rand, params_rand)
                for row in self.cursor.fetchall():
                    prog = self.get_program(row["id"])
                    if prog:  
                        inspirations.append(prog)

        
        
        if len(inspirations) < n and not enforce_separation:
            needed = n - len(inspirations)
            if needed > 0:
                placeholders_rand = ",".join("?" * len(insp_ids))
                sql_rand = f"""SELECT id FROM programs
                                 WHERE correct = 1 
                                 AND id NOT IN ({placeholders_rand})
                                 ORDER BY RANDOM() LIMIT ?
                                 """
                params_rand = list(insp_ids) + [needed]
                self.cursor.execute(sql_rand, params_rand)
                for row in self.cursor.fetchall():
                    prog = self.get_program(row["id"])
                    if prog:
                        inspirations.append(prog)

        if inspirations:
            inspiration_details = [
                f"{p.id} (Gen: {p.generation}, "
                f"Score: {p.combined_score or 0.0:.4f}, "
                f"Island: {p.island_idx})"
                for p in inspirations
            ]
            logger.info(
                f"Sampled {len(inspirations)} archive inspirations: "
                f"{inspiration_details}"
            )
        return inspirations


class TopKInspirationSelector(ContextSelectorStrategy):
    

    def sample_context(
        self, parent: Any, excluded_programs: List[Any], k: int
    ) -> List[Any]:
        
        if k <= 0:
            return []

        enforce_separation = getattr(self.config, "enforce_island_separation", False)
        parent_island_idx = parent.island_idx

        if enforce_separation and parent_island_idx is None:
            logger.debug(
                f"Parent {parent.id} has no island and island separation is enforced, "
                "skipping top-k inspirations."
            )
            return []

        
        excluded_ids: Set[str] = {parent.id}
        excluded_ids.update(prog.id for prog in excluded_programs)

        if not hasattr(self.config, "archive_size") or self.config.archive_size <= 0:
            return []

        
        placeholders = ",".join("?" * len(excluded_ids))

        if enforce_separation and parent_island_idx is not None:
            
            query = f"""
                SELECT p.*
                FROM programs p
                JOIN archive a ON p.id = a.program_id
                WHERE p.island_idx = ? AND p.id NOT IN ({placeholders}) 
                AND p.correct = 1
            """
            params = [parent_island_idx] + list(excluded_ids)
            search_scope = f"island {parent_island_idx}"
        else:
            
            query = f"""
                SELECT p.*
                FROM programs p
                JOIN archive a ON p.id = a.program_id
                WHERE p.id NOT IN ({placeholders}) 
                AND p.correct = 1
            """
            params = list(excluded_ids)
            search_scope = "all islands"

        self.cursor.execute(query, params)
        archive_rows = self.cursor.fetchall()

        if not archive_rows:
            logger.debug(
                f"No eligible archived programs for top-k selection on {search_scope}."
            )
            return []

        archive_programs = [
            self.program_from_row(row) for row in archive_rows if self.program_from_row
        ]
        archive_programs = [p for p in archive_programs if p]

        if not archive_programs:
            return []

        
        def sort_key(prog: Any) -> float:
            if prog.combined_score is not None:
                return prog.combined_score
            elif prog.public_metrics:
                return sum(prog.public_metrics.values()) / len(prog.public_metrics)
            else:
                return -float("inf")

        sorted_programs = sorted(archive_programs, key=sort_key, reverse=True)

        
        top_k = sorted_programs[:k]

        if top_k:
            inspiration_details = [
                f"{p.id} (Gen: {p.generation}, "
                f"Score: {p.combined_score or 0.0:.4f}, "
                f"Island: {p.island_idx})"
                for p in top_k
            ]
            logger.info(
                f"Selected {len(top_k)} top-k inspirations from archive on "
                f"{search_scope}: {inspiration_details}"
            )

        return top_k


class CombinedContextSelector:
    

    def __init__(
        self,
        cursor: sqlite3.Cursor,
        conn: sqlite3.Connection,
        config: Any,
        get_program_func: Callable[[str], Any],
        best_program_id: Optional[str] = None,
        get_island_idx_func: Optional[Callable[[str], Optional[int]]] = None,
        program_from_row_func: Optional[Callable[[sqlite3.Row], Any]] = None,
    ):
        self.archive_selector = ArchiveInspirationSelector(
            cursor=cursor,
            conn=conn,
            config=config,
            get_program_func=get_program_func,
            best_program_id=best_program_id,
            get_island_idx_func=get_island_idx_func,
            program_from_row_func=program_from_row_func,
        )
        self.topk_selector = TopKInspirationSelector(
            cursor=cursor,
            conn=conn,
            config=config,
            get_program_func=get_program_func,
            best_program_id=best_program_id,
            get_island_idx_func=get_island_idx_func,
            program_from_row_func=program_from_row_func,
        )

    def sample_context(
        self, parent: Any, num_archive: int, num_topk: int
    ) -> tuple[List[Any], List[Any]]:
        
        archive_inspirations = self.archive_selector.sample_context(parent, num_archive)
        top_k_inspirations = self.topk_selector.sample_context(
            parent, archive_inspirations, num_topk
        )
        return archive_inspirations, top_k_inspirations
