from typing import List, Tuple
import random
import numpy as np
from .base import BaseEvolutionAlgorithm
from ..individual import AlgorithmIndividual
from ..llm_integration import LLMClient
from ..operators import VerifyOperator
from ..operators import DECrossoverOperator, DEMutationOperator

class DifferentialEvolution(BaseEvolutionAlgorithm):
    def __init__(self, config: dict, llm_client: LLMClient):
        super().__init__(config)
        self.F = config.get("F", 0.8) 
        self.CR = config.get("CR", 0.7) 
        self.llm_client = llm_client
        self.crossover_operator = DECrossoverOperator(llm_client)
        self.mutation_operator = DEMutationOperator(llm_client)
        self.verify_operator = VerifyOperator(llm_client)
        
    def select(self, population: List[AlgorithmIndividual], num_parents: int) -> List[AlgorithmIndividual]:
        parents_groups = []
        for _ in range(num_parents):
            candidates = random.sample(population, 3)
            parents_groups.append(candidates)
        return [group[0] for group in parents_groups]  
    
    def crossover(self, parents: List[AlgorithmIndividual]) -> List[AlgorithmIndividual]:
        if len(parents) < 1:
            return parents
        
        target = parents[0] 
        if random.random() < self.CR:
            child_code = self.crossover_operator.crossover(target.code)
            if child_code:
                return [AlgorithmIndividual(child_code, target.generation + 1)]
        
        return [target]
    
    def mutate(self, individual: AlgorithmIndividual) -> AlgorithmIndividual:
        if random.random() < self.F:
            mutated_code = self.mutation_operator.mutate(individual.code)
            if mutated_code:
                return AlgorithmIndividual(mutated_code, individual.generation)
        
        return individual
    
    def survive(self, population: List[AlgorithmIndividual], offspring: List[AlgorithmIndividual], 
               pop_size: int) -> List[AlgorithmIndividual]:
        survivors = []
        for parent, child in zip(population[:len(offspring)], offspring):
            if (child.fitness or float('inf')) < (parent.fitness or float('inf')):
                survivors.append(child)
            else:
                survivors.append(parent)
        
        if len(survivors) < pop_size:
            survivors.extend(population[len(survivors):pop_size])
        
        return survivors[:pop_size]