class NSGA2:
    """
    NSGA-II for Multi-Objective Combinatorial Optimization.
    Supports: TSP, Knapsack, and CVRP problems.
    """
    
    def __init__(self, problem, population_size=100, n_generations=200, 
                 crossover_prob=0.9, mutation_prob=0.1, tournament_size=2, 
                 verbose=True, reference_point=None):
        self.problem = problem
        self.population_size = population_size
        self.n_generations = n_generations
        self.crossover_prob = crossover_prob
        self.mutation_prob = mutation_prob
        self.verbose = verbose
        self.tournament_size = tournament_size
        self.reference_point = reference_point
        
        # Detect problem type
        self.problem_type = self._detect_problem_type()
        logger.info(f"Detected problem type: {self.problem_type}")
        
        # Set number of objectives
        if hasattr(problem, 'n_objectives'):
            self.n_objectives = problem.n_objectives
        elif hasattr(problem, 'num_objectives'):
            self.n_objectives = problem.num_objectives
        else:
            self.n_objectives = 2
            
        self.is_minimization = self.problem_type in ['tsp', 'cvrp']
        
        self.population = []
        self.objective_values = []
        self.pareto_front = []
        self.pareto_objectives = []
        
        self.history = {
            'hypervolume': [],
            'runtime': [],
            'num_solutions': [],
            'best_solution_per_objective': [[] for _ in range(self.n_objectives)]
        }
        
        self.start_time = None
    
    def _detect_problem_type(self) -> str:
        if hasattr(self.problem, 'n_customers') and hasattr(self.problem, 'n_vehicles'):
            return 'cvrp'
        elif hasattr(self.problem, 'n_cities'):
            return 'tsp'
        elif hasattr(self.problem, 'n_items'):
            return 'knapsack'
        elif hasattr(self.problem, 'n_customers'):
            return 'cvrp'
        else:
            raise ValueError("Unknown problem type")
    
    def initialize_population(self):
        logger.info(f"Generating initial population of {self.population_size} solutions...")
        
        self.population = []
        self.objective_values = []
        
        max_attempts = self.population_size * 20
        attempts = 0
        
        while len(self.population) < self.population_size and attempts < max_attempts:
            try:
                attempts += 1
                
                
                if self.problem_type == 'tsp':
                    sol = list(range(self.problem.n_cities))
                    random.shuffle(sol)
                elif self.problem_type == 'knapsack':
                    sol = self._generate_knapsack_solution()
                else:
                    sol = self.problem.random_solution()
           
                
                obj_values = self.problem.evaluate(sol)
                
                if any(np.isnan(val) or np.isinf(val) for val in obj_values):
                    continue
                
                self.population.append(sol)
                self.objective_values.append(obj_values)


                
            except Exception as e:
                logger.warning(f"Error generating solution: {e}")
        
        if len(self.population) < self.population_size:
            if len(self.population) > 0:
                logger.warning(f"Only generated {len(self.population)} valid solutions. Duplicating...")
                while len(self.population) < self.population_size:
                    idx = random.randint(0, len(self.population) - 1)
                    self.population.append(deepcopy(self.population[idx]))
                    self.objective_values.append(list(self.objective_values[idx]))
            else:
                raise ValueError("Failed to generate any valid initial solutions")
        
        logger.info(f"Successfully generated {len(self.population)} solutions in {attempts} attempts")
    
    def _generate_cvrp_solution(self) -> List[List[int]]:
        unassigned = list(range(1, self.problem.n_customers + 1))
        random.shuffle(unassigned)
        
        solution = []
        
        while unassigned:
            route = []
            route_demand = 0
            
            for customer in unassigned[:]:
                customer_demand = self.problem.customers[customer].demand
                if route_demand + customer_demand <= self.problem.vehicle_capacity:
                    route.append(customer)
                    route_demand += customer_demand
                    unassigned.remove(customer)
            
            if not route and unassigned:
                customer = min(unassigned, key=lambda c: self.problem.customers[c].demand)
                route.append(customer)
                unassigned.remove(customer)
            
            if route:
                solution.append(route)
        
        return solution
    
    def _generate_knapsack_solution(self) -> List[int]:
        sol = [0] * self.problem.n_items
        
        if hasattr(self.problem, 'capacity') and hasattr(self.problem, 'weights'):
            selection_prob = min(0.3, self.problem.capacity / (sum(self.problem.weights) / 2))
            for i in range(self.problem.n_items):
                sol[i] = 1 if random.random() < selection_prob else 0
            
            current_weight = sum(self.problem.weights[i] for i in range(len(sol)) if sol[i] == 1)
            if current_weight > self.problem.capacity:
                sol = self._repair_knapsack_solution(sol)
        else:
            sol = [random.randint(0, 1) for _ in range(self.problem.n_items)]
        
        return sol
    
    # CROSSOVER OPERATORS
    
    def crossover_cvrp(self, parent1: List[List[int]], parent2: List[List[int]]) -> Tuple[List[List[int]], List[List[int]]]:
        offspring1 = self._cvrp_crossover_single(parent1, parent2)
        offspring2 = self._cvrp_crossover_single(parent2, parent1)
        return offspring1, offspring2
    
    def _cvrp_crossover_single(self, parent1: List[List[int]], parent2: List[List[int]]) -> List[List[int]]:
        """Generic crossover - NO distance knowledge"""
        offspring = [route[:] for route in parent2]
        
        if not parent1:
            return offspring
        
        num_routes_to_select = min(random.randint(1, 2), len(parent1))
        selected_routes = random.sample(parent1, num_routes_to_select)
        
        customers_to_insert = set()
        for route in selected_routes:
            customers_to_insert.update(route)
        
        for route in offspring:
            route[:] = [c for c in route if c not in customers_to_insert]
        
        offspring = [route for route in offspring if route]
        
        # CHANGED: Random insertion instead of best insertion
        for customer in customers_to_insert:
            offspring = self._random_insertion_cvrp(offspring, customer)
        
        return offspring

    def _random_insertion_cvrp(self, solution: List[List[int]], customer: int) -> List[List[int]]:
        """Random insertion - NO distance knowledge (like TSP operators)"""
        if solution:
            route_idx = random.randint(0, len(solution) - 1)
            pos = random.randint(0, len(solution[route_idx]))
            solution[route_idx].insert(pos, customer)
        else:
            solution.append([customer])
        return solution

    def crossover_tsp(self, parent1, parent2):
        n_cities = len(parent1)
        
        cx_points = sorted(random.sample(range(n_cities), 2))
        cx_point1, cx_point2 = cx_points
        
        offspring1 = [-1] * n_cities
        offspring2 = [-1] * n_cities
        
        offspring1[cx_point1:cx_point2] = parent1[cx_point1:cx_point2]
        offspring2[cx_point1:cx_point2] = parent2[cx_point1:cx_point2]
        
        pos1 = cx_point2
        pos2 = cx_point2
        
        while -1 in offspring1:
            city = parent2[pos2 % n_cities]
            if city not in offspring1:
                offspring1[pos1 % n_cities] = city
                pos1 = (pos1 + 1) % n_cities
            pos2 = (pos2 + 1) % n_cities
        
        pos1 = cx_point2
        pos2 = cx_point2
        
        while -1 in offspring2:
            city = parent1[pos2 % n_cities]
            if city not in offspring2:
                offspring2[pos1 % n_cities] = city
                pos1 = (pos1 + 1) % n_cities
            pos2 = (pos2 + 1) % n_cities
        
        return offspring1, offspring2
    
    def crossover_knapsack(self, parent1, parent2):
        n_items = len(parent1)
        
        offspring1 = parent1.copy()
        offspring2 = parent2.copy()
        
        for i in range(n_items):
            if random.random() < 0.5:
                offspring1[i], offspring2[i] = offspring2[i], offspring1[i]
        
        if not self.is_solution_feasible(offspring1):
            offspring1 = self._repair_knapsack_solution(offspring1)
        if not self.is_solution_feasible(offspring2):
            offspring2 = self._repair_knapsack_solution(offspring2)
        
        return offspring1, offspring2
    
    # MUTATION OPERATORS
    
    def mutation_cvrp(self, solution: List[List[int]]) -> List[List[int]]:
        if random.random() > self.mutation_prob:
            return solution
        
        mutated = [route[:] for route in solution]
        
        operator = random.choice(['swap_within', 'swap_between', 'relocate', '2opt'])
        
        if operator == 'swap_within':
            mutated = self._mutation_swap_within_route(mutated)
        elif operator == 'swap_between':
            mutated = self._mutation_swap_between_routes(mutated)
        elif operator == 'relocate':
            mutated = self._mutation_relocate(mutated)
        elif operator == '2opt':
            mutated = self._mutation_2opt_within_route(mutated)
        
        mutated = self._repair_cvrp_solution(mutated)
        
        return mutated
    
    def _mutation_swap_within_route(self, solution: List[List[int]]) -> List[List[int]]:
        valid_routes = [i for i, route in enumerate(solution) if len(route) >= 2]
        
        if not valid_routes:
            return solution
        
        route_idx = random.choice(valid_routes)
        route = solution[route_idx]
        
        i, j = random.sample(range(len(route)), 2)
        route[i], route[j] = route[j], route[i]
        
        return solution
    
    def _mutation_swap_between_routes(self, solution: List[List[int]]) -> List[List[int]]:
        """Swap between routes - NO capacity check (let evaluate handle it)"""
        valid_routes = [i for i, route in enumerate(solution) if len(route) >= 1]
        
        if len(valid_routes) < 2:
            return solution
        
        route_idx1, route_idx2 = random.sample(valid_routes, 2)
        route1, route2 = solution[route_idx1], solution[route_idx2]
        
        pos1 = random.randint(0, len(route1) - 1)
        pos2 = random.randint(0, len(route2) - 1)
        
        # CHANGED: Just swap, NO capacity check
        route1[pos1], route2[pos2] = route2[pos2], route1[pos1]
        
        return solution
    
    def _mutation_relocate(self, solution: List[List[int]]) -> List[List[int]]:
        """Relocate customer - NO capacity check"""
        non_empty_routes = [i for i, route in enumerate(solution) if len(route) >= 1]
        
        if not non_empty_routes:
            return solution
        
        source_idx = random.choice(non_empty_routes)
        source_route = solution[source_idx]
        
        if not source_route:
            return solution
        
        customer_pos = random.randint(0, len(source_route) - 1)
        customer = source_route.pop(customer_pos)
        
        # CHANGED: Random target, NO capacity check
        if solution and random.random() < 0.8:
            target_idx = random.randint(0, len(solution) - 1)
            insert_pos = random.randint(0, len(solution[target_idx]))
            solution[target_idx].insert(insert_pos, customer)
        else:
            solution.append([customer])
        
        solution = [route for route in solution if route]
        return solution

    def _mutation_2opt_within_route(self, solution: List[List[int]]) -> List[List[int]]:
        valid_routes = [i for i, route in enumerate(solution) if len(route) >= 3]
        
        if not valid_routes:
            return solution
        
        route_idx = random.choice(valid_routes)
        route = solution[route_idx]
        
        i = random.randint(0, len(route) - 2)
        j = random.randint(i + 1, len(route) - 1)
        
        route[i:j+1] = reversed(route[i:j+1])
        
        return solution
    
    # def _repair_cvrp_solution(self, solution: List[List[int]]) -> List[List[int]]:
    #     """Minimal repair - only ensure all customers present, NO capacity awareness"""
    #     all_customers = set(range(1, self.problem.n_customers + 1))
    #     present_customers = set()
        
    #     # Remove duplicates
    #     for route in solution:
    #         present_customers.update(route)
        
    #     seen = set()
    #     for route in solution:
    #         to_remove = []
    #         for i, customer in enumerate(route):
    #             if customer in seen:
    #                 to_remove.append(i)
    #             else:
    #                 seen.add(customer)
    #         for i in reversed(to_remove):
    #             route.pop(i)
        
    #     # Add missing customers randomly
    #     missing = all_customers - seen
    #     for customer in missing:
    #         solution = self._random_insertion_cvrp(solution, customer)
        
    #     solution = [route for route in solution if route]
    #     return solution

    def _repair_cvrp_solution(self, solution: List[List[int]]) -> List[List[int]]:
        """CAPACITY-AWARE repair - handles capacity violations"""
        # Step 1: Remove duplicates
        seen = set()
        for route in solution:
            to_remove = []
            for i, customer in enumerate(route):
                if customer in seen:
                    to_remove.append(i)
                else:
                    seen.add(customer)
            for i in reversed(to_remove):
                route.pop(i)
        
        solution = [route for route in solution if route]
        
        # Step 2: Add missing customers
        all_customers = set(range(1, self.problem.n_customers + 1))
        missing = all_customers - seen
        
        for customer in missing:
            customer_demand = self.problem.customers[customer].demand
            inserted = False
            for route in solution:
                route_demand = sum(self.problem.customers[c].demand for c in route)
                if route_demand + customer_demand <= self.problem.vehicle_capacity:
                    route.append(customer)
                    inserted = True
                    break
            if not inserted:
                solution.append([customer])
        
        # Step 3: FIX CAPACITY VIOLATIONS (THE KEY ADDITION!)
        repaired = []
        overflow = []
        
        for route in solution:
            if not route:
                continue
            route_demand = sum(self.problem.customers[c].demand for c in route)
            
            if route_demand <= self.problem.vehicle_capacity:
                repaired.append(route)
            else:
                # Split overloaded route
                new_route = []
                current_demand = 0
                for customer in route:
                    d = self.problem.customers[customer].demand
                    if current_demand + d <= self.problem.vehicle_capacity:
                        new_route.append(customer)
                        current_demand += d
                    else:
                        overflow.append(customer)
                if new_route:
                    repaired.append(new_route)
        
        # Step 4: Place overflow into new routes
        while overflow:
            new_route = []
            current_demand = 0
            remaining = []
            for customer in overflow:
                d = self.problem.customers[customer].demand
                if current_demand + d <= self.problem.vehicle_capacity:
                    new_route.append(customer)
                    current_demand += d
                else:
                    remaining.append(customer)
            if new_route:
                repaired.append(new_route)
            if len(remaining) == len(overflow):  # No progress
                for c in remaining:
                    repaired.append([c])
                break
            overflow = remaining
        
        return repaired

    def mutation_tsp(self, solution):
        mutated = solution.copy()
        
        if random.random() < self.mutation_prob:
            i, j = random.sample(range(len(solution)), 2)
            mutated[i], mutated[j] = mutated[j], mutated[i]
        
        return mutated
    
    def mutation_knapsack(self, solution):
        mutated = solution.copy()
        
        for i in range(len(solution)):
            if random.random() < self.mutation_prob:
                mutated[i] = 1 - mutated[i]
        
        if not self.is_solution_feasible(mutated):
            mutated = self._repair_knapsack_solution(mutated)
        
        return mutated
    
    def _repair_knapsack_solution(self, solution):
        if not hasattr(self.problem, 'weights') or not hasattr(self.problem, 'capacity'):
            return solution
        
        repaired = solution.copy()
        current_weight = sum(self.problem.weights[i] for i in range(len(repaired)) if repaired[i] == 1)
        
        if current_weight <= self.problem.capacity:
            return repaired
        
        selected_items = [i for i in range(len(repaired)) if repaired[i] == 1]
        
        if hasattr(self.problem, 'values'):
            item_ratios = []
            for i in selected_items:
                if self.problem.weights[i] > 0:
                    total_value = sum(obj_values[i] for obj_values in self.problem.values)
                    ratio = total_value / self.problem.weights[i]
                else:
                    ratio = 0
                item_ratios.append((i, ratio))
            
            item_ratios.sort(key=lambda x: x[1])
            
            for item_idx, _ in item_ratios:
                if current_weight <= self.problem.capacity:
                    break
                repaired[item_idx] = 0
                current_weight -= self.problem.weights[item_idx]
        else:
            random.shuffle(selected_items)
            for item_idx in selected_items:
                if current_weight <= self.problem.capacity:
                    break
                repaired[item_idx] = 0
                current_weight -= self.problem.weights[item_idx]
        
        return repaired
    
    def is_solution_feasible(self, solution) -> bool:
        try:
            obj_values = self.problem.evaluate(solution)
            return all(val != float('inf') and val != -float('inf') for val in obj_values)
        except:
            return False
    
    # NSGA-II CORE
    
    def non_dominated_sort(self):
        n_solutions = len(self.population)
        domination_count = [0] * n_solutions
        dominated_solutions = [[] for _ in range(n_solutions)]
        
        for i in range(n_solutions):
            for j in range(n_solutions):
                if i != j:
                    if self.is_minimization:
                        if (all(self.objective_values[i][k] <= self.objective_values[j][k] 
                               for k in range(self.n_objectives)) and
                            any(self.objective_values[i][k] < self.objective_values[j][k] 
                               for k in range(self.n_objectives))):
                            dominated_solutions[i].append(j)
                        elif (all(self.objective_values[j][k] <= self.objective_values[i][k] 
                                 for k in range(self.n_objectives)) and
                              any(self.objective_values[j][k] < self.objective_values[i][k] 
                                 for k in range(self.n_objectives))):
                            domination_count[i] += 1
                    else:
                        if (all(self.objective_values[i][k] >= self.objective_values[j][k] 
                               for k in range(self.n_objectives)) and
                            any(self.objective_values[i][k] > self.objective_values[j][k] 
                               for k in range(self.n_objectives))):
                            dominated_solutions[i].append(j)
                        elif (all(self.objective_values[j][k] >= self.objective_values[i][k] 
                                 for k in range(self.n_objectives)) and
                              any(self.objective_values[j][k] > self.objective_values[i][k] 
                                 for k in range(self.n_objectives))):
                            domination_count[i] += 1
        
        fronts = [[]]
        for i in range(n_solutions):
            if domination_count[i] == 0:
                fronts[0].append(i)
        
        front_index = 0
        while front_index < len(fronts):
            if not fronts[front_index]:
                front_index += 1
                continue
                
            next_front = []
            for solution_idx in fronts[front_index]:
                for dominated_idx in dominated_solutions[solution_idx]:
                    domination_count[dominated_idx] -= 1
                    if domination_count[dominated_idx] == 0:
                        next_front.append(dominated_idx)
            
            front_index += 1
            if next_front:
                fronts.append(next_front)
        
        fronts = [front for front in fronts if front]
        return fronts
    
    def calculate_crowding_distance(self, front):
        if len(front) <= 2:
            return [float('inf')] * len(front)
        
        crowding_distances = [0.0] * len(front)
        
        for obj_idx in range(self.n_objectives):
            sorted_indices = sorted(range(len(front)), 
                                   key=lambda i: self.objective_values[front[i]][obj_idx])
            
            crowding_distances[sorted_indices[0]] = float('inf')
            crowding_distances[sorted_indices[-1]] = float('inf')
            
            objective_range = (
                self.objective_values[front[sorted_indices[-1]]][obj_idx] - 
                self.objective_values[front[sorted_indices[0]]][obj_idx]
            )
            
            if objective_range == 0:
                continue
                
            for i in range(1, len(front) - 1):
                crowding_distances[sorted_indices[i]] += (
                    self.objective_values[front[sorted_indices[i+1]]][obj_idx] - 
                    self.objective_values[front[sorted_indices[i-1]]][obj_idx]
                ) / objective_range
        
        return crowding_distances
    
    def tournament_selection(self):
        fronts = self.non_dominated_sort()
        
        crowding_distances = []
        for front in fronts:
            crowding_distances.append(self.calculate_crowding_distance(front))
        
        solution_rank = {}
        solution_crowding = {}
        
        for i, front in enumerate(fronts):
            for j, solution_idx in enumerate(front):
                solution_rank[solution_idx] = i
                solution_crowding[solution_idx] = crowding_distances[i][j]
        
        participants = random.sample(range(len(self.population)), 
                                    min(self.tournament_size, len(self.population)))
        
        winner = participants[0]
        for participant in participants[1:]:
            if solution_rank[participant] < solution_rank[winner]:
                winner = participant
            elif solution_rank[participant] == solution_rank[winner]:
                if solution_crowding[participant] > solution_crowding[winner]:
                    winner = participant
        
        return winner
    
    def create_offspring(self):
        offspring = []
        offspring_objectives = []
        
        if self.problem_type == 'cvrp':
            crossover = self.crossover_cvrp
            mutation = self.mutation_cvrp
        elif self.problem_type == 'tsp':
            crossover = self.crossover_tsp
            mutation = self.mutation_tsp
        elif self.problem_type == 'knapsack':
            crossover = self.crossover_knapsack
            mutation = self.mutation_knapsack
        else:
            raise ValueError(f"Unsupported problem type: {self.problem_type}")
        
        while len(offspring) < self.population_size:
            parent1_idx = self.tournament_selection()
            parent2_idx = self.tournament_selection()
            
            while parent2_idx == parent1_idx and len(self.population) > 1:
                parent2_idx = self.tournament_selection()
            
            parent1 = self.population[parent1_idx]
            parent2 = self.population[parent2_idx]
            
            if random.random() < self.crossover_prob:
                child1, child2 = crossover(parent1, parent2)
            else:
                child1 = deepcopy(parent1)
                child2 = deepcopy(parent2)
            
            child1 = mutation(child1)
            child2 = mutation(child2)
            
            try:
                obj1 = self.problem.evaluate(child1)
                if not any(np.isnan(val) or np.isinf(val) for val in obj1):
                    offspring.append(child1)
                    offspring_objectives.append(obj1)
                
                if len(offspring) < self.population_size:
                    obj2 = self.problem.evaluate(child2)
                    if not any(np.isnan(val) or np.isinf(val) for val in obj2):
                        offspring.append(child2)
                        offspring_objectives.append(obj2)
            except Exception as e:
                logger.warning(f"Error evaluating offspring: {e}")
        
        return offspring, offspring_objectives
    
    def update_pareto_front(self):
        valid_solutions = []
        valid_objectives = []
        
        for i, obj_values in enumerate(self.objective_values):
            if not any(np.isnan(val) or np.isinf(val) for val in obj_values):
                valid_solutions.append(self.population[i])
                valid_objectives.append(obj_values)
        
        if not valid_solutions:
            logger.warning("No valid solutions found for Pareto front update")
            self.pareto_front = []
            self.pareto_objectives = []
            return
        
        is_efficient = [True] * len(valid_solutions)
        
        for i in range(len(valid_solutions)):
            if is_efficient[i]:
                for j in range(len(valid_solutions)):
                    if i != j and is_efficient[j]:
                        if self.is_minimization:
                            if (all(valid_objectives[j][k] <= valid_objectives[i][k] 
                                   for k in range(self.n_objectives)) and
                                any(valid_objectives[j][k] < valid_objectives[i][k] 
                                   for k in range(self.n_objectives))):
                                is_efficient[i] = False
                                break
                        else:
                            if (all(valid_objectives[j][k] >= valid_objectives[i][k] 
                                   for k in range(self.n_objectives)) and
                                any(valid_objectives[j][k] > valid_objectives[i][k] 
                                   for k in range(self.n_objectives))):
                                is_efficient[i] = False
                                break
        
        self.pareto_front = [valid_solutions[i] for i in range(len(valid_solutions)) if is_efficient[i]]
        self.pareto_objectives = [valid_objectives[i] for i in range(len(valid_solutions)) if is_efficient[i]]
    
    def run(self):
        self.start_time = time.time()
        logger.info(f"Starting NSGA-II optimization for {self.problem.__class__.__name__}")
        logger.info(f"Problem type: {self.problem_type}")
        
        if self.problem_type == 'cvrp':
            logger.info(f"CVRP with {self.problem.n_customers} customers, "
                       f"{self.problem.n_vehicles} vehicles, "
                       f"capacity={self.problem.vehicle_capacity}")
        elif self.problem_type == 'tsp':
            logger.info(f"TSP with {self.problem.n_cities} cities")
        elif self.problem_type == 'knapsack':
            logger.info(f"Knapsack with {self.problem.n_items} items")
        
        try:
            self.initialize_population()
            
            if self.objective_values:
                objectives = np.array(self.objective_values)
                min_vals = np.min(objectives, axis=0)
                max_vals = np.max(objectives, axis=0)
                
                logger.info(f"Initial objective value ranges:")
                for i in range(self.n_objectives):
                    logger.info(f"  Objective {i+1}: min={min_vals[i]:.4f}, max={max_vals[i]:.4f}")
            
            self.update_pareto_front()
            
            self.history['hypervolume'].append(0)
            self.history['runtime'].append(0)
            self.history['num_solutions'].append(len(self.pareto_front))
            
            if self.objective_values:
                objectives = np.array(self.objective_values)
                for obj_idx in range(self.n_objectives):
                    if self.is_minimization:
                        best_val = float(np.min(objectives[:, obj_idx]))
                    else:
                        best_val = float(np.max(objectives[:, obj_idx]))
                    self.history['best_solution_per_objective'][obj_idx].append(best_val)
            
            logger.info(f"Initial Pareto front size: {len(self.pareto_front)}")
            
            for generation in range(self.n_generations):
                generation_start = time.time()
                
                try:
                    offspring, offspring_objectives = self.create_offspring()
                    
                    combined_population = self.population + offspring
                    combined_objectives = self.objective_values + offspring_objectives
                    
                    temp_population = self.population
                    temp_objectives = self.objective_values
                    
                    self.population = combined_population
                    self.objective_values = combined_objectives
                    
                    fronts = self.non_dominated_sort()
                    
                    self.population = temp_population
                    self.objective_values = temp_objectives
                    
                    new_population = []
                    new_objectives = []
                    
                    for front in fronts:
                        if len(new_population) + len(front) <= self.population_size:
                            for idx in front:
                                new_population.append(combined_population[idx])
                                new_objectives.append(combined_objectives[idx])
                        else:
                            temp_population = self.population
                            temp_objectives = self.objective_values
                            
                            self.population = combined_population
                            self.objective_values = combined_objectives
                            
                            crowding_distances = self.calculate_crowding_distance(front)
                            
                            self.population = temp_population
                            self.objective_values = temp_objectives
                            
                            sorted_front = [(idx, dist) for idx, dist in zip(front, crowding_distances)]
                            sorted_front.sort(key=lambda x: -x[1])
                            
                            for idx, _ in sorted_front:
                                if len(new_population) < self.population_size:
                                    new_population.append(combined_population[idx])
                                    new_objectives.append(combined_objectives[idx])
                                else:
                                    break
                        
                        if len(new_population) >= self.population_size:
                            break
                    
                    self.population = new_population
                    self.objective_values = new_objectives
                    
                    self.update_pareto_front()
                    
                    elapsed = time.time() - self.start_time
                    
                    self.history['hypervolume'].append(0)
                    self.history['runtime'].append(elapsed)
                    self.history['num_solutions'].append(len(self.pareto_front))
                    
                    if self.objective_values:
                        objectives = np.array(self.objective_values)
                        for obj_idx in range(self.n_objectives):
                            if self.is_minimization:
                                best_val = float(np.min(objectives[:, obj_idx]))
                            else:
                                best_val = float(np.max(objectives[:, obj_idx]))
                            self.history['best_solution_per_objective'][obj_idx].append(best_val)
                    
                    if self.verbose and (generation + 1) % 20 == 0:
                        gen_time = time.time() - generation_start
                        logger.info(f"Generation {generation + 1}/{self.n_generations}, "
                                f"Pareto size: {len(self.pareto_front)}, "
                                f"Time: {gen_time:.2f}s, "
                                f"Total: {elapsed:.2f}s")
                
                except Exception as e:
                    logger.error(f"Error in generation {generation + 1}: {str(e)}")
                    import traceback
                    logger.error(traceback.format_exc())
            
            total_time = time.time() - self.start_time
            logger.info(f"Optimization completed in {total_time:.2f} seconds")
            logger.info(f"Final Pareto front size: {len(self.pareto_front)}")
            
            if self.objective_values:
                objectives = np.array(self.objective_values)
                min_vals = np.min(objectives, axis=0)
                max_vals = np.max(objectives, axis=0)
                
                logger.info(f"Final objective value ranges:")
                for i in range(self.n_objectives):
                    logger.info(f"  Objective {i+1}: min={min_vals[i]:.4f}, max={max_vals[i]:.4f}")
            
            return list(zip(self.pareto_front, self.pareto_objectives))
            
        except Exception as e:
            logger.error(f"Error in optimization run: {str(e)}")
            import traceback
            logger.error(traceback.format_exc())
            return list(zip(self.pareto_front, self.pareto_objectives)) if self.pareto_front else []
    
    def plot_convergence(self):
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(self.history['num_solutions'])
        plt.title('Pareto Front Size')
        plt.xlabel('Generation')
        plt.ylabel('Number of Solutions')
        plt.grid(True)
        
        plt.subplot(1, 2, 2)
        for obj_idx in range(self.n_objectives):
            plt.plot(self.history['best_solution_per_objective'][obj_idx], 
                    label=f'Objective {obj_idx + 1}')
        plt.title('Best Solution per Objective')
        plt.xlabel('Generation')
        plt.ylabel('Objective Value')
        plt.legend()
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig('convergence.png', dpi=150, bbox_inches='tight')
        plt.close()
    
    def plot_pareto_front(self):
        if self.n_objectives != 2:
            logger.warning("Pareto front visualization only supported for 2 objectives")
            return
        
        plt.figure(figsize=(10, 8))
        
        pop_objectives = np.array(self.objective_values)
        pareto_objectives = np.array(self.pareto_objectives)
        
        plt.scatter(pop_objectives[:, 0], pop_objectives[:, 1], 
                   c='lightgray', s=30, alpha=0.5, label='Population')
        
        plt.scatter(pareto_objectives[:, 0], pareto_objectives[:, 1], 
                   c='blue', s=50, label='Pareto Front')
        
        if len(pareto_objectives) > 1:
            sorted_indices = np.argsort(pareto_objectives[:, 0])
            sorted_objectives = pareto_objectives[sorted_indices]
            plt.plot(sorted_objectives[:, 0], sorted_objectives[:, 1], 'b--', alpha=0.7)
        
        if self.problem_type == 'cvrp':
            plt.title('CVRP Pareto Front (Minimization)')
            plt.xlabel('Total Distance')
            plt.ylabel('Makespan (Longest Route)')
        elif self.problem_type == 'tsp':
            plt.title('TSP Pareto Front (Minimization)')
            plt.xlabel('Objective 1 (Distance)')
            plt.ylabel('Objective 2 (Distance)')
        else:
            plt.title('Knapsack Pareto Front (Maximization)')
            plt.xlabel('Objective 1 (Value)')
            plt.ylabel('Objective 2 (Value)')
            
        plt.grid(True)
        plt.legend()
        plt.savefig('pareto_front.png', dpi=150, bbox_inches='tight')
        plt.close()


def evaluate_nsga2_with_evaluator(num_runs=5):
    """Evaluate NSGA-II using MOCOEvaluator"""
    # Common parameters
    nsga2_params = {
        'population_size': 100,
        'n_generations': 100,#200,
        'crossover_prob': 0.9,
        'mutation_prob': 0.1,
        'tournament_size':3,
        'verbose': True
    }
    
    
    # Problem parameters
    tsp_problem_params = {'n_cities': 20}
    knapsack_problem_params = {'n_items': 200, 'n_objectives': 2, 'capacity': 25.0}
    
    # Initialize evaluator with reference points
    # For TSP (minimization), use a point above the Pareto front
    # tsp_evaluator = MOCOEvaluator(reference_point=(2494, 2402), confidence_level=0.95)
    tsp_evaluator = MOCOEvaluator(reference_point=(20,20), confidence_level=0.95)
    
    # For Knapsack (maximization), use a point below the Pareto front
    # Use (0, 0) as reference point for maximization
    # knapsack_evaluator = MOCOEvaluator(reference_point=(-7.85, -8.99), confidence_level=0.95)
    knapsack_evaluator = MOCOEvaluator(reference_point=(5, 5), confidence_level=0.95)
    
    print("\n" + "="*50)
    print(f"Evaluating NSGA-II with {num_runs} runs per problem:")
    print("="*50)
    
    # Evaluate on BiObjectiveTSP
    # print("\nEvaluating on BiObjectiveTSP:")
    # tsp_result = tsp_evaluator.evaluate_algorithm(
    #     algorithm_class=NSGA2,
    #     problem_class=BiObjectiveTSP,
    #     algorithm_name="NSGA-II",
    #     parameters=nsga2_params,
    #     problem_params=tsp_problem_params,
    #     num_runs=num_runs
    # )
    
    # Evaluate on MultiObjectiveKnapsack
    print("\nEvaluating on MultiObjectiveKnapsack:")
    knapsack_result = knapsack_evaluator.evaluate_algorithm(
        algorithm_class=NSGA2,
        problem_class=MultiObjectiveKnapsack,
        algorithm_name="NSGA-II",
        parameters=nsga2_params,
        problem_params=knapsack_problem_params,
        num_runs=num_runs
    )
    
    # Generate reports
    # print("\nGenerating TSP report:")
    # tsp_evaluator.generate_report()
    
    print("\nGenerating Knapsack report:")
    knapsack_evaluator.generate_report()
    
    # Print results
    print("\nAggregated Results:")
    
    # print("\nTSP Results:")
    # for result in tsp_evaluator.results:
    #     print(f"\n{result.algorithm_name} on {result.problem_name}:")
    #     print(f"Average Runtime: {result.runtime:.2f} seconds")
    #     print(f"Final Hypervolume: {result.hypervolume:.4f}")
    #     print(f"Final Non-dominated solutions: {result.num_nondominated}")
    
    print("\nKnapsack Results:")
    for result in knapsack_evaluator.results:
        print(f"\n{result.algorithm_name} on {result.problem_name}:")
        print(f"Average Runtime: {result.runtime:.2f} seconds")
        print(f"Final Hypervolume: {result.hypervolume:.4f}")
        print(f"Final Non-dominated solutions: {result.num_nondominated}")
    
    # Plot comparison
    try:
        # print("\nGenerating TSP visualizations...")
        # tsp_evaluator.plot_comparison()
        # tsp_evaluator.plot_pareto_front(show_all=True)
        
        print("\nGenerating Knapsack visualizations...")
        knapsack_evaluator.plot_comparison()
        knapsack_evaluator.plot_pareto_front(show_all=True)
    except Exception as e:
        print(f"Could not generate plots: {e}")
    
    return tsp_evaluator, knapsack_evaluator


def test_nsga2_on_tsp(n_cities=20, population_size=100, n_generations=200):
    """Test NSGA-II on BiObjectiveTSP"""
    logger.info(f"Testing NSGA-II on BiObjectiveTSP with {n_cities} cities")
    
    # Create problem instance
    problem = BiObjectiveTSP(n_cities=n_cities)
    
    # Create and run optimizer
    optimizer = NSGA2(
        problem=problem,
        population_size=population_size,
        n_generations=n_generations,
        crossover_prob=0.9,
        mutation_prob=0.1,
        verbose=True
    )
    
    # Run optimization
    pareto_front = optimizer.run()
    
    # Plot results
    optimizer.plot_convergence()
    optimizer.plot_pareto_front()
    
    return optimizer


def test_nsga2_on_knapsack(n_items=50, n_objectives=2, population_size=100, n_generations=200):
    """Test NSGA-II on MultiObjectiveKnapsack"""
    logger.info(f"Testing NSGA-II on MultiObjectiveKnapsack with {n_items} items")
    
    # Create problem instance
    problem = MultiObjectiveKnapsack(n_items=n_items, n_objectives=n_objectives, capacity=10.0)
    
    # Create and run optimizer
    optimizer = NSGA2(
        problem=problem,
        population_size=population_size,
        n_generations=n_generations,
        crossover_prob=0.9,
        mutation_prob=0.1,
        verbose=True
    )
    
    # Run optimization
    pareto_front = optimizer.run()
    
    # Plot results
    optimizer.plot_convergence()
    if n_objectives == 2:
        optimizer.plot_pareto_front()
    
    return optimizer



if __name__ == "__main__":
    # Run a single test on BiObjectiveTSP
    # print("\nRunning NSGA-II on BiObjectiveTSP:")
    # tsp_optimizer = test_nsga2_on_tsp(n_cities=20, population_size=100, n_generations=200)
    
    # Run a single test on MultiObjectiveKnapsack
    print("\nRunning NSGA-II on MultiObjectiveKnapsack:")
    # knapsack_optimizer = test_nsga2_on_knapsack(n_items=50, n_objectives=2, population_size=100, n_generations=200)
    
    # Run evaluation with MOCOEvaluator
    # print("\nEvaluating NSGA-II with MOCOEvaluator:")
    tsp_eval, knapsack_eval = evaluate_nsga2_with_evaluator(num_runs=2)