from typing import List
import random
import numpy as np
from .base import BaseEvolutionAlgorithm
from ..multi_objective.individual import MultiObjectiveIndividual
from ..llm_integration import LLMClient
from ..operators.multi_objective_crossover import MultiObjectiveCrossoverOperator
from ..operators.multi_objective_mutation import MultiObjectiveMutationOperator
from ..operators.verify_operator import VerifyOperator

class MultiObjectiveEvolution(BaseEvolutionAlgorithm):
    
    def __init__(self, config: dict, llm_client: LLMClient):
        super().__init__(config)
        self.llm_client = llm_client
        self.verify_operator = VerifyOperator(llm_client)
        self.crossover_operator = MultiObjectiveCrossoverOperator(
            llm_client, self.verify_operator)
        self.mutation_operator = MultiObjectiveMutationOperator(
            llm_client, self.verify_operator)
        self.tournament_size = config.get("tournament_size", 3)
    
    def select(self, population: List[MultiObjectiveIndividual],
               num_parents: int) -> List[MultiObjectiveIndividual]:
        fronts = self._non_dominated_sort(population)
        
        parents = []
        for _ in range(num_parents):
            candidates = random.sample(population, min(self.tournament_size, len(population)))
            winner = min(candidates, key=lambda x: x.rank)
            parents.append(winner)
        return parents
    
    def crossover(self, parents: List[MultiObjectiveIndividual]) -> List[MultiObjectiveIndividual]:
        if len(parents) < 2:
            return parents
            
        if random.random() < self.config.get("crossover_rate", 0.9):
            objective_names = list(parents[0].fitnesses.keys())

            child_code, _ = self.crossover_operator.crossover(
                parents[0], parents[1], objective_names)
            
            if child_code:
                return [MultiObjectiveIndividual(
                    child_code, 
                    generation=parents[0].generation + 1
                )]
        return [parents[0]]
    
    def mutate(self, individual: MultiObjectiveIndividual) -> MultiObjectiveIndividual:
        if random.random() < self.config.get("mutation_rate", 0.1):
            objective_names = list(individual.fitnesses.keys())
            mutated_code, _ = self.mutation_operator.mutate(
                individual, objective_names)
            
            if mutated_code:
                return MultiObjectiveIndividual(
                    mutated_code,
                    generation=individual.generation
                )
        return individual
    
    def survive(self, population: List[MultiObjectiveIndividual],
               offspring: List[MultiObjectiveIndividual],
               pop_size: int) -> List[MultiObjectiveIndividual]:
        combined_pop = population + offspring
        fronts = self._non_dominated_sort(combined_pop)
        new_pop = []
        for front in fronts:
            if len(new_pop) < pop_size:
                new_pop.extend(front)
            else:
                break
        if len(new_pop) < pop_size:
            remaining = pop_size - len(new_pop)
            for front in fronts[len(new_pop):]:
                if remaining <= 0:
                    break
                new_pop.extend(front[:remaining])
                remaining -= len(front)
                
        return new_pop[:pop_size]
    
    def _non_dominated_sort(self, population: List[MultiObjectiveIndividual]) -> List[List[MultiObjectiveIndividual]]:
        fronts = [[]]
        for ind in population:
            ind.domination_count = 0
            ind.dominated_set = []
            for other in population:
                if ind.dominates(other):
                    ind.dominated_set.append(other)
                elif other.dominates(ind):
                    ind.domination_count += 1
            if ind.domination_count == 0:
                fronts[0].append(ind)
                ind.rank = 0
        
        i = 0
        while fronts[i]:
            next_front = []
            for ind in fronts[i]:
                for dominated_ind in ind.dominated_set:
                    dominated_ind.domination_count -= 1
                    if dominated_ind.domination_count == 0:
                        dominated_ind.rank = i + 1
                        next_front.append(dominated_ind)
            i += 1
            fronts.append(next_front)
        
        return fronts[:-1] 