import os
import json
from typing import List
from ...utils.multi_objective.evaluator import MultiObjectiveEvaluator
from ...utils.multi_objective.storage import MultiObjectiveGenerationStorage
from .individual import MultiObjectiveIndividual
from ..llm_integration import LLMClient
from ...config.settings import DEFAULT_EVOLUTION_PARAMS
from ..evolution_algorithms.multi_objective import MultiObjectiveEvolution
from ..operators.initialize_operator import InitializeOperator

class MultiObjectiveEvolutionEngine:
    
    def __init__(self, problem_path: str):
        self.problem_path = problem_path
        self.storage = MultiObjectiveGenerationStorage(problem_path)
        self.evaluator = MultiObjectiveEvaluator(problem_path)
        self.llm_client = LLMClient.from_config(problem_path)
        self.initialize_operator = InitializeOperator(self.llm_client)
        
        config = self._load_problem_config()
        self.evolution_params = {
            **DEFAULT_EVOLUTION_PARAMS,
            **config.get("evolution_params", {})
        }
        
        self.evolution_algorithm = MultiObjectiveEvolution(
            self.evolution_params,
            self.llm_client
        )
        
    
    def initialize_population(self, size: int) -> List[MultiObjectiveIndividual]:
        problem_config = self._load_problem_config()
        ideas = self.initialize_operator.ideas_generator.generate_ideas(
            problem_config["description"],
            size
        )
        population = []
        for i, idea in enumerate(ideas):
            code = self.initialize_operator.generate_initial_code(
                problem_config["description"],
                problem_config["function_name"],
                problem_config["input_format"],
                problem_config["output_format"],
                idea
            )
            if code:
                population.append(MultiObjectiveIndividual(
                    code,
                    generation=0,
                    metadata={"idea": idea}  
                ))
        return population
    
    def run_evolution(self, generations: int = None, population_size: int = None):

        generations = generations or self.evolution_params["generations"]
        population_size = population_size or self.evolution_params["population_size"]

        population = self.initialize_population(population_size)
        

        for ind in population:
            ind.fitnesses = self.evaluator.evaluate(ind.code)

        for gen in range(generations):
            offspring = []
            while len(offspring) < population_size:
                parents = self.evolution_algorithm.select(population, 2)
                children = self.evolution_algorithm.crossover(parents)
                for child in children:
                    mutated_child = self.evolution_algorithm.mutate(child)
                    if mutated_child.code:
                        mutated_child.fitnesses = self.evaluator.evaluate(mutated_child.code)
                        offspring.append(mutated_child)
            
            population = self.evolution_algorithm.survive(population, offspring, population_size)
            
            self.storage.save_generation(gen, population)
            
            front = self._calculate_pareto_front(population)

        return population
    
    def _calculate_pareto_front(self, population: List[MultiObjectiveIndividual]) -> List[MultiObjectiveIndividual]:
        front = []
        for ind in population:
            is_dominated = False
            for other in population:
                if other.dominates(ind):
                    is_dominated = True
                    break
            if not is_dominated:
                front.append(ind)
        return front
    
    def _load_problem_config(self) -> dict:
        config_path = os.path.join(self.problem_path, "problem_config.json")
        with open(config_path, "r", encoding="utf-8") as f:
            return json.load(f)