import json
import logging
import sqlite3
from abc import ABC, abstractmethod
from typing import Optional, Callable, Any
import numpy as np  

logger = logging.getLogger(__name__)


def sample_with_powerlaw(items: list, alpha: float = 1.0) -> int:
    
    if not items:
        raise ValueError("Empty items list for power-law sampling")

    
    probs = np.array([(i + 1) ** (-alpha) for i in range(len(items))])
    if np.sum(probs) == 0:  
        
        probs = np.ones(len(items))

    probs = probs / probs.sum()  
    logger.info(f"Power law probs: {probs.tolist()}")
    return np.random.choice(len(items), p=probs)


def stable_sigmoid(x: float) -> float:
    
    if x >= 0:
        exp_neg_x = np.exp(-x)
        return 1.0 / (1.0 + exp_neg_x)
    else:
        exp_x = np.exp(x)
        return exp_x / (1.0 + exp_x)


class ParentSamplingStrategy(ABC):
    

    def __init__(
        self,
        cursor: sqlite3.Cursor,
        conn: sqlite3.Connection,
        config: Any,
        get_program_func: Callable[[str], Any],
        best_program_id: Optional[str] = None,
        island_idx: Optional[int] = None,
    ):
        self.cursor = cursor
        self.conn = conn
        self.config = config
        self.get_program = get_program_func
        self.best_program_id = best_program_id
        self.island_idx = island_idx

    @abstractmethod
    def sample_parent(self) -> Any:
        
        pass

    def _get_island_idx(self, program_id: str) -> Optional[int]:
        
        self.cursor.execute(
            "SELECT island_idx FROM programs WHERE id = ?", (program_id,)
        )
        row = self.cursor.fetchone()
        return row["island_idx"] if row else None


class PowerLawSamplingStrategy(ParentSamplingStrategy):
    

    def sample_parent(self) -> Any:
        if not hasattr(self.config, "exploitation_ratio"):
            raise ConnectionError("DB/config issue for parent sampling.")

        pid: Optional[str] = None
        
        if hasattr(self.config, "exploitation_ratio"):
            if np.random.random() < self.config.exploitation_ratio:
                if self.island_idx is not None:
                    self.cursor.execute(
                        """SELECT a.program_id FROM archive a 
                           JOIN programs p ON a.program_id = p.id 
                           WHERE p.island_idx = ?""",
                        (self.island_idx,),
                    )
                else:
                    self.cursor.execute("SELECT program_id FROM archive")
                archived_rows = self.cursor.fetchall()
                if archived_rows:
                    archived_program_ids = [row["program_id"] for row in archived_rows]

                    
                    
                    archived_programs = []
                    for prog_id in archived_program_ids:
                        prog = self.get_program(prog_id)
                        if prog:
                            archived_programs.append(prog)

                    if archived_programs:
                        
                        archived_programs.sort(
                            key=lambda p: p.combined_score or 0.0, reverse=True
                        )
                        logger.info(
                            f"Island {self.island_idx} => Archived program scores: {[p.combined_score for p in archived_programs]}"
                        )

                        alpha = getattr(self.config, "exploitation_alpha", 1.0)
                        sampled_idx = sample_with_powerlaw(archived_programs, alpha)
                        selected_prog = archived_programs[sampled_idx]
                        pid = selected_prog.id

                        logger.info(
                            f"Exploitation: Sampled from archive: {pid} "
                            f"(Gen: {selected_prog.generation}, "
                            f"Score: {selected_prog.combined_score or 0.0:.4f}, "
                            f"Island: {selected_prog.island_idx})"
                        )

        
        if not pid:
            if self.island_idx is not None:
                self.cursor.execute(
                    """SELECT p.id FROM programs p
                       WHERE p.correct = 1 AND p.island_idx = ?
                       ORDER BY p.combined_score DESC""",
                    (self.island_idx,),
                )
            else:
                self.cursor.execute(
                    """SELECT p.id FROM programs p
                       WHERE p.correct = 1
                       ORDER BY p.combined_score DESC"""
                )
            correct_rows = self.cursor.fetchall()
            if correct_rows:
                correct_program_ids = [row["id"] for row in correct_rows]
                correct_programs = []
                for prog_id in correct_program_ids:
                    prog = self.get_program(prog_id)
                    if prog:
                        correct_programs.append(prog)

                if correct_programs:
                    alpha = getattr(self.config, "exploitation_alpha", 1.0)
                    logger.info(
                        f"Island {self.island_idx} => Correct program scores: {[p.combined_score for p in correct_programs]}"
                    )
                    sampled_idx = sample_with_powerlaw(correct_programs, alpha)
                    selected_prog = correct_programs[sampled_idx]
                    pid = selected_prog.id

                    logger.info(
                        f"Exploration: Sampled from all correct: {pid} "
                        f"(Gen: {selected_prog.generation}, "
                        f"Score: {selected_prog.combined_score or 0.0:.4f}, "
                        f"Island: {selected_prog.island_idx})"
                    )

        
        if (
            not pid
            and hasattr(self.config, "num_islands")
            and self.config.num_islands > 0
            and self.island_idx is None  
        ):
            self.cursor.execute("SELECT DISTINCT island_idx FROM programs")
            island_indices = [r["island_idx"] for r in self.cursor.fetchall()]
            if island_indices:
                idx = np.random.choice(island_indices)
                self.cursor.execute(
                    """SELECT p.id FROM programs p
                       WHERE p.island_idx = ? AND p.correct = 1
                       ORDER BY RANDOM() LIMIT 1""",
                    (idx,),
                )
                row = self.cursor.fetchone()
                if row:
                    pid = row["id"]
                    prog = self.get_program(pid)
                    if prog:
                        score = prog.combined_score or 0.0
                        logger.info(
                            f"Exploration: Sampled from island {idx}: {pid} "
                            f"(Gen: {prog.generation}, Score: {score:.4f})"
                        )

        
        if not pid:
            
            if self.best_program_id:
                best_prog = self.get_program(self.best_program_id)
                if (
                    best_prog
                    and best_prog.correct
                    and (
                        self.island_idx is None
                        or best_prog.island_idx == self.island_idx
                    )
                ):
                    pid = self.best_program_id
                    score = best_prog.combined_score or 0.0
                    logger.info(
                        f"Exploitation: Return best program: {pid} "
                        f"(Gen: {best_prog.generation}, Score: {score:.4f})"
                    )

        
        if not pid:
            if self.island_idx is not None:
                self.cursor.execute(
                    """SELECT id FROM programs 
                       WHERE correct = 1 AND island_idx = ? 
                       ORDER BY RANDOM() LIMIT 1""",
                    (self.island_idx,),
                )
            else:
                self.cursor.execute(
                    """SELECT id FROM programs WHERE correct = 1 
                       ORDER BY RANDOM() LIMIT 1"""
                )
            row = self.cursor.fetchone()
            if row:
                pid = row["id"]
                prog = self.get_program(pid)
                if prog:
                    logger.info(f"Fallback: Random correct program: {pid}")

        if not pid:
            logger.warning(
                "No parent found, database may be empty or no correct "
                "programs in specified island."
            )
            return None

        return self.get_program(pid)


class WeightedSamplingStrategy(ParentSamplingStrategy):
    

    def sample_parent(self) -> Any:
        
        if self.island_idx is not None:
            self.cursor.execute(
                """
                SELECT p.*
                FROM programs p
                JOIN archive a ON p.id = a.program_id
                WHERE p.correct = 1 AND p.island_idx = ?
                """,
                (self.island_idx,),
            )
        else:
            self.cursor.execute(
                """
                SELECT p.*
                FROM programs p
                JOIN archive a ON p.id = a.program_id
                WHERE p.correct = 1
                """
            )
        archive_rows = self.cursor.fetchall()

        if not archive_rows:
            logger.warning("No archived programs found for weighted sampling.")
            if self.best_program_id:
                best_prog = self.get_program(self.best_program_id)
                if best_prog and (
                    self.island_idx is None or best_prog.island_idx == self.island_idx
                ):
                    return best_prog

            
            if self.island_idx is not None:
                self.cursor.execute(
                    """SELECT id FROM programs 
                       WHERE correct = 1 AND island_idx = ? 
                       ORDER BY RANDOM() LIMIT 1""",
                    (self.island_idx,),
                )
            else:
                self.cursor.execute(
                    """SELECT id FROM programs WHERE correct = 1 
                       ORDER BY RANDOM() LIMIT 1"""
                )
            row = self.cursor.fetchone()
            return self.get_program(row["id"]) if row else None

        eligible_programs = []
        for row in archive_rows:
            p_dict = dict(row)

            
            p_dict["public_metrics"] = (
                json.loads(p_dict["public_metrics"])
                if p_dict.get("public_metrics")
                else {}
            )
            p_dict["private_metrics"] = (
                json.loads(p_dict["private_metrics"])
                if p_dict.get("private_metrics")
                else {}
            )
            p_dict["metadata"] = (
                json.loads(p_dict["metadata"]) if p_dict.get("metadata") else {}
            )
            p_dict["archive_inspiration_ids"] = (
                json.loads(p_dict["archive_inspiration_ids"])
                if p_dict.get("archive_inspiration_ids")
                else []
            )
            p_dict["top_k_inspiration_ids"] = (
                json.loads(p_dict["top_k_inspiration_ids"])
                if p_dict.get("top_k_inspiration_ids")
                else []
            )
            p_dict["embedding"] = (
                json.loads(p_dict["embedding"]) if p_dict.get("embedding") else []
            )
            p_dict["embedding_pca_2d"] = (
                json.loads(p_dict["embedding_pca_2d"])
                if p_dict.get("embedding_pca_2d")
                else []
            )
            p_dict["embedding_pca_3d"] = (
                json.loads(p_dict["embedding_pca_3d"])
                if p_dict.get("embedding_pca_3d")
                else []
            )
            p_dict["migration_history"] = (
                json.loads(p_dict["migration_history"])
                if p_dict.get("migration_history")
                else []
            )

            
            class SimpleProgram:
                def __init__(self, data):
                    for key, value in data.items():
                        setattr(self, key, value)
                    
                    if not hasattr(self, "combined_score"):
                        self.combined_score = 0.0
                    if not hasattr(self, "children_count"):
                        self.children_count = 0
                    if not hasattr(self, "correct"):
                        self.correct = False
                    if not hasattr(self, "id"):
                        self.id = None

            eligible_programs.append(SimpleProgram(p_dict))

        
        scores = [p.combined_score or 0.0 for p in eligible_programs]
        alpha_0 = np.median(scores) if scores else 0.0

        
        
        score_deviations = [abs(score - alpha_0) for score in scores]
        mad = np.median(score_deviations) if score_deviations else 1.0
        
        scale_factor = max(mad, 1e-6)

        
        weights = []
        lambda_ = self.config.parent_selection_lambda

        for i, p in enumerate(eligible_programs):
            
            alpha_i = p.combined_score or 0.0
            
            n_i = p.children_count

            
            
            normalized_diff = (alpha_i - alpha_0) / scale_factor
            s_i = stable_sigmoid(lambda_ * normalized_diff)

            
            h_i = 1 / (1 + n_i)

            
            w_i = s_i * h_i
            weights.append(w_i)
            logger.debug(
                f"I-{self.island_idx} => P-{i}: w_i: {w_i:.2f}, s_i: {s_i:.2f}, h_i: {h_i:.2f}, alpha_i: {alpha_i:.2f}, alpha_0: {alpha_0:.2f}, "
                f"normalized_diff: {normalized_diff:.2f}, scale_factor: {scale_factor:.2f}"
            )

        
        weights_sum = sum(weights)
        if weights_sum > 0:
            probabilities = [w / weights_sum for w in weights]
        else:
            
            logger.warning(
                "All parent selection weights are zero, falling back to "
                "uniform sampling."
            )
            num_eligible = len(eligible_programs)
            probabilities = [1.0 / num_eligible] * num_eligible
        logger.info(
            f"Island {self.island_idx} => Probabilities: {np.array(probabilities).tolist()}"
        )
        logger.info(
            f"Island {self.island_idx} => Scores: {[p.combined_score for p in eligible_programs]}"
        )
        
        selected_parent = np.random.choice(eligible_programs, p=probabilities)

        logger.info(
            f"Sampled parent {selected_parent.id} "
            f"(Gen: {selected_parent.generation}, "
            f"Score: {selected_parent.combined_score or 0.0:.4f}, "
            f"Children: {selected_parent.children_count}, "
            f"Island: {selected_parent.island_idx})"
        )

        return self.get_program(selected_parent.id)


class BeamSearchSamplingStrategy(ParentSamplingStrategy):
    

    def __init__(
        self,
        cursor: sqlite3.Cursor,
        conn: sqlite3.Connection,
        config: Any,
        get_program_func: Callable[[str], Any],
        best_program_id: Optional[str] = None,
        island_idx: Optional[int] = None,
        beam_search_parent_id: Optional[str] = None,
        last_iteration: int = 0,
        update_metadata_func: Optional[Callable[[str, Optional[str]], None]] = None,
        get_best_program_func: Optional[Callable[[], Any]] = None,
    ):
        super().__init__(
            cursor, conn, config, get_program_func, best_program_id, island_idx
        )
        self.beam_search_parent_id = beam_search_parent_id
        self.last_iteration = last_iteration
        self.update_metadata = update_metadata_func
        self.get_best_program_func = get_best_program_func

    def sample_parent(self) -> Any:
        num_beams = getattr(self.config, "num_beams", 5)

        
        if not self.beam_search_parent_id:
            
            if self.get_best_program_func:
                best_program = self.get_best_program_func()
                if best_program:
                    self.beam_search_parent_id = best_program.id
                    if self.update_metadata:
                        self.update_metadata(
                            "beam_search_parent_id", self.beam_search_parent_id
                        )
                    logger.info(
                        f"Beam search: Selected new parent {self.beam_search_parent_id} "
                        f"(Gen: {best_program.generation}, "
                        f"Score: {best_program.combined_score or 0.0:.4f})"
                    )

        
        if self.beam_search_parent_id:
            parent = self.get_program(self.beam_search_parent_id)
            if parent:
                
                self.cursor.execute(
                    "SELECT COUNT(*) FROM programs WHERE parent_id = ?",
                    (self.beam_search_parent_id,),
                )
                children_count = (self.cursor.fetchone() or [0])[0]

                if children_count < num_beams:
                    logger.info(
                        f"Beam search: Continue with parent {self.beam_search_parent_id} "
                        f"({children_count}/{num_beams} children)"
                    )
                    return parent
                else:
                    
                    if self.get_best_program_func:
                        best_program = self.get_best_program_func()
                        if best_program:
                            self.beam_search_parent_id = best_program.id
                            if self.update_metadata:
                                self.update_metadata(
                                    "beam_search_parent_id", self.beam_search_parent_id
                                )
                            logger.info(
                                f"Beam search: Switch to new parent {self.beam_search_parent_id} "
                                f"(Gen: {best_program.generation}, "
                                f"Score: {best_program.combined_score or 0.0:.4f})"
                            )
                            return best_program

        
        if self.best_program_id:
            return self.get_program(self.best_program_id)

        
        self.cursor.execute(
            "SELECT id FROM programs WHERE correct = 1 ORDER BY RANDOM() LIMIT 1"
        )
        row = self.cursor.fetchone()
        return self.get_program(row["id"]) if row else None


class BestOfNSamplingStrategy(ParentSamplingStrategy):
    

    def sample_parent(self) -> Any:
        
        if self.island_idx is not None:
            self.cursor.execute(
                """SELECT id FROM programs
                   WHERE generation = 0 AND island_idx = ? AND correct = 1
                   ORDER BY id LIMIT 1""",
                (self.island_idx,),
            )
        else:
            self.cursor.execute(
                """SELECT id FROM programs
                   WHERE generation = 0 AND correct = 1
                   ORDER BY id LIMIT 1"""
            )

        row = self.cursor.fetchone()
        if row:
            pid = row["id"]
            prog = self.get_program(pid)
            if prog:
                logger.info(
                    f"Best-of-N: Selected initial program {pid} "
                    f"(Gen: {prog.generation}, "
                    f"Score: {prog.combined_score or 0.0:.4f}, "
                    f"Island: {prog.island_idx})"
                )
                return prog

        
        logger.warning(
            "No generation 0 program found, falling back to any correct program"
        )
        if self.island_idx is not None:
            self.cursor.execute(
                """SELECT id FROM programs
                   WHERE correct = 1 AND island_idx = ?
                   ORDER BY generation ASC, id ASC LIMIT 1""",
                (self.island_idx,),
            )
        else:
            self.cursor.execute(
                """SELECT id FROM programs
                   WHERE correct = 1
                   ORDER BY generation ASC, id ASC LIMIT 1"""
            )

        row = self.cursor.fetchone()
        if row:
            pid = row["id"]
            prog = self.get_program(pid)
            if prog:
                logger.info(
                    f"Best-of-N: Fallback to earliest correct program {pid} "
                    f"(Gen: {prog.generation}, "
                    f"Score: {prog.combined_score or 0.0:.4f}, "
                    f"Island: {prog.island_idx})"
                )
                return prog

        logger.warning("No suitable parent found for best-of-n strategy")
        return None


class CombinedParentSelector:
    

    def __init__(
        self,
        cursor: sqlite3.Cursor,
        conn: sqlite3.Connection,
        config: Any,
        get_program_func: Callable[[str], Any],
        best_program_id: Optional[str] = None,
        beam_search_parent_id: Optional[str] = None,
        last_iteration: int = 0,
        update_metadata_func: Optional[Callable[[str, Optional[str]], None]] = None,
        get_best_program_func: Optional[Callable[[], Any]] = None,
    ):
        self.cursor = cursor
        self.conn = conn
        self.config = config
        self.get_program = get_program_func
        self.best_program_id = best_program_id
        self.beam_search_parent_id = beam_search_parent_id
        self.last_iteration = last_iteration
        self.update_metadata = update_metadata_func
        self.get_best_program_func = get_best_program_func

    def sample_parent(self, island_idx: Optional[int] = None) -> Any:
        
        strategy_name = self.config.parent_selection_strategy

        if strategy_name == "power_law":
            strategy = PowerLawSamplingStrategy(
                self.cursor,
                self.conn,
                self.config,
                self.get_program,
                self.best_program_id,
                island_idx,
            )
        elif strategy_name == "weighted":
            strategy = WeightedSamplingStrategy(
                self.cursor,
                self.conn,
                self.config,
                self.get_program,
                self.best_program_id,
                island_idx,
            )
        elif strategy_name == "beam_search":
            strategy = BeamSearchSamplingStrategy(
                cursor=self.cursor,
                conn=self.conn,
                config=self.config,
                get_program_func=self.get_program,
                best_program_id=self.best_program_id,
                island_idx=island_idx,
                beam_search_parent_id=self.beam_search_parent_id,
                last_iteration=self.last_iteration,
                update_metadata_func=self.update_metadata,
                get_best_program_func=self.get_best_program_func,
            )
        elif strategy_name == "best_of_n":
            strategy = BestOfNSamplingStrategy(
                self.cursor,
                self.conn,
                self.config,
                self.get_program,
                self.best_program_id,
                island_idx,
            )
        else:
            raise ValueError(f"Unknown parent selection strategy: {strategy_name}")

        parent = strategy.sample_parent()

        
        if not parent:
            
            if self.best_program_id:
                parent = self.get_program(self.best_program_id)
                if (
                    parent
                    and parent.correct
                    and (island_idx is None or parent.island_idx == island_idx)
                ):
                    return parent

            
            if island_idx is not None:
                self.cursor.execute(
                    """SELECT id FROM programs 
                       WHERE correct = 1 AND island_idx = ?
                       ORDER BY RANDOM() LIMIT 1""",
                    (island_idx,),
                )
            else:
                self.cursor.execute(
                    """SELECT id FROM programs 
                       ORDER BY RANDOM() LIMIT 1"""
                )
            row = self.cursor.fetchone()
            if row:
                parent = self.get_program(row["id"])

            if not parent:
                raise ValueError("Database empty or parent sampling failed.")

        return parent
