# flake8: noqa: E501

import logging
from typing import Any, Callable, Dict, List, Optional

import numpy as np
from haipr.utils import AA_ALPHABET_WITH_EXTRAS
from omegaconf import OmegaConf, SCMode
from scipy.spatial.distance import hamming

import pygad

from .base_generator import BaseSequenceGenerator

logger = logging.getLogger("haipr.sequence_generators.pygad")


class GeneticSequenceGenerator(pygad.GA, BaseSequenceGenerator):
    def __init__(
        self,
        ga_params: Optional[Dict[str, Any]] = None,
        population_factor: int = 256,
        plateau_threshold: int = 15,
        max_mutation_frequency: float = 0.1,
        use_injection: bool = False,
        max_consecutive_injections: int = 3,
        max_distance: Optional[float] = None,
        max_mutations: Optional[int] = None,
        max_retries: int = 10,
        populate_cache_with_training_data: bool = False,
    ):
        """
        Initialize the GeneticSequenceGenerator with the given parameters.

        Args:
            ga_params: Dictionary of parameters for the genetic algorithm
            population_factor: Factor to determine population size
            plateau_threshold: Number of generations without improvement before injecting training data
            max_consecutive_injections: Maximum consecutive injections without improvement before stopping
        """
        self.max_mutation_frequency = max_mutation_frequency
        self.population_factor = population_factor
        self.max_consecutive_injections = max_consecutive_injections
        self.use_injection = use_injection
        self.max_distance = max_distance
        self.max_mutations = max_mutations
        self.populate_cache_with_training_data = populate_cache_with_training_data
        self.max_retries = max_retries
        self.num_retries = 0
        # interrupt hook for graceful shutdown
        self.stop_run = False

        # parse ga_params to python dict

        if ga_params is not None:
            self.ga_params = OmegaConf.to_container(
                ga_params, resolve=True, structured_config_mode=SCMode.DICT)
        else:
            self.ga_params = {}
        assert isinstance(
            self.ga_params, dict), "ga_params must be a dictionary"

        # Internal state
        self.generation_count = 0
        self.initialized = False
        self.fitness_cache: Dict[List[int], float] = {}

        # Tracking variables for external access
        self.best_sequence = None
        self.best_fitness = -float("inf")
        self.all_sequences: List[str] = []
        self.all_fitnesses: List[float] = []

        # Plateau detection for training data injection
        self.plateau_threshold = plateau_threshold
        self.generations_without_improvement = 0
        self.last_best_fitness = -float("inf")

        # Consecutive injection tracking and monitoring
        self.consecutive_injections_without_improvement = 0
        self.total_injections = 0
        self.injection_history: List[Dict[str, Any]] = []
        self.fitness_before_injection = -float("inf")
        self.early_stopped = False
        self.stop_reason = ""

    def setup_generator(self, data, alphabet_per_position, fitness_callback: Callable):
        self.data = data
        self.fitness_callback = fitness_callback

        # Store reference to external metrics logging (will be set by inference)
        self.metrics_logger = None

        # Store reference to callback for creating new MLflow runs (will be set by inference)
        self.new_run_callback = None

        # Set alphabet_per_position FIRST before encoding anything
        self.alphabet_per_position = alphabet_per_position
        # sets self.variable_positions
        self.gene_space = self._create_gene_space(alphabet_per_position)

        if self.max_distance is not None:
            pass
        elif self.max_mutations is not None:
            # compute as max_mutations / len(variable_positions)
            self.max_distance = self.max_mutations / \
                len(self.variable_positions)

        self.sequence = data.representative
        self.starting_chromosome = self.encode(self.sequence)

        # Now we can safely encode the data sequences
        data_sequences = data.get_sequences()
        self.data_population_encoded = [
            self.encode(seq) for seq in data_sequences]
        self.data_fitness_values = data.get_labels().tolist()

        # Initialize ga_params if None

        if self.ga_params is None:
            self.ga_params = {}

        # Ensure ga_params is a dictionary for type checking
        assert isinstance(
            self.ga_params, dict), "ga_params must be a dictionary"

        # Configure GA parameters
        self.ga_params["num_genes"] = len(self.gene_space)
        self.ga_params["gene_space"] = self.gene_space
        self.ga_params["fitness_func"] = self.fitness_func
        self.ga_params["on_start"] = self.on_start
        self.ga_params["on_generation"] = self.on_generation
        self.ga_params["on_fitness"] = self.on_fitness
        self.ga_params["on_crossover"] = self.on_crossover
        self.ga_params["on_mutation"] = self.on_mutation
        self.ga_params["on_parents"] = self.on_parents
        self.ga_params["mutation_type"] = self.distance_constrained_mutation
        self.ga_params["crossover_type"] = self.distance_constrained_crossover
        self.ga_params["initial_population"] = self.setup_initial_population()

        if self.populate_cache_with_training_data:
            # Cache initial fitness values

            for seq, fitness in zip(data_sequences, self.data_fitness_values):
                self.fitness_cache[seq] = fitness

        super().__init__(**self.ga_params)
        self.initialized = True

    def set_metrics_logger(self, logger_func):
        """Set a function to call for logging generation metrics"""
        self.metrics_logger = logger_func

    def set_new_run_callback(self, callback_func):
        """Set a function to call for creating new MLflow runs"""
        self.new_run_callback = callback_func

    def cal_pop_fitness(self):
        """
        Calculate fitness for the entire population at once using batch processing.
        This is much more efficient than calling fitness_func for each individual.
        """
        logger.debug(
            f"Calculating fitness for population of {len(self.population)} individuals"
        )

        # Decode all sequences in the population
        all_sequences = [self.decode(solution) for solution in self.population]

        # Separate cached and uncached sequences
        cached_fitness = []
        uncached_sequences = []
        uncached_indices = []

        for idx, sequence in enumerate(all_sequences):
            if sequence in self.fitness_cache:
                cached_fitness.append((idx, self.fitness_cache[sequence]))
            else:
                uncached_sequences.append(sequence)
                uncached_indices.append(idx)

        # Initialize fitness array
        pop_fitness = np.zeros(len(all_sequences))

        # Set cached fitness values

        for idx, fitness in cached_fitness:
            pop_fitness[idx] = fitness
        logger.info(f"{len(cached_fitness)} HITS in Fitness Cache")

        # Batch evaluate uncached sequences if any

        if uncached_sequences:
            logger.debug(
                f"Batch evaluating {len(uncached_sequences)} new sequences")
            try:
                fitness_scores, stage_outputs = self.fitness_callback(
                    uncached_sequences
                )

                # Set fitness for uncached sequences and update cache

                for i, (seq_idx, sequence) in enumerate(
                    zip(uncached_indices, uncached_sequences)
                ):
                    fitness = fitness_scores[i] if i < len(
                        fitness_scores) else 0.0
                    pop_fitness[seq_idx] = fitness
                    self.fitness_cache[sequence] = fitness

            except Exception as e:
                logger.error(f"Error calculating batch fitness: {e}")
                # Set default fitness for failed sequences

                for seq_idx in uncached_indices:
                    pop_fitness[seq_idx] = 0.0

        logger.debug(
            f"Population fitness range: [{np.min(pop_fitness):.3f}, {np.max(pop_fitness):.3f}]"
        )

        return pop_fitness

    def fitness_func(self, ga_instance, solution, solution_idx):
        """
        Individual fitness function - called by PyGAD for specific operations.
        Returns cached values from batch processing to avoid expensive recalculation.
        """

        if len(solution) > 1:
            sequences = [self.decode(seq) for seq in solution]
        else:
            sequences = [self.decode(solution)]

        # Fallback to single sequence evaluation only if absolutely necessary
        try:
            fitness_scores, stage_outputs = self.fitness_callback(sequences)

            for sequence, fitness in zip(sequences, fitness_scores):
                self.fitness_cache[sequence] = fitness

            return fitness_scores
        except Exception as e:
            logger.error(
                f"Error calculating fitness for uncached sequence: {e}")

            return 0.0

    def setup_initial_population(self):
        """
        Setup the initial population based on the starting chromosome and the distance constraint.
        Sample mutation number per sequence from a poisson with mean 1.5
        """
        population_size = self.ga_params.get("sol_per_pop", 100)
        max_mutations = getattr(self, "max_mutations", 1)
        max_distance = getattr(self, "max_distance", None)

        if max_distance is None:
            raise ValueError(
                "max_distance must be set for distance-constrained initialization."
            )

        # Start with the starting chromosome
        initial_population = [self.starting_chromosome.copy()]

        logger.info(
            f"Setting up initial population with {population_size} sequences")
        logger.info(f"Max mutations: {max_mutations}")
        logger.info(f"Max distance: {max_distance}")

        mut_dist = []

        for i in range(1, population_size):
            new_individual = self.starting_chromosome.copy()

            # Sample number of mutations for this individual
            num_mutations = np.random.poisson(lam=1.5)
            # Clamp to [1, max_mutations]
            num_mutations = int(np.clip(num_mutations, 1, max_mutations))
            mut_dist.append(num_mutations)

            # Randomly select unique positions to mutate
            mutation_positions = np.random.choice(
                len(new_individual), size=num_mutations, replace=False
            )

            for pos in mutation_positions:
                # Pick a mutation different from the current value
                possible_mutations = [
                    aa for aa in self.gene_space[pos] if aa != new_individual[pos]
                ]

                if possible_mutations:
                    new_individual[pos] = np.random.choice(possible_mutations)
                # If no possible mutation, leave as is

            initial_population.append(new_individual)

        logger.info(f"Initial population size: {len(initial_population)}")
        logger.info(
            f"Initial Population distance distribution:\n"
            f"mean {np.mean(mut_dist):.2f} \n"
            f"min {np.min(mut_dist):.2f} \n"
            f"max {np.max(mut_dist):.2f}"
        )

        return np.array(initial_population)

    def on_start(self, ga_instance):
        logger.info("Starting genetic algorithm")
        self.generation_count = 0

    def on_generation(self, ga_instance):
        """Called after each generation - used for tracking and logging"""

        if self.stop_run:

            return "stop"
        self.generation_count += 1

        # Get current population sequences and fitness
        current_sequences = [
            self.decode(solution) for solution in ga_instance.population
        ]
        current_fitness = ga_instance.last_generation_fitness

        # Calculate distances from each chromosome to starting chromosome
        distances_to_start = []

        for chromosome in ga_instance.population:
            distance = hamming(
                chromosome[self.variable_positions],
                self.starting_chromosome[self.variable_positions],
            )
            distances_to_start.append(distance)

        # Calculate distance statistics
        min_distance = np.min(distances_to_start)
        max_distance = np.max(distances_to_start)
        mean_distance = np.mean(distances_to_start)

        # Track best solution
        best_idx = np.argmax(current_fitness)
        current_best_fitness = current_fitness[best_idx]
        current_best_sequence = current_sequences[best_idx]

        # Track improvement and handle injection monitoring
        fitness_improved = current_best_fitness > self.best_fitness

        # Store all sequences and fitness for this generation
        self.all_sequences.extend(current_sequences)
        self.all_fitnesses.extend(current_fitness.tolist())

        logger.info(
            f"Generation {self.generation_count}: Best fitness = {current_best_fitness:.4f} "
            f"(Plateau: {self.generations_without_improvement}, "
            f"Consecutive injections: {self.consecutive_injections_without_improvement})"
        )

        # Log generation metrics if logger is available

        if self.metrics_logger is not None:
            try:
                generation_metrics = {
                    "generation_mean_fitness": np.mean(current_fitness),
                    "generation_std_fitness": np.std(current_fitness),
                    "generation_min_fitness": np.min(current_fitness),
                    "generation_max_fitness": np.max(current_fitness),
                    "overall_best_fitness": self.best_fitness,
                    "generations_without_improvement": self.generations_without_improvement,
                    "consecutive_injections_without_improvement": self.consecutive_injections_without_improvement,
                    "total_injections": self.total_injections,
                    "fitness_improved": int(fitness_improved),
                    "distance_to_start_min": min_distance,
                    "distance_to_start_max": max_distance,
                    "distance_to_start_mean": mean_distance,
                }
                self.metrics_logger(generation_metrics,
                                    step=self.generation_count)
            except Exception as e:
                logger.warning(f"Failed to log generation metrics: {e}")

        if fitness_improved:
            self.best_fitness = current_best_fitness
            self.best_sequence = current_best_sequence
            self.generations_without_improvement = 0  # Reset plateau counter

            # Reset consecutive injection counter when we see actual improvement

            if self.consecutive_injections_without_improvement > 0:
                logger.info(
                    f"Fitness improved after {self.consecutive_injections_without_improvement} "
                    f"consecutive injections - resetting injection counter"
                )
                self.consecutive_injections_without_improvement = 0

            logger.info(f"New best fitness found: {current_best_fitness:.4f}")
        else:
            self.generations_without_improvement += 1

        # Check for plateau and inject training data if needed

        if self.generations_without_improvement >= self.plateau_threshold:
            logger.info(
                f"Plateau detected after {self.generations_without_improvement} generations without improvement"
            )

            if self.num_retries < self.max_retries:
                self.num_retries += 1
                self.reset(keep_cache=True)

                return "retry"  # this does not actually do anything in pygad loop. it check return value == "stop" else just continue
            elif not self.use_injection:
                return "stop"

            # Store fitness before injection for monitoring
            self.fitness_before_injection = self.best_fitness

            # Attempt injection

            if self._inject_training_data(ga_instance):
                self.generations_without_improvement = 0  # Reset plateau counter
                self.total_injections += 1
                self.consecutive_injections_without_improvement += 1

                # Log injection event with detailed monitoring
                injection_event = {
                    "generation": self.generation_count,
                    "fitness_before": self.fitness_before_injection,
                    "fitness_after": self.best_fitness,
                    "consecutive_count": self.consecutive_injections_without_improvement,
                    "total_injections": self.total_injections,
                }
                self.injection_history.append(injection_event)

                logger.info(
                    f"Training data injected (#{self.total_injections}). "
                    f"Consecutive injections without improvement: {self.consecutive_injections_without_improvement}"
                )

                # Check for early stopping condition

                if (
                    self.consecutive_injections_without_improvement
                    >= self.max_consecutive_injections
                ):
                    self.early_stopped = True
                    self.stop_reason = (
                        f"Reached maximum consecutive injections ({self.max_consecutive_injections}) "
                        f"without improvement"
                    )
                    logger.warning(
                        f"Early stopping triggered: {self.stop_reason}")

                    return "stop"  # pygad way of gracefully stopping
            else:
                logger.warning(
                    "Failed to inject training data - no suitable sequences found"
                )
                # notify

                return "stop"

    def on_fitness(self, ga_instance, fitness_list):
        logger.debug(
            f"Fitness calculated for generation {self.generation_count}")

    def on_crossover(self, ga_instance, offspring_crossover):
        logger.debug(
            f"Crossover completed for generation {self.generation_count}")

    def on_mutation(self, ga_instance, offspring_mutation):
        logger.debug(
            f"Mutation completed for generation {self.generation_count}")

    def on_parents(self, ga_instance, selected_parents):
        logger.debug(
            f"Parents selected for generation {self.generation_count}")

    def _inject_training_data(self, ga_instance) -> bool:
        """
        Inject a training data point with fitness slightly better than current best.

        Args:
            ga_instance: The pygad GA instance

        Returns:
            bool: True if injection was successful, False otherwise
        """
        logger.info(f"Injecting training data due to plateau")
        try:
            # Get training sequences and their fitness values
            training_sequences = self.data.get_sequences()
            training_fitness = self.initial_fitness_values

            # Find training examples with fitness better than current best
            better_indices = [
                i

                for i, fitness in enumerate(training_fitness)

                if fitness > self.best_fitness
            ]

            if not better_indices:
                logger.warning(
                    "No training examples found with fitness better than current best"
                )
                # notify

                return False

            # select from the lower half of better examples (closer to current best)
            better_indices.sort(key=lambda i: training_fitness[i])
            selection_pool = better_indices[: max(1, len(better_indices) // 2)]
            selected_idx = np.random.choice(selection_pool)

            selected_sequence = training_sequences[selected_idx]
            selected_fitness = training_fitness[selected_idx]
            encoded_sequence = self.encode(selected_sequence)

            # Replace worst individual with selected training example
            current_fitness = ga_instance.last_generation_fitness
            worst_idx = np.argmin(current_fitness)
            ga_instance.population[worst_idx] = encoded_sequence
            self.fitness_cache[selected_sequence] = selected_fitness
            ga_instance.last_generation_fitness[worst_idx] = selected_fitness

            # Update our best fitness tracking if this injected sequence is better

            if selected_fitness > self.best_fitness:
                self.best_fitness = selected_fitness
                self.best_sequence = selected_sequence
                logger.info(
                    f"New overall best fitness from injected data: {selected_fitness:.4f}"
                )

            return True

        except Exception as e:
            logger.error(f"Failed to inject training data: {e}")

            return False

    @staticmethod
    def aa2int(aa: str) -> int:
        return {aa: int(i) for i, aa in enumerate(AA_ALPHABET_WITH_EXTRAS)}[aa]

    @staticmethod
    def int2aa(i: int) -> str:
        return AA_ALPHABET_WITH_EXTRAS[i]

    def encode(self, sequence: str) -> np.ndarray:
        """Encode a sequence to integer representation for all positions"""

        if len(sequence) != len(self.alphabet_per_position):
            raise ValueError(
                f"Sequence length {len(sequence)} doesn't match expected length {len(self.alphabet_per_position)}"
            )

        # Encode all positions (not just variable ones)

        return np.array([self.aa2int(sequence[i]) for i in range(len(sequence))])

    def decode(self, int_sequence: np.ndarray) -> str:
        """Decode integer representation back to sequence string"""

        if len(int_sequence) != len(self.alphabet_per_position):
            raise ValueError(
                f"Encoded sequence length {len(int_sequence)} doesn't match expected length {len(self.alphabet_per_position)}"
            )

        # Decode all positions

        return "".join([self.int2aa(int(aa_int)) for aa_int in int_sequence])

    def _create_gene_space(self, alphabet_per_position: List[List[str]]):
        """Create gene space for pygad based on allowed amino acids per position"""
        logger.debug("Creating gene space")

        # Track which positions are variable (more than one option)
        # NOTE: when choosing de-novo design this includes the fixed binding partner.
        self.variable_positions = [
            i for i, alphabet in enumerate(alphabet_per_position) if len(alphabet) > 1
        ]

        # Create gene_space for ALL positions (pygad requires this)
        # Fixed positions will have only one option, variable positions will have multiple
        gene_space = [
            [self.aa2int(aa) for aa in alphabet] for alphabet in alphabet_per_position
        ]

        logger.debug(
            f"Created gene space with {len(gene_space)} total positions")
        logger.debug(
            f"Variable positions: {len(self.variable_positions)} out of {len(gene_space)}"
        )

        return gene_space

    def get_best_solution(self):
        """Get the best solution found so far"""

        return self.best_sequence, self.best_fitness

    def get_all_solutions(self):
        """Get all solutions and fitness values from all generations"""

        return self.all_sequences, self.all_fitnesses

    def get_monitoring_statistics(self) -> Dict[str, Any]:
        """
        Get comprehensive monitoring statistics for the run.

        Returns:
            Dict containing run statistics, injection history, and monitoring data
        """

        return {
            "generation_count": self.generation_count,
            "early_stopped": self.early_stopped,
            "stop_reason": self.stop_reason,
            "total_injections": self.total_injections,
            "consecutive_injections_without_improvement": self.consecutive_injections_without_improvement,
            "max_consecutive_injections": self.max_consecutive_injections,
            "injection_history": self.injection_history.copy(),
            "plateau_threshold": self.plateau_threshold,
            "generations_without_improvement": self.generations_without_improvement,
            "best_fitness": self.best_fitness,
            "total_sequences_generated": len(self.all_sequences),
        }

    def get_injection_summary(self) -> Dict[str, Any]:
        """
        Get a summary of injection events and their effectiveness.

        Returns:
            Dict containing injection statistics and effectiveness metrics
        """

        if not self.injection_history:
            return {
                "total_injections": 0,
                "effective_injections": 0,
                "effectiveness_rate": 0.0,
                "average_improvement": 0.0,
                "injection_generations": [],
            }

        # Calculate injection effectiveness
        effective_injections = 0
        total_improvement = 0.0
        injection_generations = []

        for event in self.injection_history:
            improvement = event["fitness_after"] - event["fitness_before"]
            total_improvement += improvement
            injection_generations.append(event["generation"])

            if improvement > 0:
                effective_injections += 1

        effectiveness_rate = (
            effective_injections / len(self.injection_history)

            if self.injection_history
            else 0.0
        )
        average_improvement = (
            total_improvement / len(self.injection_history)

            if self.injection_history
            else 0.0
        )

        return {
            "total_injections": self.total_injections,
            "effective_injections": effective_injections,
            "effectiveness_rate": effectiveness_rate,
            "average_improvement": average_improvement,
            "injection_generations": injection_generations,
            "consecutive_without_improvement": self.consecutive_injections_without_improvement,
        }

    def reset(self, keep_cache: bool = False):
        """Reset the GA instance to start a new run"""

        if not self.initialized:
            return

        # Request new MLflow run if callback is available

        if self.new_run_callback is not None:
            logger.info("Requesting new MLflow run for trajectory restart")
            self.new_run_callback()

        # Reset tracking variables
        self.generation_count = 0

        if not keep_cache:
            self.fitness_cache = {}
        self.best_sequence = None
        self.best_fitness = -float("inf")
        self.all_sequences = []
        self.all_fitnesses = []

        # Reset plateau tracking
        self.generations_without_improvement = 0
        self.last_best_fitness = -float("inf")

        # Reset injection tracking and monitoring
        self.consecutive_injections_without_improvement = 0
        self.total_injections = 0
        self.injection_history = []
        self.fitness_before_injection = -float("inf")
        self.early_stopped = False
        self.stop_reason = ""

        # Reseed the RNG with "random" random seed
        random_seed = np.random.randint(0, 1000000)
        np.random.seed(random_seed)
        # log the new seed
        logger.info(f"New random seed: {random_seed}")
        # set the seed in the GA instance
        self.ga_params["random_seed"] = random_seed

        if self.metrics_logger is not None:
            reset_metrics = {
                "random_seed": self.ga_params.get("random_seed"),
                "num_retries": self.num_retries,
                "generation_count": self.generation_count,
            }
            reset_metrics = {k: v for k,
                             v in reset_metrics.items() if v is not None}
            self.metrics_logger(reset_metrics, step=self.generation_count)
        #     self.metrics_logger(self.ga_params, step=self.generation_count)

        # Cache initial fitness values again

        if self.populate_cache_with_training_data:
            initial_sequences = self.data.get_sequences()

            for seq, fitness in zip(initial_sequences, self.initial_fitness_values):
                self.fitness_cache[seq] = fitness

    def distance_constrained_mutation(self, offspring, ga_instance):
        """
        Custom mutation function that generates a new mutated population,
        each offspring satisfying the distance constraint. Does not use the parent
        if a valid mutation cannot be found within max_attempts.
        """
        mutation_num_genes = getattr(ga_instance, "mutation_num_genes", 1)
        population_size = offspring.shape[0]
        mutated_offspring = []
        max_attempts = mutation_num_genes * 10

        for i in range(population_size):
            found_valid = False
            attempts = 0

            while attempts < max_attempts:
                # Start from the original chromosome
                chromosome = offspring[i].copy()
                mutation_positions = np.random.choice(
                    self.variable_positions, size=mutation_num_genes, replace=False
                ) if len(self.variable_positions) >= mutation_num_genes else self.variable_positions

                for pos in mutation_positions:
                    possible_mutations = [
                        aa for aa in self.gene_space[pos] if aa != chromosome[pos]
                    ]

                    if possible_mutations:
                        chromosome[pos] = np.random.choice(possible_mutations)
                # Check distance constraint
                distance = hamming(
                    chromosome[self.variable_positions],
                    self.starting_chromosome[self.variable_positions],
                )

                if distance <= self.max_distance:
                    mutated_offspring.append(chromosome)
                    found_valid = True

                    break
                attempts += 1
            # If no valid mutation found, skip adding this offspring

            if not found_valid:
                continue

        # If not enough valid offspring, fill with random valid mutations

        while len(mutated_offspring) < population_size:
            attempts = 0

            while attempts < max_attempts:
                # Randomly select a parent to mutate
                parent_idx = np.random.choice(population_size)
                chromosome = offspring[parent_idx].copy()
                mutation_positions = np.random.choice(
                    self.variable_positions, size=mutation_num_genes, replace=False
                ) if len(self.variable_positions) >= mutation_num_genes else self.variable_positions

                for pos in mutation_positions:
                    possible_mutations = [
                        aa for aa in self.gene_space[pos] if aa != chromosome[pos]
                    ]

                    if possible_mutations:
                        chromosome[pos] = np.random.choice(possible_mutations)
                distance = hamming(
                    chromosome[self.variable_positions],
                    self.starting_chromosome[self.variable_positions],
                )

                if distance <= self.max_distance:
                    mutated_offspring.append(chromosome)

                    break
                attempts += 1
            # If still not found, skip (will be filled in next loop if needed)

        return np.array(mutated_offspring)

    def distance_constrained_crossover(self, parents, offspring_size, ga_instance):
        """
        Custom crossover function that respects the distance constraint.
        """

        offspring = []
        idx = 0
        # Generate offspring by crossover, but only add children that satisfy the distance constraint.
        # Do not add a parent if the child is invalid; keep trying until enough valid offspring are found.
        num_offspring_needed = offspring_size[0]
        num_parents = parents.shape[0]
        max_attempts = num_offspring_needed * 20  # Prevent infinite loops

        attempts = 0

        while len(offspring) < num_offspring_needed and attempts < max_attempts:
            # Randomly select two parents
            parent_indices = np.random.choice(
                num_parents, size=2, replace=False)
            parent1 = parents[parent_indices[0], :].copy()
            parent2 = parents[parent_indices[1], :].copy()

            # Perform a single-point crossover at a random variable position
            random_split_point = np.random.choice(self.variable_positions)
            child = np.concatenate(
                (parent1[:random_split_point], parent2[random_split_point:])
            )

            # Check the distance of the new child
            distance = hamming(
                child[self.variable_positions],
                self.starting_chromosome[self.variable_positions],
            )

            # Only add the child if it satisfies the distance constraint

            if distance <= self.max_distance:
                offspring.append(child)

            attempts += 1

        # If we couldn't generate enough valid offspring, fill the rest with random parents (as fallback)

        while len(offspring) < num_offspring_needed:
            parent_idx = np.random.choice(num_parents)
            offspring.append(parents[parent_idx, :].copy())

        return np.array(offspring)

    def run(self):
        """
        Run the genetic algorithm using pygad's native run method.
        This handles the entire evolution process internally.
        """

        if not self.initialized:
            raise RuntimeError(
                "Generator not initialized. Call setup_generator first.")

        # Reset tracking variables
        self.generation_count = 0
        self.best_sequence = None
        self.best_fitness = -float("inf")
        self.all_sequences = []
        self.all_fitnesses = []

        # Reset plateau tracking
        self.generations_without_improvement = 0
        self.last_best_fitness = -float("inf")

        # Reset injection tracking and monitoring
        self.consecutive_injections_without_improvement = 0
        self.total_injections = 0
        self.injection_history = []
        self.fitness_before_injection = -float("inf")
        self.early_stopped = False
        self.stop_reason = ""

        # Run the genetic algorithm using pygad's native method with early stopping support
        super().run()
        completion_status = "completed normally"

        # Log comprehensive run summary
        logger.info(
            f"Genetic algorithm {completion_status} after {self.generation_count} generations"
        )
        logger.info(f"Best sequence: {self.best_sequence}")
        logger.info(f"Best fitness: {self.best_fitness:.4f}")
        logger.info(f"Total training data injections: {self.total_injections}")

        if self.injection_history:
            logger.info("Injection history:")

            for i, event in enumerate(self.injection_history):
                improvement = event["fitness_after"] - event["fitness_before"]
                logger.info(
                    f"  Injection {i+1} (Gen {event['generation']}): "
                    f"fitness {event['fitness_before']:.4f} → {event['fitness_after']:.4f} "
                    f"(Δ={improvement:+.4f})"
                )

        # Log final monitoring metrics

        if self.metrics_logger is not None:
            try:
                final_metrics = {
                    "run_completed_normally": int(not self.early_stopped),
                    "early_stopped": int(self.early_stopped),
                    "total_generations": self.generation_count,
                    "final_best_fitness": self.best_fitness,
                    "total_injections": self.total_injections,
                    "final_consecutive_injections": self.consecutive_injections_without_improvement,
                }
                self.metrics_logger(final_metrics, step=self.generation_count)
            except Exception as e:
                logger.warning(f"Failed to log final metrics: {e}")

        return self.best_sequence, self.best_fitness

    def shutdown(self):
        """Shutdown the generator"""
        self.stop_run = True
        self.on_generation(self)
